Bases: Aggregator[GMAModelUpdates]
Gradient Masked Averaging Aggregator subclass.
This class implements the gradient masked averaging algorithm
proposed and analyzed in [1] that modifies the base averaging
algorithm from FedAvg (and its derivatives) by correcting the
averaged updates' magnitude based on the share of clients that
agree on the updates' direction (coordinate-wise).
The formula is the following:
threshold in range(0, 1) # hyperparameter
grads = [grads_client_0, ..., grads_client_N]
agree = abs(sum(sign(grads))) / len(grads)
score = 1 if agree >= threshold else agree
return score * avg(grads)
Client-based and/or number-of-training-steps-based weighting
may also be used, that will be taken into account both when
averaging input gradients and computing the coordinate-wise
average direction that make up for the agreement scores.
References
[1] Tenison et alii, 2022.
Gradient Masked Averaging for Federated Learning.
https://arxiv.org/abs/2201.11986
Source code in declearn/aggregator/_gma.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185 | class GradientMaskedAveraging(Aggregator[GMAModelUpdates]):
"""Gradient Masked Averaging Aggregator subclass.
This class implements the gradient masked averaging algorithm
proposed and analyzed in [1] that modifies the base averaging
algorithm from FedAvg (and its derivatives) by correcting the
averaged updates' magnitude based on the share of clients that
agree on the updates' direction (coordinate-wise).
The formula is the following:
threshold in range(0, 1) # hyperparameter
grads = [grads_client_0, ..., grads_client_N]
agree = abs(sum(sign(grads))) / len(grads)
score = 1 if agree >= threshold else agree
return score * avg(grads)
Client-based and/or number-of-training-steps-based weighting
may also be used, that will be taken into account both when
averaging input gradients and computing the coordinate-wise
average direction that make up for the agreement scores.
References
----------
[1] Tenison et alii, 2022.
Gradient Masked Averaging for Federated Learning.
https://arxiv.org/abs/2201.11986
"""
name = "gradient-masked-averaging"
updates_cls = GMAModelUpdates
def __init__(
self,
threshold: float = 1.0,
steps_weighted: bool = True,
client_weights: Optional[Dict[str, float]] = None,
) -> None:
"""Instantiate a gradient masked averaging aggregator.
Parameters
----------
threshold: float
Threshold above which to round the coordinate-wise agreement
score to 1. Must be in [0, 1] (FedAvg being the 0 edge case).
steps_weighted: bool, default=True
Whether to weight updates based on the number of optimization
steps taken by the clients (relative to one another).
client_weights: dict[str, float] or None, default=None
Optional dict of client-wise base weights to use.
If None, homogeneous base weights are used.
Notes
-----
* One may specify `client_weights` and use `steps_weighted=True`.
In that case, the product of the client's base weight and their
number of training steps taken will be used (and unit-normed).
* One may use incomplete `client_weights`. In that case, unknown-
clients' base weights will be set to 1.
"""
self.threshold = threshold
self._avg = AveragingAggregator(steps_weighted, client_weights)
def get_config(
self,
) -> Dict[str, Any]:
config = super().get_config()
config["threshold"] = self.threshold
return config
def prepare_for_sharing(
self,
updates: Vector,
n_steps: int,
) -> GMAModelUpdates:
data = self._avg.prepare_for_sharing(updates, n_steps)
return GMAModelUpdates(data.updates, data.weights)
def finalize_updates(
self,
updates: GMAModelUpdates,
) -> Vector:
# Average model updates.
values = self._avg.finalize_updates(
ModelUpdates(updates.updates, updates.weights)
)
# Return if signs were not computed, denoting a lack of aggregation.
if updates.up_sign is None:
return values
# Compute the average direction, taken as an agreement score.
scores = self._avg.finalize_updates(
ModelUpdates(updates.up_sign, updates.weights)
)
scores = scores * scores.sign()
# Derive masking scores, using the thresholding hyper-parameter.
clip = (scores - self.threshold).sign().maximum(0.0)
scores = (1 - clip) * scores + clip # s = 1 if s > t else s
# Correct outputs' magnitude and return them.
return values * scores
def compute_client_weights( # pragma: no cover
self,
updates: Dict[str, Vector],
n_steps: Dict[str, int],
) -> Dict[str, float]:
"""Compute weights to use when averaging a given set of updates.
This method is DEPRECATED as of DecLearn v2.4.
It will be removed in DecLearn 2.6 and/or 3.0.
"""
# pylint: disable=duplicate-code
warnings.warn(
f"'{self.__class__.__name__}.compute_client_weights' was"
" deprecated in DecLearn v2.4. It will be removed in DecLearn"
" v2.6 and/or v3.0.",
DeprecationWarning,
)
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=DeprecationWarning)
return self._avg.compute_client_weights(updates, n_steps)
|
__init__(threshold=1.0, steps_weighted=True, client_weights=None)
Instantiate a gradient masked averaging aggregator.
Parameters:
Name |
Type |
Description |
Default |
threshold |
float
|
Threshold above which to round the coordinate-wise agreement
score to 1. Must be in [0, 1] (FedAvg being the 0 edge case). |
1.0
|
steps_weighted |
bool
|
Whether to weight updates based on the number of optimization
steps taken by the clients (relative to one another). |
True
|
client_weights |
Optional[Dict[str, float]]
|
Optional dict of client-wise base weights to use.
If None, homogeneous base weights are used. |
None
|
Notes
- One may specify
client_weights
and use steps_weighted=True
.
In that case, the product of the client's base weight and their
number of training steps taken will be used (and unit-normed).
- One may use incomplete
client_weights
. In that case, unknown-
clients' base weights will be set to 1.
Source code in declearn/aggregator/_gma.py
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127 | def __init__(
self,
threshold: float = 1.0,
steps_weighted: bool = True,
client_weights: Optional[Dict[str, float]] = None,
) -> None:
"""Instantiate a gradient masked averaging aggregator.
Parameters
----------
threshold: float
Threshold above which to round the coordinate-wise agreement
score to 1. Must be in [0, 1] (FedAvg being the 0 edge case).
steps_weighted: bool, default=True
Whether to weight updates based on the number of optimization
steps taken by the clients (relative to one another).
client_weights: dict[str, float] or None, default=None
Optional dict of client-wise base weights to use.
If None, homogeneous base weights are used.
Notes
-----
* One may specify `client_weights` and use `steps_weighted=True`.
In that case, the product of the client's base weight and their
number of training steps taken will be used (and unit-normed).
* One may use incomplete `client_weights`. In that case, unknown-
clients' base weights will be set to 1.
"""
self.threshold = threshold
self._avg = AveragingAggregator(steps_weighted, client_weights)
|
compute_client_weights(updates, n_steps)
Compute weights to use when averaging a given set of updates.
This method is DEPRECATED as of DecLearn v2.4.
It will be removed in DecLearn 2.6 and/or 3.0.
Source code in declearn/aggregator/_gma.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185 | def compute_client_weights( # pragma: no cover
self,
updates: Dict[str, Vector],
n_steps: Dict[str, int],
) -> Dict[str, float]:
"""Compute weights to use when averaging a given set of updates.
This method is DEPRECATED as of DecLearn v2.4.
It will be removed in DecLearn 2.6 and/or 3.0.
"""
# pylint: disable=duplicate-code
warnings.warn(
f"'{self.__class__.__name__}.compute_client_weights' was"
" deprecated in DecLearn v2.4. It will be removed in DecLearn"
" v2.6 and/or v3.0.",
DeprecationWarning,
)
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=DeprecationWarning)
return self._avg.compute_client_weights(updates, n_steps)
|