Skip to content

declearn.model.torch.TorchVector

Bases: Vector

Vector subclass to store PyTorch tensors.

This Vector is designed to store a collection of named PyTorch tensors, enabling computations that are either applied to each and every coefficient, or imply two sets of aligned coefficients (i.e. two TorchVector with similar specifications).

Use vector.coefs to access the stored coefficients.

Notes

  • A TorchVector can be operated with either a:
    • scalar value
    • NumpyVector that has similar specifications
    • TorchVector that has similar specifications
    • => resulting in a TorchVector in each of these cases.
  • The wrapped tensors may be placed on any device (CPU, GPU...) and may not be all on the same device.
  • The device-placement of the initial TorchVector's data is preserved by operations, including with NumpyVector.
  • When combining two TorchVector, the device-placement of the left-most one is used; in that case, one ends up with gpu + cpu = gpu while cpu + gpu = cpu. In both cases, a warning will be emitted to prevent silent un-optimized copies.
  • When deserializing a TorchVector (either by directly using TorchVector.unpack or loading one from a JSON dump), loaded tensors are placed based on the global device-placement policy (accessed via declearn.utils.get_device_policy). Thus it may have a different device-placement schema than at dump time but should be coherent with that of TorchModel computations.
Source code in declearn/model/torch/_vector.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
@register_vector_type(torch.Tensor)
class TorchVector(Vector):
    """Vector subclass to store PyTorch tensors.

    This Vector is designed to store a collection of named PyTorch
    tensors, enabling computations that are either applied to each
    and every coefficient, or imply two sets of aligned coefficients
    (i.e. two TorchVector with similar specifications).

    Use `vector.coefs` to access the stored coefficients.

    Notes
    -----
    - A `TorchVector` can be operated with either a:
        - scalar value
        - `NumpyVector` that has similar specifications
        - `TorchVector` that has similar specifications
        - => resulting in a `TorchVector` in each of these cases.
    - The wrapped tensors may be placed on any device (CPU, GPU...)
      and may not be all on the same device.
    - The device-placement of the initial `TorchVector`'s data
      is preserved by operations, including with `NumpyVector`.
    - When combining two `TorchVector`, the device-placement
      of the left-most one is used; in that case, one ends up with
      `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
      warning will be emitted to prevent silent un-optimized copies.
    - When deserializing a `TorchVector` (either by directly using
      `TorchVector.unpack` or loading one from a JSON dump), loaded
      tensors are placed based on the global device-placement policy
      (accessed via `declearn.utils.get_device_policy`). Thus it may
      have a different device-placement schema than at dump time but
      should be coherent with that of `TorchModel` computations.
    """

    @property
    def _op_add(self) -> Callable[[Any, Any], Any]:
        return torch.add  # pylint: disable=no-member

    @property
    def _op_sub(self) -> Callable[[Any, Any], Any]:
        return torch.sub  # pylint: disable=no-member

    @property
    def _op_mul(self) -> Callable[[Any, Any], Any]:
        return torch.mul  # pylint: disable=no-member

    @property
    def _op_div(self) -> Callable[[Any, Any], Any]:
        return torch.div  # pylint: disable=no-member

    @property
    def _op_pow(self) -> Callable[[Any, Any], Any]:
        return torch.pow  # pylint: disable=no-member

    @property
    def compatible_vector_types(self) -> Set[Type[Vector]]:
        types = super().compatible_vector_types
        return types.union({NumpyVector, TorchVector})

    def __init__(self, coefs: Dict[str, torch.Tensor]) -> None:
        super().__init__(coefs)

    def _apply_operation(
        self,
        other: Any,
        func: Callable[[Any, Any], Any],
    ) -> Self:
        # Convert 'other' NumpyVector to a (CPU-backed) TorchVector.
        if isinstance(other, NumpyVector):
            # false-positive; pylint: disable=no-member
            coefs = {
                key: torch.from_numpy(val) for key, val in other.coefs.items()
            }
            other = TorchVector(coefs)
        # Ensure 'other' TorchVector shares this vector's device placement.
        if isinstance(other, TorchVector):
            coefs = {
                key: val.to(self.coefs[key].device)
                for key, val in other.coefs.items()
            }
            other = TorchVector(coefs)
        return super()._apply_operation(other, func)

    def dtypes(
        self,
    ) -> Dict[str, str]:
        dtypes = super().dtypes()
        return {key: val.split(".", 1)[-1] for key, val in dtypes.items()}

    def shapes(
        self,
    ) -> Dict[str, Tuple[int, ...]]:
        return {key: tuple(coef.shape) for key, coef in self.coefs.items()}

    def pack(
        self,
    ) -> Dict[str, Any]:
        return {
            key: np.array(tns.cpu().numpy()) for key, tns in self.coefs.items()
        }

    @classmethod
    def unpack(
        cls,
        data: Dict[str, Any],
    ) -> Self:
        policy = get_device_policy()
        device = select_device(gpu=policy.gpu, idx=policy.idx)
        coefs = {
            # false-positive on `torch.from_numpy`; pylint: disable=no-member
            key: torch.from_numpy(dat).to(device)
            for key, dat in data.items()
        }
        return cls(coefs)

    def __eq__(
        self,
        other: Any,
    ) -> bool:
        valid = isinstance(other, TorchVector)
        if valid:
            valid = self.coefs.keys() == other.coefs.keys()
        if valid:
            valid = all(
                # false-positive on 'torch.equal'; pylint: disable=no-member
                torch.equal(tns, other.coefs[key].to(tns.device))
                for key, tns in self.coefs.items()
            )
        return valid

    def sign(self) -> Self:
        # false-positive; pylint: disable=no-member
        return self.apply_func(torch.sign)

    def minimum(
        self,
        other: Any,
    ) -> Self:
        # false-positive; pylint: disable=no-member
        if isinstance(other, Vector):
            return self._apply_operation(other, torch.minimum)
        if isinstance(other, float):
            other = torch.Tensor([other])
        return self.apply_func(torch.minimum, other)

    def maximum(
        self,
        other: Any,
    ) -> Self:
        # false-positive; pylint: disable=no-member
        if isinstance(other, Vector):
            return self._apply_operation(other, torch.maximum)
        if isinstance(other, float):
            other = torch.Tensor([other])
        return self.apply_func(torch.maximum, other)

    def sum(
        self,
        axis: Optional[int] = None,
        keepdims: bool = False,
    ) -> Self:
        coefs = {
            key: val.sum(dim=axis, keepdims=keepdims)
            for key, val in self.coefs.items()
        }
        return self.__class__(coefs)