Skip to content

declearn.model.haiku.JaxNumpyVector

Bases: Vector

Vector subclass to store jax.numpy.ndarray coefficients.

This Vector is designed to store a collection of named jax numpy arrays or scalars, enabling computations that are either applied to each and every coefficient, or imply two sets of aligned coefficients (i.e. two JaxNumpyVector instances with similar coefficients specifications).

Use vector.coefs to access the stored coefficients.

Notes

  • A JaxnumpyVector can be operated with either a:
    • scalar value
    • NumpyVector that has similar specifications
    • JaxNumpyVector that has similar specifications
    • => resulting in a JaxNumpyVector in each of these cases.
  • The wrapped arrays may be placed on any device (CPU, GPU...) and may not be all on the same device.
  • The device-placement of the initial JaxNumpyVector's data is preserved by operations, including with NumpyVector.
  • When combining two JaxNumpyVector, the device-placement of the left-most one is used; in that case, one ends up with gpu + cpu = gpu while cpu + gpu = cpu. In both cases, a warning will be emitted to prevent silent un-optimized copies.
  • When deserializing a JaxNumpyVector (either by directly using JaxNumpyVector.unpack or loading one from a JSON dump), loaded arrays are placed based on the global device-placement policy (accessed via declearn.utils.get_device_policy). Thus it may have a different device-placement schema than at dump time but should be coherent with that of HaikuModel computations.
Source code in declearn/model/haiku/_vector.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@register_vector_type(
    jax.Array,
    jaxlib.xla_extension.ArrayImpl,  # pylint: disable=c-extension-no-member
)
class JaxNumpyVector(Vector):
    """Vector subclass to store jax.numpy.ndarray coefficients.

    This Vector is designed to store a collection of named
    jax numpy arrays or scalars, enabling computations that are
    either applied to each and every coefficient, or imply
    two sets of aligned coefficients (i.e. two JaxNumpyVector
    instances with similar coefficients specifications).

    Use `vector.coefs` to access the stored coefficients.

    Notes
    -----
    - A `JaxnumpyVector` can be operated with either a:
        - scalar value
        - `NumpyVector` that has similar specifications
        - `JaxNumpyVector` that has similar specifications
        - => resulting in a `JaxNumpyVector` in each of these cases.
    - The wrapped arrays may be placed on any device (CPU, GPU...)
      and may not be all on the same device.
    - The device-placement of the initial `JaxNumpyVector`'s data
      is preserved by operations, including with `NumpyVector`.
    - When combining two `JaxNumpyVector`, the device-placement
      of the left-most one is used; in that case, one ends up with
      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
      warning will be emitted to prevent silent un-optimized copies.
    - When deserializing a `JaxNumpyVector` (either by directly using
      `JaxNumpyVector.unpack` or loading one from a JSON dump), loaded
      arrays are placed based on the global device-placement policy
      (accessed via `declearn.utils.get_device_policy`). Thus it may
      have a different device-placement schema than at dump time but
      should be coherent with that of `HaikuModel` computations.
    """

    @property
    def _op_add(self) -> Callable[[Any, Any], jax.Array]:
        return jnp.add

    @property
    def _op_sub(self) -> Callable[[Any, Any], jax.Array]:
        return jnp.subtract

    @property
    def _op_mul(self) -> Callable[[Any, Any], jax.Array]:
        return jnp.multiply

    @property
    def _op_div(self) -> Callable[[Any, Any], jax.Array]:
        return jnp.divide

    @property
    def _op_pow(self) -> Callable[[Any, Any], jax.Array]:
        return jnp.power

    @property
    def compatible_vector_types(self) -> Set[Type[Vector]]:
        types = super().compatible_vector_types
        return types.union({NumpyVector, JaxNumpyVector})

    def __init__(self, coefs: Dict[str, jax.Array]) -> None:
        super().__init__(coefs)

    def _apply_operation(
        self,
        other: Any,
        func: Callable[[jax.Array, Any], jax.Array],
    ) -> Self:
        # Ensure 'other' JaxNumpyVector shares this vector's device placement.
        if isinstance(other, JaxNumpyVector):
            coefs = {
                key: jax.device_put(val, self.coefs[key].device())
                for key, val in other.coefs.items()
            }
            other = JaxNumpyVector(coefs)
        return super()._apply_operation(other, func)

    def __eq__(self, other: Any) -> bool:
        valid = isinstance(other, JaxNumpyVector)
        valid = valid and (self.coefs.keys() == other.coefs.keys())
        return valid and all(
            jnp.array_equal(
                val, jax.device_put(other.coefs[key], val.device())
            )
            for key, val in self.coefs.items()
        )

    def sign(
        self,
    ) -> Self:
        return self.apply_func(jnp.sign)

    def minimum(
        self,
        other: Union[Self, float],
    ) -> Self:
        if isinstance(other, JaxNumpyVector):
            return self._apply_operation(other, jnp.minimum)
        return self.apply_func(jnp.minimum, other)

    def maximum(
        self,
        other: Union[Self, float],
    ) -> Self:
        if isinstance(other, Vector):
            return self._apply_operation(other, jnp.maximum)
        return self.apply_func(jnp.maximum, other)

    def sum(
        self,
    ) -> Self:
        coefs = {
            key: jnp.array(jnp.sum(val)) for key, val in self.coefs.items()
        }
        return self.__class__(coefs)

    def pack(
        self,
    ) -> Dict[str, Any]:
        return {key: np.asarray(arr) for key, arr in self.coefs.items()}

    @classmethod
    def unpack(
        cls,
        data: Dict[str, Any],
    ) -> Self:
        policy = get_device_policy()
        device = select_device(gpu=policy.gpu, idx=policy.idx)
        coefs = {k: jax.device_put(arr, device) for k, arr in data.items()}
        return cls(coefs)

    # similar code to that of TorchVector; pylint: disable=duplicate-code

    def flatten(
        self,
    ) -> Tuple[List[float], VectorSpec]:
        v_spec = self.get_vector_specs()
        arrays = self.pack()
        values = flatten_numpy_arrays([arrays[name] for name in v_spec.names])
        return values, v_spec

    @classmethod
    def unflatten(
        cls,
        values: List[float],
        v_spec: VectorSpec,
    ) -> Self:
        shapes = [v_spec.shapes[name] for name in v_spec.names]
        dtypes = [v_spec.dtypes[name] for name in v_spec.names]
        arrays = unflatten_numpy_arrays(values, shapes, dtypes)
        return cls.unpack(dict(zip(v_spec.names, arrays)))