Skip to content

declearn.main.config.FLOptimConfig

Bases: TomlConfig

Container dataclass for a federated optimization strategy.

This dataclass is designed to wrap together an Aggregator and a pair of Optimizer instances, that are respectively meant to be used by the server and the clients. The main point of this class is to provide with TOML-parsing capabilities, so that a strategy can be specified via a TOML file, which is expected to be simpler to edit and maintain than direct Python code by end-users.

It is designed to be used by the orchestrating server in the case of a centralized federated learning process.

Fields

  • client_opt: Optimizer Optimizer to be used by clients (that each hold a copy) so as to conduct the step-wise local model updates.
  • server_opt: Optimizer, default=Optimizer(lrate=1.0) Optimizer to be used by the server so as to conduct a round-wise global model update based on the aggregated client updates.
  • aggregator: Aggregator, default=AverageAggregator() Client weights aggregator to be used by the server so as to conduct the round-wise aggregation of client udpates.

Notes

The aggregator field may be specified in a variety of ways:

  • a single string may specify the registered name of the class constructor to use. In TOML, use aggregator = "<name>" outside of any section.
  • a serialization dict, that specifies the registration name, and optionally a registration group and/or arguments to be passed to the class constructor. In TOML, use an [aggregator] section with a name = "<name>" field and any other fields you wish to pass. Kwargs may either be grouped into a dedicated [aggregator.config] sub-section or provided as fields of the main aggregator section.

The client_opt and server_opt fields may be specified as:

  • a single float, specifying the learning rate for vanilla SGD. In TOML, use client_opt = 0.001 for Optimizer(lrate=0.001).
  • a dict of keyword arguments for declearn.optimizer.Optimizer. In TOML, use a [client_opt] section with fields specifying the input parameters you wish to pass to the constructor.

Instantiation classmethods

  • from_toml: Instantiate by parsing a TOML configuration file.
  • from_params: Instantiate by parsing inputs dicts (or objects).
Source code in declearn/main/config/_strategy.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@dataclasses.dataclass
class FLOptimConfig(TomlConfig):
    """Container dataclass for a federated optimization strategy.

    This dataclass is designed to wrap together an Aggregator and
    a pair of Optimizer instances, that are respectively meant to
    be used by the server and the clients. The main point of this
    class is to provide with TOML-parsing capabilities, so that a
    strategy can be specified via a TOML file, which is expected
    to be simpler to edit and maintain than direct Python code by
    end-users.

    It is designed to be used by the orchestrating server in the
    case of a centralized federated learning process.

    Fields
    ------
    - client_opt: Optimizer
        Optimizer to be used by clients (that each hold a copy)
        so as to conduct the step-wise local model updates.
    - server_opt: Optimizer, default=Optimizer(lrate=1.0)
        Optimizer to be used by the server so as to conduct a
        round-wise global model update based on the aggregated
        client updates.
    - aggregator: Aggregator, default=AverageAggregator()
        Client weights aggregator to be used by the server so as
        to conduct the round-wise aggregation of client udpates.

    Notes
    -----
    The `aggregator` field may be specified in a variety of ways:

    - a single string may specify the registered name of the class
    constructor to use.
    In TOML, use `aggregator = "<name>"` outside of any section.
    - a serialization dict, that specifies the registration `name`,
    and optionally a registration `group` and/or arguments to be
    passed to the class constructor.
    In TOML, use an `[aggregator]` section with a `name = "<name>"`
    field and any other fields you wish to pass. Kwargs may either
    be grouped into a dedicated `[aggregator.config]` sub-section
    or provided as fields of the main aggregator section.

    The `client_opt` and `server_opt` fields may be specified as:

    - a single float, specifying the learning rate for vanilla SGD.
    In TOML, use `client_opt = 0.001` for `Optimizer(lrate=0.001)`.
    - a dict of keyword arguments for `declearn.optimizer.Optimizer`.
    In TOML, use a `[client_opt]` section with fields specifying
    the input parameters you wish to pass to the constructor.

    Instantiation classmethods
    --------------------------
    - from_toml:
        Instantiate by parsing a TOML configuration file.
    - from_params:
        Instantiate by parsing inputs dicts (or objects).
    """

    client_opt: Optimizer
    server_opt: Optimizer = dataclasses.field(
        default_factory=functools.partial(Optimizer, lrate=1.0)
    )
    aggregator: Aggregator = dataclasses.field(
        default_factory=AveragingAggregator
    )

    @classmethod
    def parse_client_opt(
        cls,
        field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
        inputs: Union[float, Dict[str, Any], Optimizer],
    ) -> Optimizer:
        """Field-specific parser to instantiate the client-side Optimizer."""
        return cls._parse_optimizer(field, inputs)

    @classmethod
    def parse_server_opt(
        cls,
        field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
        inputs: Union[float, Dict[str, Any], Optimizer, None],
    ) -> Optimizer:
        """Field-specific parser to instantiate the server-side Optimizer."""
        return cls._parse_optimizer(field, inputs)

    @classmethod
    def _parse_optimizer(
        cls,
        field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
        inputs: Union[float, Dict[str, Any], Optimizer, None],
    ) -> Optimizer:
        """Field-specific parser to instantiate an Optimizer."""
        # Delegate to the default parser for most cases.
        if inputs is None or isinstance(inputs, (dict, Optimizer)):
            return cls.default_parser(field, inputs)
        # Case when provided with a single int: treat it as lrate for base SGD.
        if isinstance(inputs, float):
            return Optimizer(lrate=inputs)
        # Otherwise, raise a TypeError as inputs are unsupported.
        raise TypeError(f"Unsupported inputs type for field '{field.name}'.")

    @classmethod
    def parse_aggregator(
        cls,
        field: dataclasses.Field,  # future: dataclasses.Field[Aggregator]
        inputs: Union[str, Dict[str, Any], Aggregator, None],
    ) -> Aggregator:
        """Field-specific parser to instantiate an Aggregator.

        This method supports specifying `aggregator`:

        - as a str, used to retrieve a registered Aggregator class
        - as a dict, parsed a serialized Aggregator configuration:
            - name: str used to retrieve a registered Aggregator class
            - (opt.) group: str used to retrieve the registered class
            - (opt.) config: dict specifying kwargs for the constructor
            - any other field will be added to the `config` kwargs dict
        - as None (or missing kwarg), using default AveragingAggregator()
        """
        # Case when using the default value: delegate to the default parser.
        if inputs is None:
            return cls.default_parser(field, inputs)
        # Case when the input is a valid instance: return it.
        if isinstance(inputs, Aggregator):
            return inputs
        # Case when provided with a string: retrieve the class and instantiate.
        if isinstance(inputs, str):
            try:
                # Note: subclass-checking is performed by `access_registered`.
                agg_cls = access_registered(inputs, group="Aggregator")
            except KeyError as exc:
                raise TypeError(
                    f"Failed to retrieve Aggregator class from name '{inputs}'"
                ) from exc
            return agg_cls()
        # Case when provided with a dict: check/fix formatting and deserialize.
        if isinstance(inputs, dict):
            if "name" not in inputs:
                raise TypeError(
                    "Wrong format for Aggregator serialized config: missing "
                    "'name' field."
                )
            inputs.setdefault("group", "Aggregator")
            inputs.setdefault("config", {})
            for key in list(inputs):
                if key not in ("name", "group", "config"):
                    inputs["config"][key] = inputs.pop(key)
            obj = deserialize_object(inputs)  # type: ignore
            if not isinstance(obj, Aggregator):
                raise TypeError(
                    "Input specifications for 'aggregator' resulted in a non-"
                    f"Aggregator object with type '{type(obj)}'."
                )
            return obj
        # Otherwise, raise a TypeError as inputs are unsupported.
        raise TypeError("Unsupported inputs type for field 'aggregator'.")

parse_aggregator(field, inputs) classmethod

Field-specific parser to instantiate an Aggregator.

This method supports specifying aggregator:

  • as a str, used to retrieve a registered Aggregator class
  • as a dict, parsed a serialized Aggregator configuration:
    • name: str used to retrieve a registered Aggregator class
    • (opt.) group: str used to retrieve the registered class
    • (opt.) config: dict specifying kwargs for the constructor
    • any other field will be added to the config kwargs dict
  • as None (or missing kwarg), using default AveragingAggregator()
Source code in declearn/main/config/_strategy.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@classmethod
def parse_aggregator(
    cls,
    field: dataclasses.Field,  # future: dataclasses.Field[Aggregator]
    inputs: Union[str, Dict[str, Any], Aggregator, None],
) -> Aggregator:
    """Field-specific parser to instantiate an Aggregator.

    This method supports specifying `aggregator`:

    - as a str, used to retrieve a registered Aggregator class
    - as a dict, parsed a serialized Aggregator configuration:
        - name: str used to retrieve a registered Aggregator class
        - (opt.) group: str used to retrieve the registered class
        - (opt.) config: dict specifying kwargs for the constructor
        - any other field will be added to the `config` kwargs dict
    - as None (or missing kwarg), using default AveragingAggregator()
    """
    # Case when using the default value: delegate to the default parser.
    if inputs is None:
        return cls.default_parser(field, inputs)
    # Case when the input is a valid instance: return it.
    if isinstance(inputs, Aggregator):
        return inputs
    # Case when provided with a string: retrieve the class and instantiate.
    if isinstance(inputs, str):
        try:
            # Note: subclass-checking is performed by `access_registered`.
            agg_cls = access_registered(inputs, group="Aggregator")
        except KeyError as exc:
            raise TypeError(
                f"Failed to retrieve Aggregator class from name '{inputs}'"
            ) from exc
        return agg_cls()
    # Case when provided with a dict: check/fix formatting and deserialize.
    if isinstance(inputs, dict):
        if "name" not in inputs:
            raise TypeError(
                "Wrong format for Aggregator serialized config: missing "
                "'name' field."
            )
        inputs.setdefault("group", "Aggregator")
        inputs.setdefault("config", {})
        for key in list(inputs):
            if key not in ("name", "group", "config"):
                inputs["config"][key] = inputs.pop(key)
        obj = deserialize_object(inputs)  # type: ignore
        if not isinstance(obj, Aggregator):
            raise TypeError(
                "Input specifications for 'aggregator' resulted in a non-"
                f"Aggregator object with type '{type(obj)}'."
            )
        return obj
    # Otherwise, raise a TypeError as inputs are unsupported.
    raise TypeError("Unsupported inputs type for field 'aggregator'.")

parse_client_opt(field, inputs) classmethod

Field-specific parser to instantiate the client-side Optimizer.

Source code in declearn/main/config/_strategy.py
102
103
104
105
106
107
108
109
@classmethod
def parse_client_opt(
    cls,
    field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
    inputs: Union[float, Dict[str, Any], Optimizer],
) -> Optimizer:
    """Field-specific parser to instantiate the client-side Optimizer."""
    return cls._parse_optimizer(field, inputs)

parse_server_opt(field, inputs) classmethod

Field-specific parser to instantiate the server-side Optimizer.

Source code in declearn/main/config/_strategy.py
111
112
113
114
115
116
117
118
@classmethod
def parse_server_opt(
    cls,
    field: dataclasses.Field,  # future: dataclasses.Field[Optimizer]
    inputs: Union[float, Dict[str, Any], Optimizer, None],
) -> Optimizer:
    """Field-specific parser to instantiate the server-side Optimizer."""
    return cls._parse_optimizer(field, inputs)