Skip to content

declearn.optimizer.modules.OptiModule

Bases: Generic[AuxVarT]

Abstract class defining an API to implement gradients adaptation tools.

The aim of this abstraction (which itself operates on the Vector abstraction, so as to provide framework-agnostic algorithms) is to enable implementing unitary gradients-adaptation bricks that can easily and modularly be composed into complex algorithms.

The declearn.optimizer.Optimizer class defines the main tools and routines for computing and applying gradients-based updates. OptiModule instances are designed to be "plugged in" such an Optimizer instance to add intermediary operations between the moment gradients are obtained and that when they are applied as updates. Note that learning-rate use and optional (decoupled) weight-decay mechanisms are implemented at Optimizer level.

Abstract

The following attribute and method require to be overridden by any non-abstract child class of OptiModule:

  • name: str class attribute Name identifier of the class (should be unique across existing OptiModule classes). Also used for automatic types-registration of the class (see Inheritance section below).
  • run(gradients: Vector) -> Vector: Apply an adaptation algorithm to input gradients and return them. This is the main method for any OptiModule.

Overridable

The following methods may be overridden to implement information- passing and parallel behaviors between client/server module pairs. As defined at OptiModule level, they have no effect and may thus be safely ignored when implementing self-contained algorithms.

  • collect_aux_var() -> Optional[AuxVar]: Emit an AuxVar instance holding auxiliary variables, that may be shared with peers, aggregated across them, and eventually processed by a counterpart module on the other side of the client/server relationship.
  • process_aux_var(AuxVar) -> None: Process auxiliary variables received from a counterpart module on the other side of the client/server relationship.
  • aux_name: optional[str] class attribute, default=None Name to use when sending or receiving auxiliary variables between synchronous client/server modules, that therefore need to share the same aux_name.
  • auxvar_cls: optional[type[AuxVar]] class attribute, default=None Type of AuxVar used by this module (defining the actual signature of collect_aux_var and process_aux_var).

Inheritance

When a subclass inheriting from OptiModule is declared, it is automatically registered under the "OptiModule" group using its class-attribute name. This can be prevented by adding register=False to the inheritance specs (e.g. class MyCls(OptiModule, register=False)). See declearn.utils.register_type for details on types registration.

Source code in declearn/optimizer/modules/_api.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
@create_types_registry
class OptiModule(Generic[AuxVarT], metaclass=abc.ABCMeta):
    """Abstract class defining an API to implement gradients adaptation tools.

    The aim of this abstraction (which itself operates on the Vector
    abstraction, so as to provide framework-agnostic algorithms) is
    to enable implementing unitary gradients-adaptation bricks that
    can easily and modularly be composed into complex algorithms.

    The `declearn.optimizer.Optimizer` class defines the main tools
    and routines for computing and applying gradients-based updates.
    `OptiModule` instances are designed to be "plugged in" such an
    `Optimizer` instance to add intermediary operations between the
    moment gradients are obtained and that when they are applied as
    updates. Note that learning-rate use and optional (decoupled)
    weight-decay mechanisms are implemented at `Optimizer` level.

    Abstract
    --------
    The following attribute and method require to be overridden
    by any non-abstract child class of `OptiModule`:

    - name: str class attribute
        Name identifier of the class (should be unique across existing
        OptiModule classes). Also used for automatic types-registration
        of the class (see `Inheritance` section below).
    - run(gradients: Vector) -> Vector:
        Apply an adaptation algorithm to input gradients and return
        them. This is the main method for any `OptiModule`.

    Overridable
    -----------
    The following methods may be overridden to implement information-
    passing and parallel behaviors between client/server module pairs.
    As defined at `OptiModule` level, they have no effect and may thus
    be safely ignored when implementing self-contained algorithms.

    - collect_aux_var() -> Optional[AuxVar]:
        Emit an `AuxVar` instance holding auxiliary variables,
        that may be shared with peers, aggregated across them,
        and eventually processed by a counterpart module on the
        other side of the client/server relationship.
    - process_aux_var(AuxVar) -> None:
        Process auxiliary variables received from a counterpart
        module on the other side of the client/server relationship.
    - aux_name: optional[str] class attribute, default=None
        Name to use when sending or receiving auxiliary variables
        between synchronous client/server modules, that therefore
        need to share the *same* `aux_name`.
    - auxvar_cls: optional[type[AuxVar]] class attribute, default=None
        Type of `AuxVar` used by this module (defining the actual
        signature of `collect_aux_var` and `process_aux_var`).

    Inheritance
    -----------
    When a subclass inheriting from `OptiModule` is declared, it is
    automatically registered under the "OptiModule" group using its
    class-attribute `name`. This can be prevented by adding `register=False`
    to the inheritance specs (e.g. `class MyCls(OptiModule, register=False)`).
    See `declearn.utils.register_type` for details on types registration.
    """

    name: ClassVar[str] = NotImplemented
    """Name identifier of the class, unique across OptiModule classes."""

    aux_name: ClassVar[Optional[str]] = None
    """Optional aux-var-sharing identifier of the class.

    This name may be shared by a pair of OptiModule classes, designed
    to operate on the client and server side respectively. It should
    be unique to that pair of classes across all OptiModule classes.
    """

    auxvar_cls: Optional[Type[AuxVar]] = None
    """Optional `AuxVar` subtype used by this module and its counterpart."""

    def __init_subclass__(
        cls,
        register: bool = True,
        **kwargs: Any,
    ) -> None:
        """Automatically type-register OptiModule subclasses."""
        super().__init_subclass__(**kwargs)
        if register:
            register_type(cls, cls.name, group="OptiModule")

    @abc.abstractmethod
    def run(
        self,
        gradients: Vector[T],
    ) -> Vector[T]:
        """Apply the module's algorithm to input gradients.

        Please refer to the module's main docstring for details
        on the implemented algorithm and the way it transforms
        input gradients.

        Parameters
        ----------
        gradients: Vector
            Input gradients that are to be processed and updated.

        Returns
        -------
        gradients: Vector
            Modified input gradients. The output Vector should be
            fully compatible with the input one - only the values
            of the wrapped coefficients may have changed.
        """

    def collect_aux_var(
        self,
    ) -> Optional[AuxVarT]:
        """Return auxiliary variables that need to be shared between nodes.

        Returns
        -------
        aux_var: Optional[AuxVar]
            Optional `AuxVar` instance holding auxiliary variables that
            are to be shared with a counterpart OptiModule on the other
            side of the client-server relationship.

        Notes
        -----
        The calling context depend ons whether the module is part of a
        client's optimizer or of the server's one:

        - Client:
            - `collect_aux_var` is expected to happen after taking a series
              of local optimization steps, before sending the local updates
              to the server for aggregation and further processing.
        - Server:
            - `collect_aux_var` is expected to happen when the global model
              weights are ready to be shared with clients, i.e. either at
              the very end or very beginning of a training round.
        """
        return None

    def process_aux_var(
        self,
        aux_var: AuxVarT,
    ) -> None:
        """Update this module based on received shared auxiliary variables.

        Parameters
        ----------
        aux_var:
            Auxiliary variables that are to be processed by this module,
            emitted by a counterpart OptiModule on the other side of the
            client-server relationship.

        Notes
        -----
        The calling context depends on whether the module is part of a
        client's optimizer or of the server's one:

        - Client:
            - `process_aux_var` is expected to happen at the beginning of
              a training round to define gradients' processing during the
              local optimization steps taken through that round.
        - Server:
            - `process_aux_var` is expected to happen upon receiving local
              updates (and, thus, aux_var), before the aggregated updates
              are computed and passed through the server optimizer (which
              comprises this module).

        Raises
        ------
        KeyError
            If received auxiliary variables lack some required data.
        NotImplementedError
            If auxiliary variables are passed to a module that is not meant
            to receive any.
        TypeError
            If `aux_var` or one of its fields has unproper type.
        """
        if aux_var is not None:  # pragma: no cover
            raise NotImplementedError(
                f"'{self.__class__.__name__}.process_aux_var' was called, but"
                " this class is not designed to receive auxiliary variables."
            )

    def get_config(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable dict with this module's parameters.

        The counterpart to this method is the `from_config` classmethod.
        To access the module's inner states, see the `get_state` method.

        Returns
        -------
        config: Dict[str, Any]
            JSON-serializable dict storing this module's instantiation
            configuration.
        """
        return {}

    @classmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate an OptiModule from its configuration dict.

        The counterpart to this classmethod is the `get_config` method.
        To restore the module's inner states, see its `get_state` method.

        Parameters
        ----------
        config: dict[str, Any]
            Dict storing the module's instantiation configuration.
            This must match the target subclass's requirements.

        Raises
        ------
        KeyError
            If the provided `config` lacks some required parameters
            and/or contains some unused ones.
        """
        return cls(**config)

    @staticmethod
    def from_specs(
        name: str,
        config: Dict[str, Any],
    ) -> "OptiModule":
        """Instantiate an OptiModule from its specifications.

        Parameters
        ----------
        name: str
            Name based on which the module can be retrieved.
            Available as a class attribute.
        config: dict[str, any]
            Configuration dict of the module, that is to be
            passed to its `from_config` class constructor.
        """
        cls = access_registered(name, group="OptiModule")
        assert issubclass(cls, OptiModule)  # force-tested by access_registered
        return cls.from_config(config)

    def get_state(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable dict with this module's state(s).

        The counterpart to this method is the `set_state` one.

        Returns
        -------
        state: Dict[str, Any]
            JSON-serializable dict storing this module's inner state
            variables.
        """
        return {}

    def set_state(
        self,
        state: Dict[str, Any],
    ) -> None:
        """Load a state dict into an instantiated module.

        The counterpart to this method is the `get_state` one.

        Parameters
        ----------
        state: dict[str, any]
            Dict storing values to assign to this module's inner
            state variables.

        Raises
        ------
        KeyError
            If an expected state variable is missing from `state`.
        """
        if state:
            raise KeyError(
                f"'{self.__class__.__name__}.set_state' received some data, "
                "but it is not implemented to actually use any."
            )

aux_name: ClassVar[Optional[str]] = None class-attribute

Optional aux-var-sharing identifier of the class.

This name may be shared by a pair of OptiModule classes, designed to operate on the client and server side respectively. It should be unique to that pair of classes across all OptiModule classes.

auxvar_cls: Optional[Type[AuxVar]] = None class-attribute

Optional AuxVar subtype used by this module and its counterpart.

name: ClassVar[str] = NotImplemented class-attribute

Name identifier of the class, unique across OptiModule classes.

__init_subclass__(register=True, **kwargs)

Automatically type-register OptiModule subclasses.

Source code in declearn/optimizer/modules/_api.py
144
145
146
147
148
149
150
151
152
def __init_subclass__(
    cls,
    register: bool = True,
    **kwargs: Any,
) -> None:
    """Automatically type-register OptiModule subclasses."""
    super().__init_subclass__(**kwargs)
    if register:
        register_type(cls, cls.name, group="OptiModule")

collect_aux_var()

Return auxiliary variables that need to be shared between nodes.

Returns:

Name Type Description
aux_var Optional[AuxVar]

Optional AuxVar instance holding auxiliary variables that are to be shared with a counterpart OptiModule on the other side of the client-server relationship.

Notes

The calling context depend ons whether the module is part of a client's optimizer or of the server's one:

  • Client:
    • collect_aux_var is expected to happen after taking a series of local optimization steps, before sending the local updates to the server for aggregation and further processing.
  • Server:
    • collect_aux_var is expected to happen when the global model weights are ready to be shared with clients, i.e. either at the very end or very beginning of a training round.
Source code in declearn/optimizer/modules/_api.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def collect_aux_var(
    self,
) -> Optional[AuxVarT]:
    """Return auxiliary variables that need to be shared between nodes.

    Returns
    -------
    aux_var: Optional[AuxVar]
        Optional `AuxVar` instance holding auxiliary variables that
        are to be shared with a counterpart OptiModule on the other
        side of the client-server relationship.

    Notes
    -----
    The calling context depend ons whether the module is part of a
    client's optimizer or of the server's one:

    - Client:
        - `collect_aux_var` is expected to happen after taking a series
          of local optimization steps, before sending the local updates
          to the server for aggregation and further processing.
    - Server:
        - `collect_aux_var` is expected to happen when the global model
          weights are ready to be shared with clients, i.e. either at
          the very end or very beginning of a training round.
    """
    return None

from_config(config) classmethod

Instantiate an OptiModule from its configuration dict.

The counterpart to this classmethod is the get_config method. To restore the module's inner states, see its get_state method.

Parameters:

Name Type Description Default
config Dict[str, Any]

Dict storing the module's instantiation configuration. This must match the target subclass's requirements.

required

Raises:

Type Description
KeyError

If the provided config lacks some required parameters and/or contains some unused ones.

Source code in declearn/optimizer/modules/_api.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
@classmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate an OptiModule from its configuration dict.

    The counterpart to this classmethod is the `get_config` method.
    To restore the module's inner states, see its `get_state` method.

    Parameters
    ----------
    config: dict[str, Any]
        Dict storing the module's instantiation configuration.
        This must match the target subclass's requirements.

    Raises
    ------
    KeyError
        If the provided `config` lacks some required parameters
        and/or contains some unused ones.
    """
    return cls(**config)

from_specs(name, config) staticmethod

Instantiate an OptiModule from its specifications.

Parameters:

Name Type Description Default
name str

Name based on which the module can be retrieved. Available as a class attribute.

required
config Dict[str, Any]

Configuration dict of the module, that is to be passed to its from_config class constructor.

required
Source code in declearn/optimizer/modules/_api.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
@staticmethod
def from_specs(
    name: str,
    config: Dict[str, Any],
) -> "OptiModule":
    """Instantiate an OptiModule from its specifications.

    Parameters
    ----------
    name: str
        Name based on which the module can be retrieved.
        Available as a class attribute.
    config: dict[str, any]
        Configuration dict of the module, that is to be
        passed to its `from_config` class constructor.
    """
    cls = access_registered(name, group="OptiModule")
    assert issubclass(cls, OptiModule)  # force-tested by access_registered
    return cls.from_config(config)

get_config()

Return a JSON-serializable dict with this module's parameters.

The counterpart to this method is the from_config classmethod. To access the module's inner states, see the get_state method.

Returns:

Name Type Description
config Dict[str, Any]

JSON-serializable dict storing this module's instantiation configuration.

Source code in declearn/optimizer/modules/_api.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def get_config(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable dict with this module's parameters.

    The counterpart to this method is the `from_config` classmethod.
    To access the module's inner states, see the `get_state` method.

    Returns
    -------
    config: Dict[str, Any]
        JSON-serializable dict storing this module's instantiation
        configuration.
    """
    return {}

get_state()

Return a JSON-serializable dict with this module's state(s).

The counterpart to this method is the set_state one.

Returns:

Name Type Description
state Dict[str, Any]

JSON-serializable dict storing this module's inner state variables.

Source code in declearn/optimizer/modules/_api.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def get_state(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable dict with this module's state(s).

    The counterpart to this method is the `set_state` one.

    Returns
    -------
    state: Dict[str, Any]
        JSON-serializable dict storing this module's inner state
        variables.
    """
    return {}

process_aux_var(aux_var)

Update this module based on received shared auxiliary variables.

Parameters:

Name Type Description Default
aux_var AuxVarT

Auxiliary variables that are to be processed by this module, emitted by a counterpart OptiModule on the other side of the client-server relationship.

required

Notes

The calling context depends on whether the module is part of a client's optimizer or of the server's one:

  • Client:
    • process_aux_var is expected to happen at the beginning of a training round to define gradients' processing during the local optimization steps taken through that round.
  • Server:
    • process_aux_var is expected to happen upon receiving local updates (and, thus, aux_var), before the aggregated updates are computed and passed through the server optimizer (which comprises this module).

Raises:

Type Description
KeyError

If received auxiliary variables lack some required data.

NotImplementedError

If auxiliary variables are passed to a module that is not meant to receive any.

TypeError

If aux_var or one of its fields has unproper type.

Source code in declearn/optimizer/modules/_api.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def process_aux_var(
    self,
    aux_var: AuxVarT,
) -> None:
    """Update this module based on received shared auxiliary variables.

    Parameters
    ----------
    aux_var:
        Auxiliary variables that are to be processed by this module,
        emitted by a counterpart OptiModule on the other side of the
        client-server relationship.

    Notes
    -----
    The calling context depends on whether the module is part of a
    client's optimizer or of the server's one:

    - Client:
        - `process_aux_var` is expected to happen at the beginning of
          a training round to define gradients' processing during the
          local optimization steps taken through that round.
    - Server:
        - `process_aux_var` is expected to happen upon receiving local
          updates (and, thus, aux_var), before the aggregated updates
          are computed and passed through the server optimizer (which
          comprises this module).

    Raises
    ------
    KeyError
        If received auxiliary variables lack some required data.
    NotImplementedError
        If auxiliary variables are passed to a module that is not meant
        to receive any.
    TypeError
        If `aux_var` or one of its fields has unproper type.
    """
    if aux_var is not None:  # pragma: no cover
        raise NotImplementedError(
            f"'{self.__class__.__name__}.process_aux_var' was called, but"
            " this class is not designed to receive auxiliary variables."
        )

run(gradients) abstractmethod

Apply the module's algorithm to input gradients.

Please refer to the module's main docstring for details on the implemented algorithm and the way it transforms input gradients.

Parameters:

Name Type Description Default
gradients Vector[T]

Input gradients that are to be processed and updated.

required

Returns:

Name Type Description
gradients Vector

Modified input gradients. The output Vector should be fully compatible with the input one - only the values of the wrapped coefficients may have changed.

Source code in declearn/optimizer/modules/_api.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@abc.abstractmethod
def run(
    self,
    gradients: Vector[T],
) -> Vector[T]:
    """Apply the module's algorithm to input gradients.

    Please refer to the module's main docstring for details
    on the implemented algorithm and the way it transforms
    input gradients.

    Parameters
    ----------
    gradients: Vector
        Input gradients that are to be processed and updated.

    Returns
    -------
    gradients: Vector
        Modified input gradients. The output Vector should be
        fully compatible with the input one - only the values
        of the wrapped coefficients may have changed.
    """

set_state(state)

Load a state dict into an instantiated module.

The counterpart to this method is the get_state one.

Parameters:

Name Type Description Default
state Dict[str, Any]

Dict storing values to assign to this module's inner state variables.

required

Raises:

Type Description
KeyError

If an expected state variable is missing from state.

Source code in declearn/optimizer/modules/_api.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def set_state(
    self,
    state: Dict[str, Any],
) -> None:
    """Load a state dict into an instantiated module.

    The counterpart to this method is the `get_state` one.

    Parameters
    ----------
    state: dict[str, any]
        Dict storing values to assign to this module's inner
        state variables.

    Raises
    ------
    KeyError
        If an expected state variable is missing from `state`.
    """
    if state:
        raise KeyError(
            f"'{self.__class__.__name__}.set_state' received some data, "
            "but it is not implemented to actually use any."
        )