Skip to content

declearn.test_utils.to_numpy

Convert an input framework-based structure to a numpy array.

Source code in declearn/test_utils/_convert.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def to_numpy(array: Any, framework: str) -> np.ndarray:
    """Convert an input framework-based structure to a numpy array."""
    if isinstance(array, np.ndarray):
        return array
    if framework == "jax":
        return np.asarray(array)
    if framework == "tensorflow":  # add support for IndexedSlices
        tensorflow = importlib.import_module("tensorflow")
        if isinstance(array, tensorflow.IndexedSlices):
            with tensorflow.device(array.device):
                return tensorflow.convert_to_tensor(array).numpy()
        return array.numpy()
    if framework == "torch":
        return array.cpu().numpy()
    raise ValueError(
        f"Invalid 'framework' from which to convert to numpy: '{framework}'."
    )