42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195 | @register_vector_type(
jax.Array,
jaxlib.xla_extension.ArrayImpl, # pylint: disable=c-extension-no-member
)
class JaxNumpyVector(Vector):
"""Vector subclass to store jax.numpy.ndarray coefficients.
This Vector is designed to store a collection of named
jax numpy arrays or scalars, enabling computations that are
either applied to each and every coefficient, or imply
two sets of aligned coefficients (i.e. two JaxNumpyVector
instances with similar coefficients specifications).
Use `vector.coefs` to access the stored coefficients.
Notes
-----
- A `JaxnumpyVector` can be operated with either a:
- scalar value
- `NumpyVector` that has similar specifications
- `JaxNumpyVector` that has similar specifications
- => resulting in a `JaxNumpyVector` in each of these cases.
- The wrapped arrays may be placed on any device (CPU, GPU...)
and may not be all on the same device.
- The device-placement of the initial `JaxNumpyVector`'s data
is preserved by operations, including with `NumpyVector`.
- When combining two `JaxNumpyVector`, the device-placement
of the left-most one is used; in that case, one ends up with
`gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
warning will be emitted to prevent silent un-optimized copies.
- When deserializing a `JaxNumpyVector` (either by directly using
`JaxNumpyVector.unpack` or loading one from a JSON dump), loaded
arrays are placed based on the global device-placement policy
(accessed via `declearn.utils.get_device_policy`). Thus it may
have a different device-placement schema than at dump time but
should be coherent with that of `HaikuModel` computations.
"""
@property
def _op_add(self) -> Callable[[Any, Any], jax.Array]:
return jnp.add
@property
def _op_sub(self) -> Callable[[Any, Any], jax.Array]:
return jnp.subtract
@property
def _op_mul(self) -> Callable[[Any, Any], jax.Array]:
return jnp.multiply
@property
def _op_div(self) -> Callable[[Any, Any], jax.Array]:
return jnp.divide
@property
def _op_pow(self) -> Callable[[Any, Any], jax.Array]:
return jnp.power
@property
def compatible_vector_types(self) -> Set[Type[Vector]]:
types = super().compatible_vector_types
return types.union({NumpyVector, JaxNumpyVector})
def __init__(self, coefs: Dict[str, jax.Array]) -> None:
super().__init__(coefs)
def _apply_operation(
self,
other: Any,
func: Callable[[jax.Array, Any], jax.Array],
) -> Self:
# Ensure 'other' JaxNumpyVector shares this vector's device placement.
if isinstance(other, JaxNumpyVector):
coefs = {
key: jax.device_put(val, self.coefs[key].device())
for key, val in other.coefs.items()
}
other = JaxNumpyVector(coefs)
return super()._apply_operation(other, func)
def __eq__(self, other: Any) -> bool:
valid = isinstance(other, JaxNumpyVector)
valid = valid and (self.coefs.keys() == other.coefs.keys())
return valid and all(
jnp.array_equal(
val, jax.device_put(other.coefs[key], val.device())
)
for key, val in self.coefs.items()
)
def sign(
self,
) -> Self:
return self.apply_func(jnp.sign)
def minimum(
self,
other: Union[Self, float],
) -> Self:
if isinstance(other, JaxNumpyVector):
return self._apply_operation(other, jnp.minimum)
return self.apply_func(jnp.minimum, other)
def maximum(
self,
other: Union[Self, float],
) -> Self:
if isinstance(other, Vector):
return self._apply_operation(other, jnp.maximum)
return self.apply_func(jnp.maximum, other)
def sum(
self,
) -> Self:
coefs = {
key: jnp.array(jnp.sum(val)) for key, val in self.coefs.items()
}
return self.__class__(coefs)
def pack(
self,
) -> Dict[str, Any]:
return {key: np.asarray(arr) for key, arr in self.coefs.items()}
@classmethod
def unpack(
cls,
data: Dict[str, Any],
) -> Self:
policy = get_device_policy()
device = select_device(gpu=policy.gpu, idx=policy.idx)
coefs = {k: jax.device_put(arr, device) for k, arr in data.items()}
return cls(coefs)
# similar code to that of TorchVector; pylint: disable=duplicate-code
def flatten(
self,
) -> Tuple[List[float], VectorSpec]:
v_spec = self.get_vector_specs()
arrays = self.pack()
values = flatten_numpy_arrays([arrays[name] for name in v_spec.names])
return values, v_spec
@classmethod
def unflatten(
cls,
values: List[float],
v_spec: VectorSpec,
) -> Self:
shapes = [v_spec.shapes[name] for name in v_spec.names]
dtypes = [v_spec.dtypes[name] for name in v_spec.names]
arrays = unflatten_numpy_arrays(values, shapes, dtypes)
return cls.unpack(dict(zip(v_spec.names, arrays)))
|