Skip to content

declearn.model.api.Model

Bases: Generic[VectorT]

Abstract class defining an API to manipulate a ML model.

A 'Model' is an abstraction that defines a generic interface to access a model's parameters and perform operations (such as computing gradients or metrics over some data), enabling writing algorithms and operations agnostic to the framework in which the underlying model is implemented (e.g. PyTorch, TensorFlow, Scikit-Learn...).

Device-placement (i.e. running computations on CPU or GPU) is also handled as part of Model classes' backend, mapping the generic declearn.utils.DevicePolicy parameters to any required framework-specific instruction to adequately pick the device to use and ensure the wrapped model, input data and interfaced computations are placed there.

Source code in declearn/model/api/_model.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@create_types_registry
class Model(Generic[VectorT], metaclass=ABCMeta):
    """Abstract class defining an API to manipulate a ML model.

    A 'Model' is an abstraction that defines a generic interface
    to access a model's parameters and perform operations (such
    as computing gradients or metrics over some data), enabling
    writing algorithms and operations agnostic to the framework
    in which the underlying model is implemented (e.g. PyTorch,
    TensorFlow, Scikit-Learn...).

    Device-placement (i.e. running computations on CPU or GPU)
    is also handled as part of Model classes' backend, mapping
    the generic `declearn.utils.DevicePolicy` parameters to any
    required framework-specific instruction to adequately pick
    the device to use and ensure the wrapped model, input data
    and interfaced computations are placed there.
    """

    def __init__(
        self,
        model: Any,
    ) -> None:
        """Instantiate a Model interface wrapping a 'model' object."""
        self._model = model
        # Declare a private list where to record batch-wise training losses.
        self._loss_history = []  # type: List[float]

    def get_wrapped_model(self) -> Any:
        """Getter to access the wrapped framework-specific model object.

        This getter should be used sparingly, so as to avoid undesirable
        side effects. In particular, it should not be used in declearn
        backend code (but may be in examples or tests), as it is merely
        a way for end-users to access the wrapped model after training.

        Returns
        -------
        model:
            Wrapped model, of (framework/Model-subclass)-specific type.
        """
        return self._model

    @property
    @abstractmethod
    def device_policy(
        self,
    ) -> DevicePolicy:
        """Return the device-placement policy currently used by this model."""

    @property
    @abstractmethod
    def required_data_info(
        self,
    ) -> Set[str]:
        """List of 'data_info' fields required to initialize this model.

        Note: These fields should match a registered specification
        (see the [`declearn.data_info`][] submodule).
        """

    @abstractmethod
    def initialize(
        self,
        data_info: Dict[str, Any],
    ) -> None:
        """Initialize the model based on data specifications.

        Parameters
        ----------
        data_info: dict[str, any]
            Data specifications, presenting values for all fields
            listed under `self.required_data_info`

        Raises
        ------
        KeyError
            If some fields in `required_data_info` are missing.

        Notes
        -----
        See the `aggregate_data_info` method to derive `data_info`
        from client-wise dict.
        """

    @abstractmethod
    def get_config(
        self,
    ) -> Dict[str, Any]:
        """Return the model's parameters as a JSON-serializable dict."""

    @classmethod
    @abstractmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate a model from a configuration dict."""

    @abstractmethod
    def get_weights(
        self,
        trainable: bool = False,
    ) -> VectorT:
        """Return the model's weights, optionally excluding frozen ones.

        Parameters
        ----------
        trainable: bool, default=False
            Whether to restrict the returned weights to the trainable ones,
            or include those that are frozen, i.e. are not updates as part
            of the training process.

        Returns
        -------
        weights: Vector
            Vector wrapping the named weights data arrays.
            The concrete type of the returned Vector depends on the concrete
            `Model`, and is the same as with `compute_batch_gradients`.
        """

    @abstractmethod
    def set_weights(
        self,
        weights: VectorT,
        trainable: bool = False,
    ) -> None:
        """Assign values to the model's weights.

        This method can only be used to update the values of *all*
        model weights, with the optional exception of frozen (i.e.
        non-trainable) ones. It cannot be used to alter the values
        of a subset of weight tensors.

        Parameters
        ----------
        weights: Vector
            Vector wrapping the named data arrays that should replace
            the current weights' values.
            The concrete type of Vector depends on the Model class,
            and matches the `get_weights` method's return type.
        trainable: bool, default=False
            Whether the assigned weights only cover the trainable ones,
            or include those that are frozen, i.e. are not updated as
            part of the training process.

        Raises
        ------
        KeyError
            If the input weights do not match the expected number and
            names of weight tensors.
        TypeError
            If the input weights are of unproper concrete Vector type.
        """

    @abstractmethod
    def compute_batch_gradients(
        self,
        batch: Batch,
        max_norm: Optional[float] = None,
    ) -> VectorT:
        """Compute and return gradients computed over a given data batch.

        Compute the average gradients of the model's loss with respect
        to its trainable parameters for the given data batch.
        Optionally clip sample-wise gradients before batch-averaging.

        Record the loss value over the batch, which may be collected
        (and thereof purged from the internal memory) by calling the
        `collect_training_losses` method.

        Parameters
        ----------
        batch: declearn.typing.Batch
            Tuple wrapping input data, (opt.) target values and (opt.)
            sample weights to be applied to the loss function.
        max_norm: float or None, default=None
            Maximum L2-norm of sample-wise gradients, beyond which to
            clip them before computing the batch-average gradients.
            If None, batch-averaged gradients are computed directly,
            which is less costful in computational time and memory.

        Returns
        -------
        gradients: Vector
            Batch-averaged gradients, wrapped into a Vector (using
            a suited Vector subclass depending on the Model class).
        """

    @abstractmethod
    def apply_updates(
        self,
        updates: VectorT,
    ) -> None:
        """Apply updates to the model's weights."""

    def collect_training_losses(
        self,
    ) -> List[float]:
        """Collect batch-wise training losses accumulated over time.

        Return all recorded batch-averaged loss values computed a
        part of `compute_batch_gradients` calls, and clear them
        from memory, so that next time this method is called, only
        new values are returned.

        Returns
        -------
        losses:
            List of bath-averaged loss values computed over inputs
            to the `compute_batch_gradients` method.
        """
        losses, self._loss_history = self._loss_history, []
        return losses

    @abstractmethod
    def compute_batch_predictions(
        self,
        batch: Batch,
    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
        """Compute and return model predictions on given inputs.

        This method is designed to return numpy arrays independently
        from the wrapped model's actual framework, for compatibility
        purposed with the `declearn.metrics.Metric` API.

        Note that in most cases, the returned `y_true` and `s_wght`
        are directly taken from the input batch. Their inclusion in
        the inputs and outputs of this method aims to enable using
        some non-standard data-flow schemes, such as that of auto-
        encoder models, that re-use their inputs as labels.

        Parameters
        ----------
        batch: declearn.typing.Batch
            Tuple wrapping input data, (opt.) target values and (opt.)
            sample weights. Note that in general, predictions should
            only be computed from input data - but the API is flexible
            for edge cases, e.g. auto-encoder models, as target labels
            are equal to the input data.

        Returns
        -------
        y_true: np.ndarray
            Ground-truth labels, to which predictions are aligned
            and should be compared for loss (and other evaluation
            metrics) computation.
        y_pred: np.ndarray
            Output model predictions (scores or labels), wrapped as
            a (>=1)-d numpy array, batched along the first axis.
        s_wght: np.ndarray or None
            Optional sample weights to be used to weight metrics.
        """

    @abstractmethod
    def loss_function(
        self,
        y_true: np.ndarray,
        y_pred: np.ndarray,
    ) -> np.ndarray:
        """Compute the model's sample-wise loss from labels and predictions.

        This method is designed to be used when evaluating the model,
        to compute a sample-wise loss from the predictions output by
        `self.compute_batch_predictions`.

        It may further be wrapped as an ad-hoc samples-averaged Metric
        instance so as to mutualize the inference computations between
        the loss's and other evaluation metrics' computation.

        Parameters
        ----------
        y_true: np.ndarray
            Target values or labels, wrapped as a (>=1)-d numpy array,
            the first axis of which is the batching one.
        y_pred: np.ndarray
            Predicted values or scores, as a (>=1)-d numpy array aligned
            with the `y_true` one.

        Returns
        -------
        s_loss: np.ndarray
            Sample-wise loss values, as a 1-d numpy array.
        """

    @abstractmethod
    def update_device_policy(
        self,
        policy: Optional[DevicePolicy] = None,
    ) -> None:
        """Update the device-placement policy of this model.

        This method is designed to be called after a change in the global
        device-placement policy (e.g. to disable using a GPU, or move to
        a specific one), so as to place pre-existing Model instances and
        avoid policy inconsistencies that might cause repeated memory or
        runtime costs from moving data or weights around each time they
        are used. You should otherwise not worry about a Model's device-
        placement, as it is handled at instantiation based on the global
        device policy (see `declearn.utils.set_device_policy`).

        Parameters
        ----------
        policy: DevicePolicy or None, default=None
            Optional DevicePolicy dataclass instance to be used.
            If None, use the global device policy, accessed via
            `declearn.utils.get_device_policy`.
        """

device_policy: DevicePolicy abstractmethod property

Return the device-placement policy currently used by this model.

required_data_info: Set[str] abstractmethod property

List of 'data_info' fields required to initialize this model.

Note: These fields should match a registered specification (see the declearn.data_info submodule).

__init__(model)

Instantiate a Model interface wrapping a 'model' object.

Source code in declearn/model/api/_model.py
59
60
61
62
63
64
65
66
def __init__(
    self,
    model: Any,
) -> None:
    """Instantiate a Model interface wrapping a 'model' object."""
    self._model = model
    # Declare a private list where to record batch-wise training losses.
    self._loss_history = []  # type: List[float]

apply_updates(updates) abstractmethod

Apply updates to the model's weights.

Source code in declearn/model/api/_model.py
229
230
231
232
233
234
@abstractmethod
def apply_updates(
    self,
    updates: VectorT,
) -> None:
    """Apply updates to the model's weights."""

collect_training_losses()

Collect batch-wise training losses accumulated over time.

Return all recorded batch-averaged loss values computed a part of compute_batch_gradients calls, and clear them from memory, so that next time this method is called, only new values are returned.

Returns:

Name Type Description
losses List[float]

List of bath-averaged loss values computed over inputs to the compute_batch_gradients method.

Source code in declearn/model/api/_model.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def collect_training_losses(
    self,
) -> List[float]:
    """Collect batch-wise training losses accumulated over time.

    Return all recorded batch-averaged loss values computed a
    part of `compute_batch_gradients` calls, and clear them
    from memory, so that next time this method is called, only
    new values are returned.

    Returns
    -------
    losses:
        List of bath-averaged loss values computed over inputs
        to the `compute_batch_gradients` method.
    """
    losses, self._loss_history = self._loss_history, []
    return losses

compute_batch_gradients(batch, max_norm=None) abstractmethod

Compute and return gradients computed over a given data batch.

Compute the average gradients of the model's loss with respect to its trainable parameters for the given data batch. Optionally clip sample-wise gradients before batch-averaging.

Record the loss value over the batch, which may be collected (and thereof purged from the internal memory) by calling the collect_training_losses method.

Parameters:

Name Type Description Default
batch Batch

Tuple wrapping input data, (opt.) target values and (opt.) sample weights to be applied to the loss function.

required
max_norm Optional[float]

Maximum L2-norm of sample-wise gradients, beyond which to clip them before computing the batch-average gradients. If None, batch-averaged gradients are computed directly, which is less costful in computational time and memory.

None

Returns:

Name Type Description
gradients Vector

Batch-averaged gradients, wrapped into a Vector (using a suited Vector subclass depending on the Model class).

Source code in declearn/model/api/_model.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@abstractmethod
def compute_batch_gradients(
    self,
    batch: Batch,
    max_norm: Optional[float] = None,
) -> VectorT:
    """Compute and return gradients computed over a given data batch.

    Compute the average gradients of the model's loss with respect
    to its trainable parameters for the given data batch.
    Optionally clip sample-wise gradients before batch-averaging.

    Record the loss value over the batch, which may be collected
    (and thereof purged from the internal memory) by calling the
    `collect_training_losses` method.

    Parameters
    ----------
    batch: declearn.typing.Batch
        Tuple wrapping input data, (opt.) target values and (opt.)
        sample weights to be applied to the loss function.
    max_norm: float or None, default=None
        Maximum L2-norm of sample-wise gradients, beyond which to
        clip them before computing the batch-average gradients.
        If None, batch-averaged gradients are computed directly,
        which is less costful in computational time and memory.

    Returns
    -------
    gradients: Vector
        Batch-averaged gradients, wrapped into a Vector (using
        a suited Vector subclass depending on the Model class).
    """

compute_batch_predictions(batch) abstractmethod

Compute and return model predictions on given inputs.

This method is designed to return numpy arrays independently from the wrapped model's actual framework, for compatibility purposed with the declearn.metrics.Metric API.

Note that in most cases, the returned y_true and s_wght are directly taken from the input batch. Their inclusion in the inputs and outputs of this method aims to enable using some non-standard data-flow schemes, such as that of auto- encoder models, that re-use their inputs as labels.

Parameters:

Name Type Description Default
batch Batch

Tuple wrapping input data, (opt.) target values and (opt.) sample weights. Note that in general, predictions should only be computed from input data - but the API is flexible for edge cases, e.g. auto-encoder models, as target labels are equal to the input data.

required

Returns:

Name Type Description
y_true np.ndarray

Ground-truth labels, to which predictions are aligned and should be compared for loss (and other evaluation metrics) computation.

y_pred np.ndarray

Output model predictions (scores or labels), wrapped as a (>=1)-d numpy array, batched along the first axis.

s_wght np.ndarray or None

Optional sample weights to be used to weight metrics.

Source code in declearn/model/api/_model.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
@abstractmethod
def compute_batch_predictions(
    self,
    batch: Batch,
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
    """Compute and return model predictions on given inputs.

    This method is designed to return numpy arrays independently
    from the wrapped model's actual framework, for compatibility
    purposed with the `declearn.metrics.Metric` API.

    Note that in most cases, the returned `y_true` and `s_wght`
    are directly taken from the input batch. Their inclusion in
    the inputs and outputs of this method aims to enable using
    some non-standard data-flow schemes, such as that of auto-
    encoder models, that re-use their inputs as labels.

    Parameters
    ----------
    batch: declearn.typing.Batch
        Tuple wrapping input data, (opt.) target values and (opt.)
        sample weights. Note that in general, predictions should
        only be computed from input data - but the API is flexible
        for edge cases, e.g. auto-encoder models, as target labels
        are equal to the input data.

    Returns
    -------
    y_true: np.ndarray
        Ground-truth labels, to which predictions are aligned
        and should be compared for loss (and other evaluation
        metrics) computation.
    y_pred: np.ndarray
        Output model predictions (scores or labels), wrapped as
        a (>=1)-d numpy array, batched along the first axis.
    s_wght: np.ndarray or None
        Optional sample weights to be used to weight metrics.
    """

from_config(config) abstractmethod classmethod

Instantiate a model from a configuration dict.

Source code in declearn/model/api/_model.py
131
132
133
134
135
136
137
@classmethod
@abstractmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate a model from a configuration dict."""

get_config() abstractmethod

Return the model's parameters as a JSON-serializable dict.

Source code in declearn/model/api/_model.py
125
126
127
128
129
@abstractmethod
def get_config(
    self,
) -> Dict[str, Any]:
    """Return the model's parameters as a JSON-serializable dict."""

get_weights(trainable=False) abstractmethod

Return the model's weights, optionally excluding frozen ones.

Parameters:

Name Type Description Default
trainable bool

Whether to restrict the returned weights to the trainable ones, or include those that are frozen, i.e. are not updates as part of the training process.

False

Returns:

Name Type Description
weights Vector

Vector wrapping the named weights data arrays. The concrete type of the returned Vector depends on the concrete Model, and is the same as with compute_batch_gradients.

Source code in declearn/model/api/_model.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@abstractmethod
def get_weights(
    self,
    trainable: bool = False,
) -> VectorT:
    """Return the model's weights, optionally excluding frozen ones.

    Parameters
    ----------
    trainable: bool, default=False
        Whether to restrict the returned weights to the trainable ones,
        or include those that are frozen, i.e. are not updates as part
        of the training process.

    Returns
    -------
    weights: Vector
        Vector wrapping the named weights data arrays.
        The concrete type of the returned Vector depends on the concrete
        `Model`, and is the same as with `compute_batch_gradients`.
    """

get_wrapped_model()

Getter to access the wrapped framework-specific model object.

This getter should be used sparingly, so as to avoid undesirable side effects. In particular, it should not be used in declearn backend code (but may be in examples or tests), as it is merely a way for end-users to access the wrapped model after training.

Returns:

Name Type Description
model Any

Wrapped model, of (framework/Model-subclass)-specific type.

Source code in declearn/model/api/_model.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def get_wrapped_model(self) -> Any:
    """Getter to access the wrapped framework-specific model object.

    This getter should be used sparingly, so as to avoid undesirable
    side effects. In particular, it should not be used in declearn
    backend code (but may be in examples or tests), as it is merely
    a way for end-users to access the wrapped model after training.

    Returns
    -------
    model:
        Wrapped model, of (framework/Model-subclass)-specific type.
    """
    return self._model

initialize(data_info) abstractmethod

Initialize the model based on data specifications.

Parameters:

Name Type Description Default
data_info Dict[str, Any]

Data specifications, presenting values for all fields listed under self.required_data_info

required

Raises:

Type Description
KeyError

If some fields in required_data_info are missing.

Notes

See the aggregate_data_info method to derive data_info from client-wise dict.

Source code in declearn/model/api/_model.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@abstractmethod
def initialize(
    self,
    data_info: Dict[str, Any],
) -> None:
    """Initialize the model based on data specifications.

    Parameters
    ----------
    data_info: dict[str, any]
        Data specifications, presenting values for all fields
        listed under `self.required_data_info`

    Raises
    ------
    KeyError
        If some fields in `required_data_info` are missing.

    Notes
    -----
    See the `aggregate_data_info` method to derive `data_info`
    from client-wise dict.
    """

loss_function(y_true, y_pred) abstractmethod

Compute the model's sample-wise loss from labels and predictions.

This method is designed to be used when evaluating the model, to compute a sample-wise loss from the predictions output by self.compute_batch_predictions.

It may further be wrapped as an ad-hoc samples-averaged Metric instance so as to mutualize the inference computations between the loss's and other evaluation metrics' computation.

Parameters:

Name Type Description Default
y_true np.ndarray

Target values or labels, wrapped as a (>=1)-d numpy array, the first axis of which is the batching one.

required
y_pred np.ndarray

Predicted values or scores, as a (>=1)-d numpy array aligned with the y_true one.

required

Returns:

Name Type Description
s_loss np.ndarray

Sample-wise loss values, as a 1-d numpy array.

Source code in declearn/model/api/_model.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
@abstractmethod
def loss_function(
    self,
    y_true: np.ndarray,
    y_pred: np.ndarray,
) -> np.ndarray:
    """Compute the model's sample-wise loss from labels and predictions.

    This method is designed to be used when evaluating the model,
    to compute a sample-wise loss from the predictions output by
    `self.compute_batch_predictions`.

    It may further be wrapped as an ad-hoc samples-averaged Metric
    instance so as to mutualize the inference computations between
    the loss's and other evaluation metrics' computation.

    Parameters
    ----------
    y_true: np.ndarray
        Target values or labels, wrapped as a (>=1)-d numpy array,
        the first axis of which is the batching one.
    y_pred: np.ndarray
        Predicted values or scores, as a (>=1)-d numpy array aligned
        with the `y_true` one.

    Returns
    -------
    s_loss: np.ndarray
        Sample-wise loss values, as a 1-d numpy array.
    """

set_weights(weights, trainable=False) abstractmethod

Assign values to the model's weights.

This method can only be used to update the values of all model weights, with the optional exception of frozen (i.e. non-trainable) ones. It cannot be used to alter the values of a subset of weight tensors.

Parameters:

Name Type Description Default
weights VectorT

Vector wrapping the named data arrays that should replace the current weights' values. The concrete type of Vector depends on the Model class, and matches the get_weights method's return type.

required
trainable bool

Whether the assigned weights only cover the trainable ones, or include those that are frozen, i.e. are not updated as part of the training process.

False

Raises:

Type Description
KeyError

If the input weights do not match the expected number and names of weight tensors.

TypeError

If the input weights are of unproper concrete Vector type.

Source code in declearn/model/api/_model.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@abstractmethod
def set_weights(
    self,
    weights: VectorT,
    trainable: bool = False,
) -> None:
    """Assign values to the model's weights.

    This method can only be used to update the values of *all*
    model weights, with the optional exception of frozen (i.e.
    non-trainable) ones. It cannot be used to alter the values
    of a subset of weight tensors.

    Parameters
    ----------
    weights: Vector
        Vector wrapping the named data arrays that should replace
        the current weights' values.
        The concrete type of Vector depends on the Model class,
        and matches the `get_weights` method's return type.
    trainable: bool, default=False
        Whether the assigned weights only cover the trainable ones,
        or include those that are frozen, i.e. are not updated as
        part of the training process.

    Raises
    ------
    KeyError
        If the input weights do not match the expected number and
        names of weight tensors.
    TypeError
        If the input weights are of unproper concrete Vector type.
    """

update_device_policy(policy=None) abstractmethod

Update the device-placement policy of this model.

This method is designed to be called after a change in the global device-placement policy (e.g. to disable using a GPU, or move to a specific one), so as to place pre-existing Model instances and avoid policy inconsistencies that might cause repeated memory or runtime costs from moving data or weights around each time they are used. You should otherwise not worry about a Model's device- placement, as it is handled at instantiation based on the global device policy (see declearn.utils.set_device_policy).

Parameters:

Name Type Description Default
policy Optional[DevicePolicy]

Optional DevicePolicy dataclass instance to be used. If None, use the global device policy, accessed via declearn.utils.get_device_policy.

None
Source code in declearn/model/api/_model.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@abstractmethod
def update_device_policy(
    self,
    policy: Optional[DevicePolicy] = None,
) -> None:
    """Update the device-placement policy of this model.

    This method is designed to be called after a change in the global
    device-placement policy (e.g. to disable using a GPU, or move to
    a specific one), so as to place pre-existing Model instances and
    avoid policy inconsistencies that might cause repeated memory or
    runtime costs from moving data or weights around each time they
    are used. You should otherwise not worry about a Model's device-
    placement, as it is handled at instantiation based on the global
    device policy (see `declearn.utils.set_device_policy`).

    Parameters
    ----------
    policy: DevicePolicy or None, default=None
        Optional DevicePolicy dataclass instance to be used.
        If None, use the global device policy, accessed via
        `declearn.utils.get_device_policy`.
    """