Convert an input framework-based structure to a numpy array.
Source code in declearn/test_utils/_convert.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 | def to_numpy(array: Any, framework: str) -> np.ndarray:
"""Convert an input framework-based structure to a numpy array."""
if isinstance(array, np.ndarray):
return array
if framework == "jax":
return np.asarray(array)
if framework == "tensorflow": # add support for IndexedSlices
tensorflow = importlib.import_module("tensorflow")
if isinstance(array, tensorflow.IndexedSlices):
with tensorflow.device(array.device):
return tensorflow.convert_to_tensor(array).numpy()
return array.numpy()
if framework == "torch":
return array.cpu().numpy()
raise ValueError(
f"Invalid 'framework' from which to convert to numpy: '{framework}'."
)
|