Bases: OptiModule
Server-side Stochastic Controlled Averaging (SCAFFOLD) module.
This module is to be added to the optimizer used by a federated-
learning server, and expects that the clients' optimizer use its
counterpart module:
ScaffoldClientModule
.
This module implements the following algorithm:
Init(clients):
state = 0
s_loc = {client: 0 for client in clients}
Step(grads):
grads
Send:
delta = {client: (s_loc[client] - state); client in s_loc}
Receive(s_new = {client: state}):
s_upd = sum(s_new[client] - s_loc[client]; client in s_new)
s_loc.update(s_new)
state += s_upd / len(s_loc)
In other words, this module holds a shared state variable, and a
set of client-specific ones, which are zero-valued when created.
At the beginning of a training round it sends to each client its
delta variable, set to the difference between its current state
and the shared one, which is to be applied as a correction term
to local gradients. At the end of a training round, aggregated
gradients are corrected by substracting the shared state value
from them. Finally, updated local states received from clients
are recorded, and used to update the shared state variable, so
that new delta values can be sent to clients as the next round
of training starts.
The SCAFFOLD algorithm is described in reference [1].
The client-side correction of gradients and the computation of
updated local states are deferred to ScaffoldClientModule
.
References
[1] Karimireddy et al., 2019.
SCAFFOLD: Stochastic Controlled Averaging for Federated Learning.
https://arxiv.org/abs/1910.06378
Source code in declearn/optimizer/modules/_scaffold.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406 | class ScaffoldServerModule(OptiModule):
"""Server-side Stochastic Controlled Averaging (SCAFFOLD) module.
This module is to be added to the optimizer used by a federated-
learning server, and expects that the clients' optimizer use its
counterpart module:
[`ScaffoldClientModule`][declearn.optimizer.modules.ScaffoldClientModule].
This module implements the following algorithm:
Init(clients):
state = 0
s_loc = {client: 0 for client in clients}
Step(grads):
grads
Send:
delta = {client: (s_loc[client] - state); client in s_loc}
Receive(s_new = {client: state}):
s_upd = sum(s_new[client] - s_loc[client]; client in s_new)
s_loc.update(s_new)
state += s_upd / len(s_loc)
In other words, this module holds a shared state variable, and a
set of client-specific ones, which are zero-valued when created.
At the beginning of a training round it sends to each client its
delta variable, set to the difference between its current state
and the shared one, which is to be applied as a correction term
to local gradients. At the end of a training round, aggregated
gradients are corrected by substracting the shared state value
from them. Finally, updated local states received from clients
are recorded, and used to update the shared state variable, so
that new delta values can be sent to clients as the next round
of training starts.
The SCAFFOLD algorithm is described in reference [1].
The client-side correction of gradients and the computation of
updated local states are deferred to `ScaffoldClientModule`.
References
----------
[1] Karimireddy et al., 2019.
SCAFFOLD: Stochastic Controlled Averaging for Federated Learning.
https://arxiv.org/abs/1910.06378
"""
name: ClassVar[str] = "scaffold-server"
aux_name: ClassVar[str] = "scaffold"
def __init__(
self,
clients: Optional[List[str]] = None,
) -> None:
"""Instantiate the server-side SCAFFOLD gradients-correction module.
Parameters
----------
clients: list[str] or None, default=None
Optional list of known clients' id strings.
Notes
-----
- If this module is used under a training strategy that has
participating clients vary across epochs, leaving `clients`
to None will affect the update rule for the shared state,
as it uses a (n_participating / n_total_clients) term, the
divisor of which will be incorrect (at least on the first
step, potentially on following ones as well).
- Similarly, listing clients that in fact do not participate
in training will have side effects on computations.
"""
self.state = 0.0 # type: Union[Vector, float]
self.s_loc = {} # type: Dict[str, Union[Vector, float]]
if clients:
self.s_loc = {client: 0.0 for client in clients}
def get_config(
self,
) -> Dict[str, Any]:
return {"clients": list(self.s_loc)}
def run(
self,
gradients: Vector,
) -> Vector:
# Note: ScaffoldServer only manages auxiliary variables.
return gradients
def collect_aux_var(
self,
) -> Dict[str, Dict[str, Any]]:
"""Return auxiliary variables that need to be shared between nodes.
Package client-wise `delta = (local_state - shared_state)` variables.
Returns
-------
aux_var:
JSON-serializable dict of auxiliary variables that are to
be shared with the client-wise ScaffoldClientModule. This
dict has a `{client-name: {"delta": value}}` structure.
"""
# Compute clients' delta variable, package them and return.
aux_var = {} # type: Dict[str, Dict[str, Any]]
for client, state in self.s_loc.items():
delta = state - self.state
aux_var[client] = {"delta": delta}
return aux_var
def process_aux_var(
self,
aux_var: Dict[str, Dict[str, Any]],
) -> None:
"""Update this module based on received shared auxiliary variables.
Collect updated local state variables sent by clients.
Update the global state variable based on the latter.
Parameters
----------
aux_var: dict[str, dict[str, any]]
JSON-serializable dict of auxiliary variables that are to be
processed by this module before processing global updates.
This dict should have a `{client-name: {"state": value}}`
structure.
Raises
------
KeyError:
If an expected auxiliary variable is missing.
TypeError:
If a variable is of unproper type, or if aux_var
is not formatted as it should be.
"""
# Collect updated local states received from Scaffold client modules.
s_new = {} # type: Dict[str, Union[Vector, float]]
for client, c_dict in aux_var.items():
if not isinstance(c_dict, dict):
raise TypeError(
"ScaffoldServerModule requires auxiliary variables "
"to be received as client-wise dictionaries."
)
if "state" not in c_dict:
raise KeyError(
"Missing required 'state' key in auxiliary variables "
f"received by ScaffoldServerModule from client '{client}'."
)
state = c_dict["state"]
if isinstance(state, float) and state == 0.0:
# Drop info from clients that have not processed gradients.
continue
if isinstance(state, (Vector, float)):
s_new[client] = state
else:
raise TypeError(
"Unsupported type for auxiliary variable 'state' "
f"received by ScaffoldServerModule from client '{client}'."
)
# Update the global and client-wise state variables.
update = sum(
state - self.s_loc.get(client, 0.0)
for client, state in s_new.items()
)
self.s_loc.update(s_new)
update = update / len(self.s_loc)
self.state = self.state + update
def get_state(
self,
) -> Dict[str, Any]:
return {"state": self.state, "s_loc": self.s_loc}
def set_state(
self,
state: Dict[str, Any],
) -> None:
for key in ("state", "s_loc"):
if key not in state:
raise KeyError(f"Missing required state variable '{key}'.")
self.state = state["state"]
self.s_loc = state["s_loc"]
|
__init__(clients=None)
Instantiate the server-side SCAFFOLD gradients-correction module.
Parameters:
Name |
Type |
Description |
Default |
clients |
Optional[List[str]]
|
Optional list of known clients' id strings. |
None
|
Notes
- If this module is used under a training strategy that has
participating clients vary across epochs, leaving
clients
to None will affect the update rule for the shared state,
as it uses a (n_participating / n_total_clients) term, the
divisor of which will be incorrect (at least on the first
step, potentially on following ones as well).
- Similarly, listing clients that in fact do not participate
in training will have side effects on computations.
Source code in declearn/optimizer/modules/_scaffold.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300 | def __init__(
self,
clients: Optional[List[str]] = None,
) -> None:
"""Instantiate the server-side SCAFFOLD gradients-correction module.
Parameters
----------
clients: list[str] or None, default=None
Optional list of known clients' id strings.
Notes
-----
- If this module is used under a training strategy that has
participating clients vary across epochs, leaving `clients`
to None will affect the update rule for the shared state,
as it uses a (n_participating / n_total_clients) term, the
divisor of which will be incorrect (at least on the first
step, potentially on following ones as well).
- Similarly, listing clients that in fact do not participate
in training will have side effects on computations.
"""
self.state = 0.0 # type: Union[Vector, float]
self.s_loc = {} # type: Dict[str, Union[Vector, float]]
if clients:
self.s_loc = {client: 0.0 for client in clients}
|
collect_aux_var()
Return auxiliary variables that need to be shared between nodes.
Package client-wise delta = (local_state - shared_state)
variables.
Returns:
Name | Type |
Description |
aux_var |
Dict[str, Dict[str, Any]]
|
JSON-serializable dict of auxiliary variables that are to
be shared with the client-wise ScaffoldClientModule. This
dict has a {client-name: {"delta": value}} structure. |
Source code in declearn/optimizer/modules/_scaffold.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333 | def collect_aux_var(
self,
) -> Dict[str, Dict[str, Any]]:
"""Return auxiliary variables that need to be shared between nodes.
Package client-wise `delta = (local_state - shared_state)` variables.
Returns
-------
aux_var:
JSON-serializable dict of auxiliary variables that are to
be shared with the client-wise ScaffoldClientModule. This
dict has a `{client-name: {"delta": value}}` structure.
"""
# Compute clients' delta variable, package them and return.
aux_var = {} # type: Dict[str, Dict[str, Any]]
for client, state in self.s_loc.items():
delta = state - self.state
aux_var[client] = {"delta": delta}
return aux_var
|
process_aux_var(aux_var)
Update this module based on received shared auxiliary variables.
Collect updated local state variables sent by clients.
Update the global state variable based on the latter.
Parameters:
Name |
Type |
Description |
Default |
aux_var |
Dict[str, Dict[str, Any]]
|
JSON-serializable dict of auxiliary variables that are to be
processed by this module before processing global updates.
This dict should have a {client-name: {"state": value}}
structure. |
required
|
Raises:
Type |
Description |
KeyError:
|
If an expected auxiliary variable is missing. |
TypeError:
|
If a variable is of unproper type, or if aux_var
is not formatted as it should be. |
Source code in declearn/optimizer/modules/_scaffold.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391 | def process_aux_var(
self,
aux_var: Dict[str, Dict[str, Any]],
) -> None:
"""Update this module based on received shared auxiliary variables.
Collect updated local state variables sent by clients.
Update the global state variable based on the latter.
Parameters
----------
aux_var: dict[str, dict[str, any]]
JSON-serializable dict of auxiliary variables that are to be
processed by this module before processing global updates.
This dict should have a `{client-name: {"state": value}}`
structure.
Raises
------
KeyError:
If an expected auxiliary variable is missing.
TypeError:
If a variable is of unproper type, or if aux_var
is not formatted as it should be.
"""
# Collect updated local states received from Scaffold client modules.
s_new = {} # type: Dict[str, Union[Vector, float]]
for client, c_dict in aux_var.items():
if not isinstance(c_dict, dict):
raise TypeError(
"ScaffoldServerModule requires auxiliary variables "
"to be received as client-wise dictionaries."
)
if "state" not in c_dict:
raise KeyError(
"Missing required 'state' key in auxiliary variables "
f"received by ScaffoldServerModule from client '{client}'."
)
state = c_dict["state"]
if isinstance(state, float) and state == 0.0:
# Drop info from clients that have not processed gradients.
continue
if isinstance(state, (Vector, float)):
s_new[client] = state
else:
raise TypeError(
"Unsupported type for auxiliary variable 'state' "
f"received by ScaffoldServerModule from client '{client}'."
)
# Update the global and client-wise state variables.
update = sum(
state - self.s_loc.get(client, 0.0)
for client, state in s_new.items()
)
self.s_loc.update(s_new)
update = update / len(self.s_loc)
self.state = self.state + update
|