Skip to content

declearn.optimizer.modules.ScaffoldServerModule

Bases: OptiModule

Server-side Stochastic Controlled Averaging (SCAFFOLD) module.

This module is to be added to the optimizer used by a federated- learning server, and expects that the clients' optimizer use its counterpart module: ScaffoldClientModule.

This module implements the following algorithm:

Init(clients):
    state = 0
    s_loc = {client: 0 for client in clients}
Step(grads):
    grads
Send:
    delta = {client: (s_loc[client] - state); client in s_loc}
Receive(s_new = {client: state}):
    s_upd = sum(s_new[client] - s_loc[client]; client in s_new)
    s_loc.update(s_new)
    state += s_upd / len(s_loc)

In other words, this module holds a shared state variable, and a set of client-specific ones, which are zero-valued when created. At the beginning of a training round it sends to each client its delta variable, set to the difference between its current state and the shared one, which is to be applied as a correction term to local gradients. At the end of a training round, aggregated gradients are corrected by substracting the shared state value from them. Finally, updated local states received from clients are recorded, and used to update the shared state variable, so that new delta values can be sent to clients as the next round of training starts.

The SCAFFOLD algorithm is described in reference [1]. The client-side correction of gradients and the computation of updated local states are deferred to ScaffoldClientModule.

References

[1] Karimireddy et al., 2019. SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. https://arxiv.org/abs/1910.06378

Source code in declearn/optimizer/modules/_scaffold.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
class ScaffoldServerModule(OptiModule):
    """Server-side Stochastic Controlled Averaging (SCAFFOLD) module.

    This module is to be added to the optimizer used by a federated-
    learning server, and expects that the clients' optimizer use its
    counterpart module:
    [`ScaffoldClientModule`][declearn.optimizer.modules.ScaffoldClientModule].

    This module implements the following algorithm:

        Init(clients):
            state = 0
            s_loc = {client: 0 for client in clients}
        Step(grads):
            grads
        Send:
            delta = {client: (s_loc[client] - state); client in s_loc}
        Receive(s_new = {client: state}):
            s_upd = sum(s_new[client] - s_loc[client]; client in s_new)
            s_loc.update(s_new)
            state += s_upd / len(s_loc)

    In other words, this module holds a shared state variable, and a
    set of client-specific ones, which are zero-valued when created.
    At the beginning of a training round it sends to each client its
    delta variable, set to the difference between its current state
    and the shared one, which is to be applied as a correction term
    to local gradients. At the end of a training round, aggregated
    gradients are corrected by substracting the shared state value
    from them. Finally, updated local states received from clients
    are recorded, and used to update the shared state variable, so
    that new delta values can be sent to clients as the next round
    of training starts.

    The SCAFFOLD algorithm is described in reference [1].
    The client-side correction of gradients and the computation of
    updated local states are deferred to `ScaffoldClientModule`.

    References
    ----------
    [1] Karimireddy et al., 2019.
        SCAFFOLD: Stochastic Controlled Averaging for Federated Learning.
        https://arxiv.org/abs/1910.06378
    """

    name: ClassVar[str] = "scaffold-server"
    aux_name: ClassVar[str] = "scaffold"

    def __init__(
        self,
        clients: Optional[List[str]] = None,
    ) -> None:
        """Instantiate the server-side SCAFFOLD gradients-correction module.

        Parameters
        ----------
        clients: list[str] or None, default=None
            Optional list of known clients' id strings.

        Notes
        -----
        - If this module is used under a training strategy that has
          participating clients vary across epochs, leaving `clients`
          to None will affect the update rule for the shared state,
          as it uses a (n_participating / n_total_clients) term, the
          divisor of which will be incorrect (at least on the first
          step, potentially on following ones as well).
        - Similarly, listing clients that in fact do not participate
          in training will have side effects on computations.
        """
        self.state = 0.0  # type: Union[Vector, float]
        self.s_loc = {}  # type: Dict[str, Union[Vector, float]]
        if clients:
            self.s_loc = {client: 0.0 for client in clients}

    def get_config(
        self,
    ) -> Dict[str, Any]:
        return {"clients": list(self.s_loc)}

    def run(
        self,
        gradients: Vector,
    ) -> Vector:
        # Note: ScaffoldServer only manages auxiliary variables.
        return gradients

    def collect_aux_var(
        self,
    ) -> Dict[str, Dict[str, Any]]:
        """Return auxiliary variables that need to be shared between nodes.

        Package client-wise `delta = (local_state - shared_state)` variables.

        Returns
        -------
        aux_var:
            JSON-serializable dict of auxiliary variables that are to
            be shared with the client-wise ScaffoldClientModule. This
            dict has a `{client-name: {"delta": value}}` structure.
        """
        # Compute clients' delta variable, package them and return.
        aux_var = {}  # type: Dict[str, Dict[str, Any]]
        for client, state in self.s_loc.items():
            delta = state - self.state
            aux_var[client] = {"delta": delta}
        return aux_var

    def process_aux_var(
        self,
        aux_var: Dict[str, Dict[str, Any]],
    ) -> None:
        """Update this module based on received shared auxiliary variables.

        Collect updated local state variables sent by clients.
        Update the global state variable based on the latter.

        Parameters
        ----------
        aux_var: dict[str, dict[str, any]]
            JSON-serializable dict of auxiliary variables that are to be
            processed by this module before processing global updates.
            This dict should have a `{client-name: {"state": value}}`
            structure.

        Raises
        ------
        KeyError:
            If an expected auxiliary variable is missing.
        TypeError:
            If a variable is of unproper type, or if aux_var
            is not formatted as it should be.
        """
        # Collect updated local states received from Scaffold client modules.
        s_new = {}  # type: Dict[str, Union[Vector, float]]
        for client, c_dict in aux_var.items():
            if not isinstance(c_dict, dict):
                raise TypeError(
                    "ScaffoldServerModule requires auxiliary variables "
                    "to be received as client-wise dictionaries."
                )
            if "state" not in c_dict:
                raise KeyError(
                    "Missing required 'state' key in auxiliary variables "
                    f"received by ScaffoldServerModule from client '{client}'."
                )
            state = c_dict["state"]
            if isinstance(state, float) and state == 0.0:
                # Drop info from clients that have not processed gradients.
                continue
            if isinstance(state, (Vector, float)):
                s_new[client] = state
            else:
                raise TypeError(
                    "Unsupported type for auxiliary variable 'state' "
                    f"received by ScaffoldServerModule from client '{client}'."
                )
        # Update the global and client-wise state variables.
        update = sum(
            state - self.s_loc.get(client, 0.0)
            for client, state in s_new.items()
        )
        self.s_loc.update(s_new)
        update = update / len(self.s_loc)
        self.state = self.state + update

    def get_state(
        self,
    ) -> Dict[str, Any]:
        return {"state": self.state, "s_loc": self.s_loc}

    def set_state(
        self,
        state: Dict[str, Any],
    ) -> None:
        for key in ("state", "s_loc"):
            if key not in state:
                raise KeyError(f"Missing required state variable '{key}'.")
        self.state = state["state"]
        self.s_loc = state["s_loc"]

__init__(clients=None)

Instantiate the server-side SCAFFOLD gradients-correction module.

Parameters:

Name Type Description Default
clients Optional[List[str]]

Optional list of known clients' id strings.

None

Notes

  • If this module is used under a training strategy that has participating clients vary across epochs, leaving clients to None will affect the update rule for the shared state, as it uses a (n_participating / n_total_clients) term, the divisor of which will be incorrect (at least on the first step, potentially on following ones as well).
  • Similarly, listing clients that in fact do not participate in training will have side effects on computations.
Source code in declearn/optimizer/modules/_scaffold.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def __init__(
    self,
    clients: Optional[List[str]] = None,
) -> None:
    """Instantiate the server-side SCAFFOLD gradients-correction module.

    Parameters
    ----------
    clients: list[str] or None, default=None
        Optional list of known clients' id strings.

    Notes
    -----
    - If this module is used under a training strategy that has
      participating clients vary across epochs, leaving `clients`
      to None will affect the update rule for the shared state,
      as it uses a (n_participating / n_total_clients) term, the
      divisor of which will be incorrect (at least on the first
      step, potentially on following ones as well).
    - Similarly, listing clients that in fact do not participate
      in training will have side effects on computations.
    """
    self.state = 0.0  # type: Union[Vector, float]
    self.s_loc = {}  # type: Dict[str, Union[Vector, float]]
    if clients:
        self.s_loc = {client: 0.0 for client in clients}

collect_aux_var()

Return auxiliary variables that need to be shared between nodes.

Package client-wise delta = (local_state - shared_state) variables.

Returns:

Name Type Description
aux_var Dict[str, Dict[str, Any]]

JSON-serializable dict of auxiliary variables that are to be shared with the client-wise ScaffoldClientModule. This dict has a {client-name: {"delta": value}} structure.

Source code in declearn/optimizer/modules/_scaffold.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def collect_aux_var(
    self,
) -> Dict[str, Dict[str, Any]]:
    """Return auxiliary variables that need to be shared between nodes.

    Package client-wise `delta = (local_state - shared_state)` variables.

    Returns
    -------
    aux_var:
        JSON-serializable dict of auxiliary variables that are to
        be shared with the client-wise ScaffoldClientModule. This
        dict has a `{client-name: {"delta": value}}` structure.
    """
    # Compute clients' delta variable, package them and return.
    aux_var = {}  # type: Dict[str, Dict[str, Any]]
    for client, state in self.s_loc.items():
        delta = state - self.state
        aux_var[client] = {"delta": delta}
    return aux_var

process_aux_var(aux_var)

Update this module based on received shared auxiliary variables.

Collect updated local state variables sent by clients. Update the global state variable based on the latter.

Parameters:

Name Type Description Default
aux_var Dict[str, Dict[str, Any]]

JSON-serializable dict of auxiliary variables that are to be processed by this module before processing global updates. This dict should have a {client-name: {"state": value}} structure.

required

Raises:

Type Description
KeyError:

If an expected auxiliary variable is missing.

TypeError:

If a variable is of unproper type, or if aux_var is not formatted as it should be.

Source code in declearn/optimizer/modules/_scaffold.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def process_aux_var(
    self,
    aux_var: Dict[str, Dict[str, Any]],
) -> None:
    """Update this module based on received shared auxiliary variables.

    Collect updated local state variables sent by clients.
    Update the global state variable based on the latter.

    Parameters
    ----------
    aux_var: dict[str, dict[str, any]]
        JSON-serializable dict of auxiliary variables that are to be
        processed by this module before processing global updates.
        This dict should have a `{client-name: {"state": value}}`
        structure.

    Raises
    ------
    KeyError:
        If an expected auxiliary variable is missing.
    TypeError:
        If a variable is of unproper type, or if aux_var
        is not formatted as it should be.
    """
    # Collect updated local states received from Scaffold client modules.
    s_new = {}  # type: Dict[str, Union[Vector, float]]
    for client, c_dict in aux_var.items():
        if not isinstance(c_dict, dict):
            raise TypeError(
                "ScaffoldServerModule requires auxiliary variables "
                "to be received as client-wise dictionaries."
            )
        if "state" not in c_dict:
            raise KeyError(
                "Missing required 'state' key in auxiliary variables "
                f"received by ScaffoldServerModule from client '{client}'."
            )
        state = c_dict["state"]
        if isinstance(state, float) and state == 0.0:
            # Drop info from clients that have not processed gradients.
            continue
        if isinstance(state, (Vector, float)):
            s_new[client] = state
        else:
            raise TypeError(
                "Unsupported type for auxiliary variable 'state' "
                f"received by ScaffoldServerModule from client '{client}'."
            )
    # Update the global and client-wise state variables.
    update = sum(
        state - self.s_loc.get(client, 0.0)
        for client, state in s_new.items()
    )
    self.s_loc.update(s_new)
    update = update / len(self.s_loc)
    self.state = self.state + update