Skip to content

declearn.optimizer.Optimizer

Base class to define gradient-descent-based optimizers.

The Optimizer class defines an API that is required by other declearn components for federated learning processes to run. It is also fully-workable and is designed to be customizable through the use of "plug-in modules" rather than subclassing (which might be used for advanced algorithm modifications): see the base classes declearn.optimizer.modules.OptiModule and declearn.optimizer.regularizers.Regularizer for details.

The process implemented here is the following:

  • Compute or receive the (pseudo-)gradients of a model.
  • Compute loss-regularization terms and add them to the gradients, based on a list of plug-in regularizers.
  • Refine gradients by running them through plug-in modules, which are thus composed by sequential application.
  • Optionally compute a decoupled weight decay term (see [1]) and add it to the updates (i.e. refined gradients).
  • Apply the learning rate and perform the weights' udpate.

Most plug-in modules are self-contained, in the sense that they do not require any information flow between the server and its clients in a federated process, and may be used solely by the server, by clients or even by a subset of clients - at least formally (their might be correctness or convergence issues with clients not adopting similar local optimization strategies).

However, some algorithms designed (or adapted) specifically for federated learning require some form of synchronicity between the server and clients. In that case, they should be coded to emit and expect auxiliary variables, shared between server and clients alongside updated model weights during training. Those mechanisms are to be implemented at the level of the modules themselves, but are wrapped at optimizer level, which collects plugged-in-modules' variables and maps back received variables to them.

Attributes:

Name Type Description
lrate float

Base learning rate applied to computed updates.

w_decay float

Decoupled weight decay parameter.

modules list[OptiModule]

List of plug-in modules composed into the optimizer's gradients-to-updates computation algorithm.

regularizers list[Regularizer]

List of plug-in loss regularization modules composed into the optimizer's gradients-to-updates computation algorithm.

API methods

  • apply_gradients(Model, Vector) -> None: Update a Model based on a pre-computed Vector of gradients.
  • collect_aux_var() -> Dict[str, AuxVar]: Collect and package plug-in modules' auxiliary variables.
  • compute_updates_from_gradients(Model, Vector) -> Vector: Compute and return model updates based on pre-computed gradients.
  • process_aux_var(Dict[str, AuxVar]) -> None: Pass auxiliary variables to plug-in modules for processing.
  • run_train_step(Model, batch) -> None: Compute gradients of a Model over a Batch and apply updates.
  • start_round() -> None: Signal that a new training round is starting to wrapped regularizers.

References

[1] Loshchilov & Hutter, 2019. Decoupled Weight Decay Regularization. https://arxiv.org/abs/1711.05101

See also

Source code in declearn/optimizer/_base.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
class Optimizer:
    """Base class to define gradient-descent-based optimizers.

    The `Optimizer` class defines an API that is required by other
    declearn components for federated learning processes to run.
    It is also fully-workable and is designed to be customizable
    through the use of "plug-in modules" rather than subclassing
    (which might be used for advanced algorithm modifications):
    see the base classes [declearn.optimizer.modules.OptiModule][]
    and [declearn.optimizer.regularizers.Regularizer][] for details.

    The process implemented here is the following:

    - Compute or receive the (pseudo-)gradients of a model.
    - Compute loss-regularization terms and add them to the
      gradients, based on a list of plug-in regularizers.
    - Refine gradients by running them through plug-in modules,
      which are thus composed by sequential application.
    - Optionally compute a decoupled weight decay term (see [1])
      and add it to the updates (i.e. refined gradients).
    - Apply the learning rate and perform the weights' udpate.

    Most plug-in modules are self-contained, in the sense that they
    do not require any information flow between the server and its
    clients in a federated process, and may be used solely by the
    server, by clients or even by a subset of clients - at least
    formally (their might be correctness or convergence issues with
    clients not adopting similar local optimization strategies).

    However, some algorithms designed (or adapted) specifically for
    federated learning require some form of synchronicity between
    the server and clients. In that case, they should be coded to
    emit and expect auxiliary variables, shared between server and
    clients alongside updated model weights during training. Those
    mechanisms are to be implemented at the level of the modules
    themselves, but are wrapped at optimizer level, which collects
    plugged-in-modules' variables and maps back received variables
    to them.

    Attributes
    ----------
    lrate: float
        Base learning rate applied to computed updates.
    w_decay: float
        Decoupled weight decay parameter.
    modules: list[OptiModule]
        List of plug-in modules composed into the optimizer's
        gradients-to-updates computation algorithm.
    regularizers: list[Regularizer]
        List of plug-in loss regularization modules composed into
        the optimizer's gradients-to-updates computation algorithm.

    API methods
    -----------
    - apply_gradients(Model, Vector) -> None:
        Update a Model based on a pre-computed Vector of gradients.
    - collect_aux_var() -> Dict[str, AuxVar]:
        Collect and package plug-in modules' auxiliary variables.
    - compute_updates_from_gradients(Model, Vector) -> Vector:
        Compute and return model updates based on pre-computed gradients.
    - process_aux_var(Dict[str, AuxVar]) -> None:
        Pass auxiliary variables to plug-in modules for processing.
    - run_train_step(Model, batch) -> None:
        Compute gradients of a Model over a Batch and apply updates.
    - start_round() -> None:
        Signal that a new training round is starting to wrapped regularizers.

    References
    ----------
    [1] Loshchilov & Hutter, 2019.
        Decoupled Weight Decay Regularization.
        https://arxiv.org/abs/1711.05101

    See also
    --------
    - [declearn.optimizer.list_optim_modules][]:
        Return a mapping of registered OptiModule subclasses.
    - [declearn.optimizer.list_optim_regularizers][]:
        Return a mapping of registered Regularizer subclasses.
    """

    def __init__(
        self,
        lrate: float,  # future: add scheduling tools
        w_decay: float = 0.0,  # future: add scheduling tools
        regularizers: Optional[
            Sequence[Union[Regularizer, str, Tuple[str, Dict[str, Any]]]]
        ] = None,
        modules: Optional[
            Sequence[Union[OptiModule, str, Tuple[str, Dict[str, Any]]]]
        ] = None,
    ) -> None:
        """Instantiate the gradient-descent optimizer.

        Parameters
        ----------
        lrate: float
            Base learning rate (i.e. step size) applied to gradients-
            based updates upon applying them to a model's weights.
        w_decay: float, default=0.
            Optional weight decay parameter, used to parameterize
            a decoupled weight decay regularization term (see [1])
            added to the updates right before the learning rate is
            applied and model weights are effectively updated.
        regularizers: list[Regularizer or specs] or None, default=None
            Optional list of plug-in loss regularizers. Regularizers will
            be applied to gradients following this list's order, prior to
            any other alteration (e.g. accelaration module - see below).
            See `declearn.optimizer.regularizers.Regularizer` for details.
            See Notes section below for details on the "specs" format.
        modules: list[OptiModule or specs] or None, default=None
            Optional list of plug-in modules implementing gradients'
            alteration into model weights' udpates. Modules will be
            applied to gradients following this list's ordering.
            See `declearn.optimizer.modules.OptiModule` for details.
            See Notes section below for details on the "specs" format.

        Notes
        -----
        `Regularizer` and `OptiModule` to be used by this optimizer,
        specified using the `regularizers` and `modules` parameters,
        may be passed as ready-for-use instances, or be instantiated
        from specs, consisting either of a single string (the `name`
        attribute of the class to build) or a tuple grouping this
        name and a config dict (to specify some hyper-parameters).

        References
        ----------
        [1] Loshchilov & Hutter, 2019.
            Decoupled Weight Decay Regularization.
            https://arxiv.org/abs/1711.05101
        """
        self.lrate = lrate
        self.w_decay = w_decay
        self.regularizers = (
            []
            if regularizers is None
            else self._parse_plugins(Regularizer, regularizers)  # type: ignore
        )  # type: List[Regularizer]
        self.modules = (
            []
            if modules is None
            else self._parse_plugins(OptiModule, modules)  # type: ignore
        )  # type: List[OptiModule]

    def _parse_plugins(
        self,
        cls: Type[Union[OptiModule, Regularizer]],
        plugins: Sequence[Union[Any, str, Tuple[str, Dict[str, Any]]]],
    ) -> Union[List[OptiModule], List[Regularizer]]:
        """Parse a list of plug-in specs into a list of instances.

        Parameters
        ----------
        cls: Type[OptiModule or Regularizer]
            Base type of plug-ins being instantiated.
        plugins: list[`cls` | str | (str, dict)]
            List of instances or specifications to process and/or type-check.
            Specifications may be a single string (`name` attribute of the
            type to build) or a tuple grouping this name and a config dict
            (to specify non-default hyper-parameters).

        Returns
        -------
        plugins: list[`cls`]
            List of `cls` instances created (or taken) from the specs.
        """
        output = []
        for specs in plugins:
            if isinstance(specs, cls):
                plugin = specs
            elif isinstance(specs, str):
                plugin = cls.from_specs(specs, config={})
            elif isinstance(specs, (tuple, list)) and (len(specs) == 2):
                plugin = cls.from_specs(*specs)
            else:
                raise TypeError(
                    f"Cannot instantiate a {cls.__name__} from {specs}. "
                    "Required a name (str) or specs ((str, dict) tuple)."
                )
            output.append(plugin)
        return output  # type: ignore

    def get_config(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable dict with this optimizer's parameters.

        The counterpart to this method is the `from_config` classmethod.
        To access the optimizer's inner states, see the `get_state` method.

        Returns
        -------
        config: dict[str, any]
            JSON-serializable dict storing this optimizer's instantiation
            configuration.
        """
        regulzr = [(reg.name, reg.get_config()) for reg in self.regularizers]
        modules = [(mod.name, mod.get_config()) for mod in self.modules]
        return {
            "lrate": self.lrate,
            "w_decay": self.w_decay,
            "regularizers": regulzr,
            "modules": modules,
        }

    @classmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate an Optimizer from its configuration dict.

        The counterpart to this classmethod is the `get_config` method.
        To restore the optimizer's inner states, see its `get_state` method.

        Parameters
        ----------
        config: dict[str, Any]
            Dict storing the optimizer's instantiation configuration.

        Raises
        ------
        KeyError
            If the provided `config` lacks some required parameters
            and/or contains some unused ones.
        """
        return cls(**config)

    def compute_updates_from_gradients(
        self,
        model: Model[Vector[T]],
        gradients: Vector[T],
    ) -> Vector[T]:
        """Compute and return model updates based on pre-computed gradients.

        Parameters
        ----------
        model: Model
            Model instance that is to be trained using gradient-descent.
            This parameter is only used to access current weights in case
            some loss regularizers are part of the pipeline.
        gradients: Vector
            Pre-computed vector of (pseudo-)gradients based on which to
            perform the gradient-descent step, by applying the algorithm
            defined by this optimizer and its plug-in modules.

        Returns
        -------
        updates: Vector
            Model weights' updates, preserving input `gradients`'s specs,
            ready to be applied using the `model.apply_updates` method.
        """
        # Optionally fetch the model's trainable weights.
        if self.regularizers or self.w_decay:
            weights = model.get_weights(trainable=True)
        # Run input gradients and weights through plug-in regularizers.
        if self.regularizers:
            for regularizer in self.regularizers:
                gradients = regularizer.run(gradients, weights)
        # Run input gradients through plug-in modules.
        for module in self.modules:
            gradients = module.run(gradients)
        # Apply the base learning rate.
        updates = self.lrate * gradients
        # Optionally add the decoupled weight decay term.
        if self.w_decay:
            updates += self.w_decay * weights
        # Return ready-to-apply model updates.
        return -1.0 * updates

    def collect_aux_var(
        self,
    ) -> Dict[str, AuxVar]:
        """Return auxiliary variables that need to be shared between nodes.

        Returns
        -------
        aux_var:
            Dict that associates `module.collect_aux_var()` values
            to `module.name` keys for each and every module plugged
            in this optimizer that produces auxiliary variables.
        """
        aux_var = {}  # type: Dict[str, AuxVar]
        for module in self.modules:
            auxv = module.collect_aux_var()
            if auxv is not None:
                name = module.aux_name or module.name
                aux_var[name] = auxv
        return aux_var

    def process_aux_var(
        self,
        aux_var: Dict[str, AuxVar],
    ) -> None:
        """Update plug-in modules based on received shared auxiliary variables.

        Received auxiliary variables will be passed to this optimizer's
        modules' `process_aux_var` method, mapped based on `module.name`.

        Parameters
        ----------
        aux_var: dict[str, AuxVar]
            Auxiliary variables received from the counterpart optimizer
            (on the other side of the client/server relationship), that
            are packed as a `{module.name: module.auxvar_cls}` dict for
            modules that do use auxiliary variables.
            When auxiliary variables from multiple peers are due to be
            processed, they must be aggregated prior to being passed to
            this method.

        Raises
        ------
        KeyError
            If a key from `aux_var` does not match the name of any module
            plugged in this optimizer (i.e. if received variables cannot
            be mapped to a destinatory module).
        """
        modules = {
            (module.aux_name or module.name): module for module in self.modules
        }
        for name, auxv in aux_var.items():
            module = modules.get(name, None)
            if module is None:
                raise KeyError(
                    f"No module with name '{name}' is available to receive "
                    "auxiliary variables."
                )
            module.process_aux_var(auxv)

    def start_round(
        self,
    ) -> None:
        """Perform any required action at the start of a training round.

        This method calls the `on_round_start` callback of each and every
        wrapped `Regularizer` which may be used to regulate some internal
        state variables.
        """
        for regularizer in self.regularizers:
            regularizer.on_round_start()

    def run_train_step(
        self,
        model: Model,
        batch: Batch,
        sclip: Optional[float] = None,
    ) -> None:
        """Perform a gradient-descent step on a given batch.

        Parameters
        ----------
        model: Model
            Model instance that is to be trained using gradient-descent.
        batch: Batch
            Training data used for that training step.
        sclip: float or None, default=None
            Optional L2-norm clipping threshold for sample-wise gradients,
            restraining their sensitivity prior to any alteration designed
            as part of this Optimizer's pipeline of plug-in algorithms.

        Returns
        -------
        None
            This method does not return, as `model` is updated in-place.
        """
        gradients = model.compute_batch_gradients(batch, max_norm=sclip)
        self.apply_gradients(model, gradients)

    def apply_gradients(
        self,
        model: Model[Vector[T]],
        gradients: Vector[T],
    ) -> None:
        """Compute and apply model updates based on pre-computed gradients.

        Parameters
        ----------
        model: Model
            Model instance that is to be trained using gradient-descent.
        gradients: Vector
            Pre-computed vector of (pseudo-)gradients based on which to
            perform the gradient-descent step, by applying the algorithm
            defined by this optimizer and its plug-in modules.

        Returns
        -------
        None
            This method does not return, as `model` is updated in-place.
        """
        updates = self.compute_updates_from_gradients(model, gradients)
        model.apply_updates(updates)

    def get_state(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable dict with this optimizer's state.

        The counterpart to this method is the `set_state` one.

        Returns
        -------
        state: dict[str, any]
            JSON-serializable dict storing this optimizer's inner state
            variables (i.e. those from its modules).
        """
        modules = [(mod.name, mod.get_state()) for mod in self.modules]
        return {"modules": modules}

    def set_state(
        self,
        states: Dict[str, Any],
    ) -> None:
        """Load a saved state dict into an optimizer instance.

        The counterpart to this method is the `get_state` one.

        Parameters
        ----------
        states: dict[str, any]
            Dict storing values to assign to this optimizer's inner
            state variables (i.e. those from its modules).

        Raises
        ------
        KeyError
            If the received states do not match the expected config,
            whether because a module is missing or one of its states
            is missing.
            In both cases, the Optimizer's states will be reverted
            to their values prior to the failed call to this method.
        RuntimeError
            If a KeyError was raised both when trying to apply the
            input `state` and when trying to revert the states to
            their initial values after that first error was raised.
            This should never happen and indicates a source code
            error in a wrapped module, or even in this class.
        """
        if "modules" not in states:
            raise KeyError("Optimizer input 'states' lack a 'modules' field.")
        if len(states["modules"]) != len(self.modules):
            raise KeyError("Optimizer 'states' do not match modules config.")
        initial = self.get_state()
        try:
            self._set_state(states)
        except KeyError as exc:
            try:
                self._set_state(initial)
            except KeyError as exc_bis:
                raise RuntimeError(
                    "`Optimizer.set_state` failed to restore initial states "
                    "after a KeyError was raised during states' attempted "
                    "update. There probably is a source code error with one "
                    "of the wrapped modules.\n"
                    f"Error when reverting states: {exc_bis}\n"
                    f"Initial update error: {exc}\n"
                ) from exc_bis
            raise exc

    def _set_state(
        self,
        states: Dict[str, Any],
    ) -> None:
        """Backend to the `set_state` method, lacking exception-catching."""
        for mod, (name, state) in zip(self.modules, states["modules"]):
            if mod.name != name:
                raise KeyError(
                    "Optimizer 'states' do not match modules config."
                )
            # Note: this may raise a KeyError if 'state' is misspecified.
            mod.set_state(state)

__init__(lrate, w_decay=0.0, regularizers=None, modules=None)

Instantiate the gradient-descent optimizer.

Parameters:

Name Type Description Default
lrate float

Base learning rate (i.e. step size) applied to gradients- based updates upon applying them to a model's weights.

required
w_decay float

Optional weight decay parameter, used to parameterize a decoupled weight decay regularization term (see [1]) added to the updates right before the learning rate is applied and model weights are effectively updated.

0.0
regularizers Optional[Sequence[Union[Regularizer, str, Tuple[str, Dict[str, Any]]]]]

Optional list of plug-in loss regularizers. Regularizers will be applied to gradients following this list's order, prior to any other alteration (e.g. accelaration module - see below). See declearn.optimizer.regularizers.Regularizer for details. See Notes section below for details on the "specs" format.

None
modules Optional[Sequence[Union[OptiModule, str, Tuple[str, Dict[str, Any]]]]]

Optional list of plug-in modules implementing gradients' alteration into model weights' udpates. Modules will be applied to gradients following this list's ordering. See declearn.optimizer.modules.OptiModule for details. See Notes section below for details on the "specs" format.

None

Notes

Regularizer and OptiModule to be used by this optimizer, specified using the regularizers and modules parameters, may be passed as ready-for-use instances, or be instantiated from specs, consisting either of a single string (the name attribute of the class to build) or a tuple grouping this name and a config dict (to specify some hyper-parameters).

References

[1] Loshchilov & Hutter, 2019. Decoupled Weight Decay Regularization. https://arxiv.org/abs/1711.05101

Source code in declearn/optimizer/_base.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def __init__(
    self,
    lrate: float,  # future: add scheduling tools
    w_decay: float = 0.0,  # future: add scheduling tools
    regularizers: Optional[
        Sequence[Union[Regularizer, str, Tuple[str, Dict[str, Any]]]]
    ] = None,
    modules: Optional[
        Sequence[Union[OptiModule, str, Tuple[str, Dict[str, Any]]]]
    ] = None,
) -> None:
    """Instantiate the gradient-descent optimizer.

    Parameters
    ----------
    lrate: float
        Base learning rate (i.e. step size) applied to gradients-
        based updates upon applying them to a model's weights.
    w_decay: float, default=0.
        Optional weight decay parameter, used to parameterize
        a decoupled weight decay regularization term (see [1])
        added to the updates right before the learning rate is
        applied and model weights are effectively updated.
    regularizers: list[Regularizer or specs] or None, default=None
        Optional list of plug-in loss regularizers. Regularizers will
        be applied to gradients following this list's order, prior to
        any other alteration (e.g. accelaration module - see below).
        See `declearn.optimizer.regularizers.Regularizer` for details.
        See Notes section below for details on the "specs" format.
    modules: list[OptiModule or specs] or None, default=None
        Optional list of plug-in modules implementing gradients'
        alteration into model weights' udpates. Modules will be
        applied to gradients following this list's ordering.
        See `declearn.optimizer.modules.OptiModule` for details.
        See Notes section below for details on the "specs" format.

    Notes
    -----
    `Regularizer` and `OptiModule` to be used by this optimizer,
    specified using the `regularizers` and `modules` parameters,
    may be passed as ready-for-use instances, or be instantiated
    from specs, consisting either of a single string (the `name`
    attribute of the class to build) or a tuple grouping this
    name and a config dict (to specify some hyper-parameters).

    References
    ----------
    [1] Loshchilov & Hutter, 2019.
        Decoupled Weight Decay Regularization.
        https://arxiv.org/abs/1711.05101
    """
    self.lrate = lrate
    self.w_decay = w_decay
    self.regularizers = (
        []
        if regularizers is None
        else self._parse_plugins(Regularizer, regularizers)  # type: ignore
    )  # type: List[Regularizer]
    self.modules = (
        []
        if modules is None
        else self._parse_plugins(OptiModule, modules)  # type: ignore
    )  # type: List[OptiModule]

apply_gradients(model, gradients)

Compute and apply model updates based on pre-computed gradients.

Parameters:

Name Type Description Default
model Model[Vector[T]]

Model instance that is to be trained using gradient-descent.

required
gradients Vector[T]

Pre-computed vector of (pseudo-)gradients based on which to perform the gradient-descent step, by applying the algorithm defined by this optimizer and its plug-in modules.

required

Returns:

Type Description
None

This method does not return, as model is updated in-place.

Source code in declearn/optimizer/_base.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def apply_gradients(
    self,
    model: Model[Vector[T]],
    gradients: Vector[T],
) -> None:
    """Compute and apply model updates based on pre-computed gradients.

    Parameters
    ----------
    model: Model
        Model instance that is to be trained using gradient-descent.
    gradients: Vector
        Pre-computed vector of (pseudo-)gradients based on which to
        perform the gradient-descent step, by applying the algorithm
        defined by this optimizer and its plug-in modules.

    Returns
    -------
    None
        This method does not return, as `model` is updated in-place.
    """
    updates = self.compute_updates_from_gradients(model, gradients)
    model.apply_updates(updates)

collect_aux_var()

Return auxiliary variables that need to be shared between nodes.

Returns:

Name Type Description
aux_var Dict[str, AuxVar]

Dict that associates module.collect_aux_var() values to module.name keys for each and every module plugged in this optimizer that produces auxiliary variables.

Source code in declearn/optimizer/_base.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def collect_aux_var(
    self,
) -> Dict[str, AuxVar]:
    """Return auxiliary variables that need to be shared between nodes.

    Returns
    -------
    aux_var:
        Dict that associates `module.collect_aux_var()` values
        to `module.name` keys for each and every module plugged
        in this optimizer that produces auxiliary variables.
    """
    aux_var = {}  # type: Dict[str, AuxVar]
    for module in self.modules:
        auxv = module.collect_aux_var()
        if auxv is not None:
            name = module.aux_name or module.name
            aux_var[name] = auxv
    return aux_var

compute_updates_from_gradients(model, gradients)

Compute and return model updates based on pre-computed gradients.

Parameters:

Name Type Description Default
model Model[Vector[T]]

Model instance that is to be trained using gradient-descent. This parameter is only used to access current weights in case some loss regularizers are part of the pipeline.

required
gradients Vector[T]

Pre-computed vector of (pseudo-)gradients based on which to perform the gradient-descent step, by applying the algorithm defined by this optimizer and its plug-in modules.

required

Returns:

Name Type Description
updates Vector

Model weights' updates, preserving input gradients's specs, ready to be applied using the model.apply_updates method.

Source code in declearn/optimizer/_base.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def compute_updates_from_gradients(
    self,
    model: Model[Vector[T]],
    gradients: Vector[T],
) -> Vector[T]:
    """Compute and return model updates based on pre-computed gradients.

    Parameters
    ----------
    model: Model
        Model instance that is to be trained using gradient-descent.
        This parameter is only used to access current weights in case
        some loss regularizers are part of the pipeline.
    gradients: Vector
        Pre-computed vector of (pseudo-)gradients based on which to
        perform the gradient-descent step, by applying the algorithm
        defined by this optimizer and its plug-in modules.

    Returns
    -------
    updates: Vector
        Model weights' updates, preserving input `gradients`'s specs,
        ready to be applied using the `model.apply_updates` method.
    """
    # Optionally fetch the model's trainable weights.
    if self.regularizers or self.w_decay:
        weights = model.get_weights(trainable=True)
    # Run input gradients and weights through plug-in regularizers.
    if self.regularizers:
        for regularizer in self.regularizers:
            gradients = regularizer.run(gradients, weights)
    # Run input gradients through plug-in modules.
    for module in self.modules:
        gradients = module.run(gradients)
    # Apply the base learning rate.
    updates = self.lrate * gradients
    # Optionally add the decoupled weight decay term.
    if self.w_decay:
        updates += self.w_decay * weights
    # Return ready-to-apply model updates.
    return -1.0 * updates

from_config(config) classmethod

Instantiate an Optimizer from its configuration dict.

The counterpart to this classmethod is the get_config method. To restore the optimizer's inner states, see its get_state method.

Parameters:

Name Type Description Default
config Dict[str, Any]

Dict storing the optimizer's instantiation configuration.

required

Raises:

Type Description
KeyError

If the provided config lacks some required parameters and/or contains some unused ones.

Source code in declearn/optimizer/_base.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
@classmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate an Optimizer from its configuration dict.

    The counterpart to this classmethod is the `get_config` method.
    To restore the optimizer's inner states, see its `get_state` method.

    Parameters
    ----------
    config: dict[str, Any]
        Dict storing the optimizer's instantiation configuration.

    Raises
    ------
    KeyError
        If the provided `config` lacks some required parameters
        and/or contains some unused ones.
    """
    return cls(**config)

get_config()

Return a JSON-serializable dict with this optimizer's parameters.

The counterpart to this method is the from_config classmethod. To access the optimizer's inner states, see the get_state method.

Returns:

Name Type Description
config dict[str, any]

JSON-serializable dict storing this optimizer's instantiation configuration.

Source code in declearn/optimizer/_base.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def get_config(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable dict with this optimizer's parameters.

    The counterpart to this method is the `from_config` classmethod.
    To access the optimizer's inner states, see the `get_state` method.

    Returns
    -------
    config: dict[str, any]
        JSON-serializable dict storing this optimizer's instantiation
        configuration.
    """
    regulzr = [(reg.name, reg.get_config()) for reg in self.regularizers]
    modules = [(mod.name, mod.get_config()) for mod in self.modules]
    return {
        "lrate": self.lrate,
        "w_decay": self.w_decay,
        "regularizers": regulzr,
        "modules": modules,
    }

get_state()

Return a JSON-serializable dict with this optimizer's state.

The counterpart to this method is the set_state one.

Returns:

Name Type Description
state dict[str, any]

JSON-serializable dict storing this optimizer's inner state variables (i.e. those from its modules).

Source code in declearn/optimizer/_base.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def get_state(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable dict with this optimizer's state.

    The counterpart to this method is the `set_state` one.

    Returns
    -------
    state: dict[str, any]
        JSON-serializable dict storing this optimizer's inner state
        variables (i.e. those from its modules).
    """
    modules = [(mod.name, mod.get_state()) for mod in self.modules]
    return {"modules": modules}

process_aux_var(aux_var)

Update plug-in modules based on received shared auxiliary variables.

Received auxiliary variables will be passed to this optimizer's modules' process_aux_var method, mapped based on module.name.

Parameters:

Name Type Description Default
aux_var Dict[str, AuxVar]

Auxiliary variables received from the counterpart optimizer (on the other side of the client/server relationship), that are packed as a {module.name: module.auxvar_cls} dict for modules that do use auxiliary variables. When auxiliary variables from multiple peers are due to be processed, they must be aggregated prior to being passed to this method.

required

Raises:

Type Description
KeyError

If a key from aux_var does not match the name of any module plugged in this optimizer (i.e. if received variables cannot be mapped to a destinatory module).

Source code in declearn/optimizer/_base.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def process_aux_var(
    self,
    aux_var: Dict[str, AuxVar],
) -> None:
    """Update plug-in modules based on received shared auxiliary variables.

    Received auxiliary variables will be passed to this optimizer's
    modules' `process_aux_var` method, mapped based on `module.name`.

    Parameters
    ----------
    aux_var: dict[str, AuxVar]
        Auxiliary variables received from the counterpart optimizer
        (on the other side of the client/server relationship), that
        are packed as a `{module.name: module.auxvar_cls}` dict for
        modules that do use auxiliary variables.
        When auxiliary variables from multiple peers are due to be
        processed, they must be aggregated prior to being passed to
        this method.

    Raises
    ------
    KeyError
        If a key from `aux_var` does not match the name of any module
        plugged in this optimizer (i.e. if received variables cannot
        be mapped to a destinatory module).
    """
    modules = {
        (module.aux_name or module.name): module for module in self.modules
    }
    for name, auxv in aux_var.items():
        module = modules.get(name, None)
        if module is None:
            raise KeyError(
                f"No module with name '{name}' is available to receive "
                "auxiliary variables."
            )
        module.process_aux_var(auxv)

run_train_step(model, batch, sclip=None)

Perform a gradient-descent step on a given batch.

Parameters:

Name Type Description Default
model Model

Model instance that is to be trained using gradient-descent.

required
batch Batch

Training data used for that training step.

required
sclip Optional[float]

Optional L2-norm clipping threshold for sample-wise gradients, restraining their sensitivity prior to any alteration designed as part of this Optimizer's pipeline of plug-in algorithms.

None

Returns:

Type Description
None

This method does not return, as model is updated in-place.

Source code in declearn/optimizer/_base.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def run_train_step(
    self,
    model: Model,
    batch: Batch,
    sclip: Optional[float] = None,
) -> None:
    """Perform a gradient-descent step on a given batch.

    Parameters
    ----------
    model: Model
        Model instance that is to be trained using gradient-descent.
    batch: Batch
        Training data used for that training step.
    sclip: float or None, default=None
        Optional L2-norm clipping threshold for sample-wise gradients,
        restraining their sensitivity prior to any alteration designed
        as part of this Optimizer's pipeline of plug-in algorithms.

    Returns
    -------
    None
        This method does not return, as `model` is updated in-place.
    """
    gradients = model.compute_batch_gradients(batch, max_norm=sclip)
    self.apply_gradients(model, gradients)

set_state(states)

Load a saved state dict into an optimizer instance.

The counterpart to this method is the get_state one.

Parameters:

Name Type Description Default
states Dict[str, Any]

Dict storing values to assign to this optimizer's inner state variables (i.e. those from its modules).

required

Raises:

Type Description
KeyError

If the received states do not match the expected config, whether because a module is missing or one of its states is missing. In both cases, the Optimizer's states will be reverted to their values prior to the failed call to this method.

RuntimeError

If a KeyError was raised both when trying to apply the input state and when trying to revert the states to their initial values after that first error was raised. This should never happen and indicates a source code error in a wrapped module, or even in this class.

Source code in declearn/optimizer/_base.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
def set_state(
    self,
    states: Dict[str, Any],
) -> None:
    """Load a saved state dict into an optimizer instance.

    The counterpart to this method is the `get_state` one.

    Parameters
    ----------
    states: dict[str, any]
        Dict storing values to assign to this optimizer's inner
        state variables (i.e. those from its modules).

    Raises
    ------
    KeyError
        If the received states do not match the expected config,
        whether because a module is missing or one of its states
        is missing.
        In both cases, the Optimizer's states will be reverted
        to their values prior to the failed call to this method.
    RuntimeError
        If a KeyError was raised both when trying to apply the
        input `state` and when trying to revert the states to
        their initial values after that first error was raised.
        This should never happen and indicates a source code
        error in a wrapped module, or even in this class.
    """
    if "modules" not in states:
        raise KeyError("Optimizer input 'states' lack a 'modules' field.")
    if len(states["modules"]) != len(self.modules):
        raise KeyError("Optimizer 'states' do not match modules config.")
    initial = self.get_state()
    try:
        self._set_state(states)
    except KeyError as exc:
        try:
            self._set_state(initial)
        except KeyError as exc_bis:
            raise RuntimeError(
                "`Optimizer.set_state` failed to restore initial states "
                "after a KeyError was raised during states' attempted "
                "update. There probably is a source code error with one "
                "of the wrapped modules.\n"
                f"Error when reverting states: {exc_bis}\n"
                f"Initial update error: {exc}\n"
            ) from exc_bis
        raise exc

start_round()

Perform any required action at the start of a training round.

This method calls the on_round_start callback of each and every wrapped Regularizer which may be used to regulate some internal state variables.

Source code in declearn/optimizer/_base.py
370
371
372
373
374
375
376
377
378
379
380
def start_round(
    self,
) -> None:
    """Perform any required action at the start of a training round.

    This method calls the `on_round_start` callback of each and every
    wrapped `Regularizer` which may be used to regulate some internal
    state variables.
    """
    for regularizer in self.regularizers:
        regularizer.on_round_start()