Skip to content

declearn.test_utils.GradientsTestCase

Framework-parametrized Vector instances provider for testing purposes.

This class aims at providing with seeded random or zero-valued Vector instances (with deterministic specifications) that may be used in the context of unit tests.

Source code in declearn/test_utils/_vectors.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class GradientsTestCase:
    """Framework-parametrized Vector instances provider for testing purposes.

    This class aims at providing with seeded random or zero-valued Vector
    instances (with deterministic specifications) that may be used in the
    context of unit tests.
    """

    def __init__(
        self,
        framework: FrameworkType,
        seed: Optional[int] = 0,
    ) -> None:
        """Instantiate the parametrized test-case."""
        if framework not in list_available_frameworks():
            raise RuntimeError(f"Framework '{framework}' is unavailable.")
        self.framework = framework
        self.seed = seed

    @property
    def vector_cls(self) -> Type[Vector]:
        """Vector subclass suitable to the tested framework."""
        if self.framework == "numpy":
            return NumpyVector
        if self.framework == "tensorflow":
            module = importlib.import_module("declearn.model.tensorflow")
            return module.TensorflowVector
        if self.framework == "torch":
            module = importlib.import_module("declearn.model.torch")
            return module.TorchVector
        if self.framework == "jax":
            module = importlib.import_module("declearn.model.haiku")
            return module.JaxNumpyVector
        raise ValueError(f"Invalid framework '{self.framework}'")

    def convert(self, array: np.ndarray) -> ArrayLike:
        """Convert an input numpy array to a framework-based structure."""
        if self.framework == "numpy":
            return array
        if self.framework == "tensorflow":
            tensorflow = importlib.import_module("tensorflow")
            with tensorflow.device("CPU"):
                return tensorflow.convert_to_tensor(array)
        if self.framework == "torch":
            torch = importlib.import_module("torch")
            return torch.from_numpy(array)
        if self.framework == "jax":
            jnp = importlib.import_module("jax.numpy")
            return jnp.asarray(array)
        raise ValueError(f"Invalid framework '{self.framework}'")

    def to_numpy(self, array: ArrayLike) -> np.ndarray:
        """Convert an input framework-based structure to a numpy array."""
        if isinstance(array, np.ndarray):
            return array
        if self.framework == "jax":
            return np.asarray(array)
        if self.framework == "tensorflow":  # add support for IndexedSlices
            tensorflow = importlib.import_module("tensorflow")
            if isinstance(array, tensorflow.IndexedSlices):
                with tensorflow.device(array.device):
                    return tensorflow.convert_to_tensor(array).numpy()
        return array.numpy()  # type: ignore

    @property
    def mock_gradient(self) -> Vector:
        """Instantiate a Vector with random-valued mock gradients.

        Note: the RNG used to generate gradients has a fixed seed,
              to that gradients have the same values whatever the
              tensor framework used is.
        """
        rng = np.random.default_rng(self.seed)
        shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)]
        values = [rng.normal(size=shape) for shape in shapes]
        vector = self.vector_cls(
            {str(idx): self.convert(value) for idx, value in enumerate(values)}
        )
        # In Tensorflow, convert the first gradients to IndexedSlices.
        # In this case they are equivalent to dense ones, but this enables
        # testing the support for these structures while maintaining the
        # possibility to compare outputs' values with other frameworks.
        if self.framework == "tensorflow":
            tensorflow = importlib.import_module("tensorflow")
            vector.coefs["0"] = tensorflow.IndexedSlices(
                values=vector.coefs["0"],
                indices=tensorflow.range(64),
                dense_shape=tensorflow.convert_to_tensor([64, 32]),
            )
        return vector

    @property
    def mock_ones(self) -> Vector:
        """Instantiate a Vector with random-valued mock gradients.

        Note: the RNG used to generate gradients has a fixed seed,
                to that gradients have the same values whatever the
                tensor framework used is.
        """
        shapes = [(5, 5), (4,), (1,)]
        values = [np.ones(shape) for shape in shapes]
        return self.vector_cls(
            {str(idx): self.convert(value) for idx, value in enumerate(values)}
        )

    @property
    def mock_zeros(self) -> Vector:
        """Instantiate a Vector with random-valued mock gradients.

        Note: the RNG used to generate gradients has a fixed seed,
                to that gradients have the same values whatever the
                tensor framework used is.
        """
        shapes = [(5, 5), (4,), (1,)]
        values = [np.zeros(shape) for shape in shapes]
        return self.vector_cls(
            {str(idx): self.convert(value) for idx, value in enumerate(values)}
        )

mock_gradient: Vector property

Instantiate a Vector with random-valued mock gradients.

Note: the RNG used to generate gradients has a fixed seed, to that gradients have the same values whatever the tensor framework used is.

mock_ones: Vector property

Instantiate a Vector with random-valued mock gradients.

Note: the RNG used to generate gradients has a fixed seed, to that gradients have the same values whatever the tensor framework used is.

mock_zeros: Vector property

Instantiate a Vector with random-valued mock gradients.

Note: the RNG used to generate gradients has a fixed seed, to that gradients have the same values whatever the tensor framework used is.

vector_cls: Type[Vector] property

Vector subclass suitable to the tested framework.

__init__(framework, seed=0)

Instantiate the parametrized test-case.

Source code in declearn/test_utils/_vectors.py
62
63
64
65
66
67
68
69
70
71
def __init__(
    self,
    framework: FrameworkType,
    seed: Optional[int] = 0,
) -> None:
    """Instantiate the parametrized test-case."""
    if framework not in list_available_frameworks():
        raise RuntimeError(f"Framework '{framework}' is unavailable.")
    self.framework = framework
    self.seed = seed

convert(array)

Convert an input numpy array to a framework-based structure.

Source code in declearn/test_utils/_vectors.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def convert(self, array: np.ndarray) -> ArrayLike:
    """Convert an input numpy array to a framework-based structure."""
    if self.framework == "numpy":
        return array
    if self.framework == "tensorflow":
        tensorflow = importlib.import_module("tensorflow")
        with tensorflow.device("CPU"):
            return tensorflow.convert_to_tensor(array)
    if self.framework == "torch":
        torch = importlib.import_module("torch")
        return torch.from_numpy(array)
    if self.framework == "jax":
        jnp = importlib.import_module("jax.numpy")
        return jnp.asarray(array)
    raise ValueError(f"Invalid framework '{self.framework}'")

to_numpy(array)

Convert an input framework-based structure to a numpy array.

Source code in declearn/test_utils/_vectors.py
105
106
107
108
109
110
111
112
113
114
115
116
def to_numpy(self, array: ArrayLike) -> np.ndarray:
    """Convert an input framework-based structure to a numpy array."""
    if isinstance(array, np.ndarray):
        return array
    if self.framework == "jax":
        return np.asarray(array)
    if self.framework == "tensorflow":  # add support for IndexedSlices
        tensorflow = importlib.import_module("tensorflow")
        if isinstance(array, tensorflow.IndexedSlices):
            with tensorflow.device(array.device):
                return tensorflow.convert_to_tensor(array).numpy()
    return array.numpy()  # type: ignore