Skip to content

declearn.optimizer.modules.ScaffoldServerModule

Bases: OptiModule[ScaffoldAuxVar]

Server-side Stochastic Controlled Averaging (SCAFFOLD) module.

This module is to be added to the optimizer used by a federated- learning server, and expects that the clients' optimizer use its counterpart module: ScaffoldClientModule.

This module implements the following algorithm:

Init:
    s_state = 0
    clients = {}
Step(grads):
    grads
Send -> state:
    state = s_state / min(len(clients), 1)
Receive(delta=sum(state_i^t+1 - state_i^t), clients=set{uuid}):
    s_state += delta
    clients.update(clients)

In other words, this module holds a global state variable, set to zero at instantiation. At the beginning of a training round it sends it to all clients so that they can derive a correction term for their processed gradients, based on a local state they hold. At the end of a training round, client-wise local state updates are sum-aggregated into an update for the global state variable, which will be sent to clients at the start of the next round. The sent state is always the average of the last known local state from each and every known client.

The SCAFFOLD algorithm is described in reference [1]. The client-side correction of gradients and the computation of updated local states are deferred to ScaffoldClientModule.

References

[1] Karimireddy et al., 2019. SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. https://arxiv.org/abs/1910.06378

Source code in declearn/optimizer/modules/_scaffold.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
class ScaffoldServerModule(OptiModule[ScaffoldAuxVar]):
    """Server-side Stochastic Controlled Averaging (SCAFFOLD) module.

    This module is to be added to the optimizer used by a federated-
    learning server, and expects that the clients' optimizer use its
    counterpart module:
    [`ScaffoldClientModule`][declearn.optimizer.modules.ScaffoldClientModule].

    This module implements the following algorithm:

        Init:
            s_state = 0
            clients = {}
        Step(grads):
            grads
        Send -> state:
            state = s_state / min(len(clients), 1)
        Receive(delta=sum(state_i^t+1 - state_i^t), clients=set{uuid}):
            s_state += delta
            clients.update(clients)

    In other words, this module holds a global state variable, set
    to zero at instantiation. At the beginning of a training round
    it sends it to all clients so that they can derive a correction
    term for their processed gradients, based on a local state they
    hold. At the end of a training round, client-wise local state
    updates are sum-aggregated into an update for the global state
    variable, which will be sent to clients at the start of the next
    round. The sent state is always the average of the last known
    local state from each and every known client.

    The SCAFFOLD algorithm is described in reference [1].
    The client-side correction of gradients and the computation of
    updated local states are deferred to `ScaffoldClientModule`.

    References
    ----------
    [1] Karimireddy et al., 2019.
        SCAFFOLD: Stochastic Controlled Averaging for Federated Learning.
        https://arxiv.org/abs/1910.06378
    """

    name = "scaffold-server"
    aux_name = "scaffold"
    auxvar_cls = ScaffoldAuxVar

    def __init__(
        self,
        clients: Optional[List[str]] = None,
    ) -> None:
        """Instantiate the server-side SCAFFOLD gradients-correction module.

        Parameters
        ----------
        clients:
            DEPRECATED and unused starting with declearn 2.4.
            Optional list of known clients' id strings.
        """
        self.s_state = 0.0  # type: Union[Vector, float]
        self.clients = set()  # type: Set[str]
        if clients:  # pragma: no cover
            warnings.warn(
                "ScaffoldServerModule's 'clients' argument has been deprecated"
                " as of declearn v2.4, and no longer has any effect. It will"
                " be removed in declearn 2.6 and/or 3.0.",
                DeprecationWarning,
            )

    def run(
        self,
        gradients: Vector,
    ) -> Vector:
        # Note: ScaffoldServer only manages auxiliary variables.
        return gradients

    def collect_aux_var(
        self,
    ) -> ScaffoldAuxVar:
        """Return auxiliary variables that need to be shared between nodes.

        Returns
        -------
        aux_var:
            `ScaffoldAuxVar` instance holding auxiliary variables that are
            to be shared with clients' `ScaffoldClientModule` instances.
        """
        # When un-initialized, send lightweight information.
        if not self.clients:
            return ScaffoldAuxVar(state=0.0)
        # Otherwise, compute and return the current shared state.
        return ScaffoldAuxVar(state=self.s_state / len(self.clients))

    def process_aux_var(
        self,
        aux_var: ScaffoldAuxVar,
    ) -> None:
        """Update this module based on received shared auxiliary variables.

        Update the global state variable based on the sum of client's
        local state updates.

        Parameters
        ----------
        aux_var:
            `ScaffoldAuxVar` resulting from the aggregation of clients'
            `ScaffoldClientModule` auxiliary variables.

        Raises
        ------
        KeyError:
            If `aux_var` is empty.
        TypeError:
            If `aux_var` is of unproper type.
        """
        if not isinstance(aux_var, ScaffoldAuxVar):
            raise TypeError(
                f"'{self.__class__.__name__}.process_aux_var' received "
                f"auxiliary variables of unproper type: '{type(aux_var)}'."
            )
        if aux_var.delta is None:
            raise KeyError(
                f"'{self.__class__.__name__}.process_aux_var' received "
                "auxiliary variables with an empty 'delta' field."
            )
        # Update the list of known clients, and the sum of local states.
        self.clients.update(aux_var.clients)
        self.s_state += aux_var.delta

    def get_state(
        self,
    ) -> Dict[str, Any]:
        return {"s_state": self.s_state, "clients": list(self.clients)}

    def set_state(
        self,
        state: Dict[str, Any],
    ) -> None:
        for key in ("s_state", "clients"):
            if key not in state:
                raise KeyError(f"Missing required state variable '{key}'.")
        self.s_state = state["s_state"]
        self.clients = set(state["clients"])

__init__(clients=None)

Instantiate the server-side SCAFFOLD gradients-correction module.

Parameters:

Name Type Description Default
clients Optional[List[str]]

DEPRECATED and unused starting with declearn 2.4. Optional list of known clients' id strings.

None
Source code in declearn/optimizer/modules/_scaffold.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def __init__(
    self,
    clients: Optional[List[str]] = None,
) -> None:
    """Instantiate the server-side SCAFFOLD gradients-correction module.

    Parameters
    ----------
    clients:
        DEPRECATED and unused starting with declearn 2.4.
        Optional list of known clients' id strings.
    """
    self.s_state = 0.0  # type: Union[Vector, float]
    self.clients = set()  # type: Set[str]
    if clients:  # pragma: no cover
        warnings.warn(
            "ScaffoldServerModule's 'clients' argument has been deprecated"
            " as of declearn v2.4, and no longer has any effect. It will"
            " be removed in declearn 2.6 and/or 3.0.",
            DeprecationWarning,
        )

collect_aux_var()

Return auxiliary variables that need to be shared between nodes.

Returns:

Name Type Description
aux_var ScaffoldAuxVar

ScaffoldAuxVar instance holding auxiliary variables that are to be shared with clients' ScaffoldClientModule instances.

Source code in declearn/optimizer/modules/_scaffold.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def collect_aux_var(
    self,
) -> ScaffoldAuxVar:
    """Return auxiliary variables that need to be shared between nodes.

    Returns
    -------
    aux_var:
        `ScaffoldAuxVar` instance holding auxiliary variables that are
        to be shared with clients' `ScaffoldClientModule` instances.
    """
    # When un-initialized, send lightweight information.
    if not self.clients:
        return ScaffoldAuxVar(state=0.0)
    # Otherwise, compute and return the current shared state.
    return ScaffoldAuxVar(state=self.s_state / len(self.clients))

process_aux_var(aux_var)

Update this module based on received shared auxiliary variables.

Update the global state variable based on the sum of client's local state updates.

Parameters:

Name Type Description Default
aux_var ScaffoldAuxVar

ScaffoldAuxVar resulting from the aggregation of clients' ScaffoldClientModule auxiliary variables.

required

Raises:

Type Description
KeyError:

If aux_var is empty.

TypeError:

If aux_var is of unproper type.

Source code in declearn/optimizer/modules/_scaffold.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
def process_aux_var(
    self,
    aux_var: ScaffoldAuxVar,
) -> None:
    """Update this module based on received shared auxiliary variables.

    Update the global state variable based on the sum of client's
    local state updates.

    Parameters
    ----------
    aux_var:
        `ScaffoldAuxVar` resulting from the aggregation of clients'
        `ScaffoldClientModule` auxiliary variables.

    Raises
    ------
    KeyError:
        If `aux_var` is empty.
    TypeError:
        If `aux_var` is of unproper type.
    """
    if not isinstance(aux_var, ScaffoldAuxVar):
        raise TypeError(
            f"'{self.__class__.__name__}.process_aux_var' received "
            f"auxiliary variables of unproper type: '{type(aux_var)}'."
        )
    if aux_var.delta is None:
        raise KeyError(
            f"'{self.__class__.__name__}.process_aux_var' received "
            "auxiliary variables with an empty 'delta' field."
        )
    # Update the list of known clients, and the sum of local states.
    self.clients.update(aux_var.clients)
    self.s_state += aux_var.delta