Skip to content

declearn.optimizer.modules.ScaffoldClientModule

Bases: OptiModule

Client-side Stochastic Controlled Averaging (SCAFFOLD) module.

This module is to be added to the optimizer used by a federated- learning client, and expects that the server's optimizer use its counterpart module: ScaffoldServerModule.

This module implements the following algorithm:

Init:
    delta = 0
    _past = 0
    _step = 0
Step(grads):
    _past += grads
    _step += 1
    grads = grads - delta
Send:
    state = (_past / _step)
Receive(delta):
    delta = delta
    reset(_past, _step) to 0

In other words, this module receives a "delta" variable from the server instance which is set as the difference between a client- specific state and a shared one, and corrects input gradients by adding this delta to it. At the end of a training round (made of multiple steps) it computes an updated client state based on the accumulated sum of corrected gradients. This value is to be sent to the server, that will emit a new value for the local delta in return.

The SCAFFOLD algorithm is described in reference [1]. The server-side correction of aggregated gradients, the storage of raw local and shared states, and the computation of the updated shared state and derived client-wise delta values are deferred to ScaffoldServerModule.

The formula applied to compute the updated local state variables corresponds to the "Option-II" in the paper. Implementing Option-I would require holding a copy of the shared model and computing its gradients in addition to those of the local model, effectively doubling computations. This can be done in declearn, but requires implementing an alternative training procedure rather than an optimizer plug-in.

References

[1] Karimireddy et al., 2019. SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. https://arxiv.org/abs/1910.06378

Source code in declearn/optimizer/modules/_scaffold.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class ScaffoldClientModule(OptiModule):
    """Client-side Stochastic Controlled Averaging (SCAFFOLD) module.

    This module is to be added to the optimizer used by a federated-
    learning client, and expects that the server's optimizer use its
    counterpart module:
    [`ScaffoldServerModule`][declearn.optimizer.modules.ScaffoldServerModule].

    This module implements the following algorithm:

        Init:
            delta = 0
            _past = 0
            _step = 0
        Step(grads):
            _past += grads
            _step += 1
            grads = grads - delta
        Send:
            state = (_past / _step)
        Receive(delta):
            delta = delta
            reset(_past, _step) to 0

    In other words, this module receives a "delta" variable from the
    server instance which is set as the difference between a client-
    specific state and a shared one, and corrects input gradients by
    adding this delta to it. At the end of a training round (made of
    multiple steps) it computes an updated client state based on the
    accumulated sum of corrected gradients. This value is to be sent
    to the server, that will emit a new value for the local delta in
    return.

    The SCAFFOLD algorithm is described in reference [1].
    The server-side correction of aggregated gradients, the storage
    of raw local and shared states, and the computation of the updated
    shared state and derived client-wise delta values are deferred to
    `ScaffoldServerModule`.

    The formula applied to compute the updated local state variables
    corresponds to the "Option-II" in the paper.
    Implementing Option-I would require holding a copy of the shared
    model and computing its gradients in addition to those of the
    local model, effectively doubling computations. This can be done
    in `declearn`, but requires implementing an alternative training
    procedure rather than an optimizer plug-in.

    References
    ----------
    [1] Karimireddy et al., 2019.
        SCAFFOLD: Stochastic Controlled Averaging for Federated Learning.
        https://arxiv.org/abs/1910.06378
    """

    name: ClassVar[str] = "scaffold-client"
    aux_name: ClassVar[str] = "scaffold"

    def __init__(
        self,
    ) -> None:
        """Instantiate the client-side SCAFFOLD gradients-correction module."""
        self.delta = 0.0  # type: Union[Vector, float]
        self._grads = 0.0  # type: Union[Vector, float]
        self._steps = 0

    def run(
        self,
        gradients: Vector,
    ) -> Vector:
        # Accumulate the uncorrected gradients.
        self._grads = self._grads + gradients
        self._steps += 1
        # Apply state-based correction to outputs.
        return gradients - self.delta

    def collect_aux_var(
        self,
    ) -> Dict[str, Any]:
        """Return auxiliary variables that need to be shared between nodes.

        Compute and package (without applying it) the updated value
        of the local state variable, so that the server may compute
        the updated shared state variable.

        Returns
        -------
        aux_var: dict[str, any]
            JSON-serializable dict of auxiliary variables that
            are to be shared with a ScaffoldServerModule held
            by the orchestrating server.

        Warns
        -----
        RuntimeWarning
            If called on an instance that has not processed any gradients
            (via a call to `run`) since the last call to `process_aux_var`
            (or its instantiation).
        """
        state = self._compute_updated_state()
        return {"state": state}

    def _compute_updated_state(
        self,
    ) -> Union[Vector, float]:
        """Compute and return the updated value of the local state.

        Note: the computed update is *not* applied by this method.

        The computation implemented here is equivalent to "Option II"
        of the SCAFFOLD paper. In that paper, authors write that:
            c_i^+ = (c_i - c) + (x - y_i) / (K * eta_l)
        where x are the shared model's weights, y_i are the local
        model's weights after K optimization steps with eta_l lr,
        c is the shared global state and c_i is the local state.

        Noting that (x - y_i) is in fact the difference between the
        local model's weights before and after running K training
        steps, we rewrite it as eta_l * Sum_k(grad(y_i^k) - D_i),
        where we define D_i = (c_i - c). Thus we rewrite c_i^+ as:
            c_i^+ = D_i + (1/K)*Sum_k(grad(y_i^k) - D_i)
        When then note that D_i is constant and can be taken out
        of the summation term, leaving us with:
            c_i^+ = Avg_k(grad(y_i^k))

        Hence the new local state can be computed by averaging the
        gradients input to this module along the training steps.
        """
        if not self._steps:
            warnings.warn(
                "Collecting auxiliary variables from a scaffold module "
                "that was not run. Returned zero-valued scalar updates.",
                category=RuntimeWarning,
            )
            return 0.0
        return self._grads / self._steps

    def process_aux_var(
        self,
        aux_var: Dict[str, Any],
    ) -> None:
        """Update this module based on received shared auxiliary variables.

        Collect the (local_state - shared_state) variable sent by server.
        Reset hidden variables used to compute the local state's updates.

        Parameters
        ----------
        aux_var: dict[str, any]
            JSON-serializable dict of auxiliary variables that are to be
            processed by this module at the start of a training round (on
            the client side).
            Expected keys for this class: {"delta"}.

        Raises
        ------
        KeyError
            If an expected auxiliary variable is missing.
        TypeError
            If a variable is of unproper type, or if aux_var
            is not formatted as it should be.
        """
        # Expect a state variable and apply it.
        delta = aux_var.get("delta", None)
        if delta is None:
            raise KeyError(
                "Missing 'delta' key in ScaffoldClientModule's "
                "received auxiliary variables."
            )
        if isinstance(delta, (float, Vector)):
            self.delta = delta
        else:
            raise TypeError(
                "Unsupported type for ScaffoldClientModule's "
                "received auxiliary variable 'delta'."
            )
        # Reset local variables.
        self._grads = 0.0
        self._steps = 0

__init__()

Instantiate the client-side SCAFFOLD gradients-correction module.

Source code in declearn/optimizer/modules/_scaffold.py
104
105
106
107
108
109
110
def __init__(
    self,
) -> None:
    """Instantiate the client-side SCAFFOLD gradients-correction module."""
    self.delta = 0.0  # type: Union[Vector, float]
    self._grads = 0.0  # type: Union[Vector, float]
    self._steps = 0

collect_aux_var()

Return auxiliary variables that need to be shared between nodes.

Compute and package (without applying it) the updated value of the local state variable, so that the server may compute the updated shared state variable.

Returns:

Name Type Description
aux_var dict[str, any]

JSON-serializable dict of auxiliary variables that are to be shared with a ScaffoldServerModule held by the orchestrating server.

Warns:

Type Description
RuntimeWarning

If called on an instance that has not processed any gradients (via a call to run) since the last call to process_aux_var (or its instantiation).

Source code in declearn/optimizer/modules/_scaffold.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def collect_aux_var(
    self,
) -> Dict[str, Any]:
    """Return auxiliary variables that need to be shared between nodes.

    Compute and package (without applying it) the updated value
    of the local state variable, so that the server may compute
    the updated shared state variable.

    Returns
    -------
    aux_var: dict[str, any]
        JSON-serializable dict of auxiliary variables that
        are to be shared with a ScaffoldServerModule held
        by the orchestrating server.

    Warns
    -----
    RuntimeWarning
        If called on an instance that has not processed any gradients
        (via a call to `run`) since the last call to `process_aux_var`
        (or its instantiation).
    """
    state = self._compute_updated_state()
    return {"state": state}

process_aux_var(aux_var)

Update this module based on received shared auxiliary variables.

Collect the (local_state - shared_state) variable sent by server. Reset hidden variables used to compute the local state's updates.

Parameters:

Name Type Description Default
aux_var Dict[str, Any]

JSON-serializable dict of auxiliary variables that are to be processed by this module at the start of a training round (on the client side). Expected keys for this class: {"delta"}.

required

Raises:

Type Description
KeyError

If an expected auxiliary variable is missing.

TypeError

If a variable is of unproper type, or if aux_var is not formatted as it should be.

Source code in declearn/optimizer/modules/_scaffold.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def process_aux_var(
    self,
    aux_var: Dict[str, Any],
) -> None:
    """Update this module based on received shared auxiliary variables.

    Collect the (local_state - shared_state) variable sent by server.
    Reset hidden variables used to compute the local state's updates.

    Parameters
    ----------
    aux_var: dict[str, any]
        JSON-serializable dict of auxiliary variables that are to be
        processed by this module at the start of a training round (on
        the client side).
        Expected keys for this class: {"delta"}.

    Raises
    ------
    KeyError
        If an expected auxiliary variable is missing.
    TypeError
        If a variable is of unproper type, or if aux_var
        is not formatted as it should be.
    """
    # Expect a state variable and apply it.
    delta = aux_var.get("delta", None)
    if delta is None:
        raise KeyError(
            "Missing 'delta' key in ScaffoldClientModule's "
            "received auxiliary variables."
        )
    if isinstance(delta, (float, Vector)):
        self.delta = delta
    else:
        raise TypeError(
            "Unsupported type for ScaffoldClientModule's "
            "received auxiliary variable 'delta'."
        )
    # Reset local variables.
    self._grads = 0.0
    self._steps = 0