Abstract base dataclass for cross-peers data aggregation containers.
This class defines an API for containers of values that are
to be shared across peers and aggregated with other similar
instances.
It is typically intended as a base structure to share model
updates, optimizer auxiliary variables, metadata, analytics
or model evaluation metrics that are to be aggregated, and
eventually finalized into some results, across a federated
or decentralized network of data-holding peers.
Aggregation
By default, fields are aggregated using default_aggregate
,
which by default implements the mere summation of two values.
However, the aggregation rule for any field may be overridden
by declaring an aggregate_<field.name>
method.
Subclasses may also overload the main aggregate
method, if
some fields require to be aggregated in a specific way that
involves crossing values from mutiple ones.
Secure Aggregation
The prepare_for_secagg
method defines whether an Aggregate
is suitable for secure aggregation, and if so, which fields
are to be encrypted/sum-decrypted, and which are to be shared
in cleartext and aggregated similarly as in cleartext mode.
By default, subclasses are assumed to support secure summation
and require it for each and every field. The method should be
overridden when this is not the case, returning a pair of dict
storing, respectively, fields that require secure summation,
and fields that are to remain cleartext. If secure aggregation
is not compatible with the subclass, the method should raise a
NotImplementedError
.
Serialization
By default, subclasses will be made (de)serializable to and from
JSON, using declearn.utils.add_json_support
and the to_dict
and from_dict
methods. They will also be type-registered using
declearn.utils.register_type
. This may be prevented by passing
the register=False
keyword argument at inheritance time, i.e.
class MyAggregate(Aggregate, register=False):
.
For this to succeed, first-child subclasses of Aggregate
need
to define the class attribute _group_key
, that acts as a root
for their children' JSON-registration name, and the group name
for their type registration. They also need to be passed the
base_cls=True
keyword argument at inheritance time, i.e.
class FirstChild(Aggregate, base_cls=True):
.
Source code in declearn/utils/_aggregate.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229 | @dataclasses.dataclass
class Aggregate(metaclass=abc.ABCMeta):
"""Abstract base dataclass for cross-peers data aggregation containers.
This class defines an API for containers of values that are
to be shared across peers and aggregated with other similar
instances.
It is typically intended as a base structure to share model
updates, optimizer auxiliary variables, metadata, analytics
or model evaluation metrics that are to be aggregated, and
eventually finalized into some results, across a federated
or decentralized network of data-holding peers.
Aggregation
-----------
By default, fields are aggregated using `default_aggregate`,
which by default implements the mere summation of two values.
However, the aggregation rule for any field may be overridden
by declaring an `aggregate_<field.name>` method.
Subclasses may also overload the main `aggregate` method, if
some fields require to be aggregated in a specific way that
involves crossing values from mutiple ones.
Secure Aggregation
------------------
The `prepare_for_secagg` method defines whether an `Aggregate`
is suitable for secure aggregation, and if so, which fields
are to be encrypted/sum-decrypted, and which are to be shared
in cleartext and aggregated similarly as in cleartext mode.
By default, subclasses are assumed to support secure summation
and require it for each and every field. The method should be
overridden when this is not the case, returning a pair of dict
storing, respectively, fields that require secure summation,
and fields that are to remain cleartext. If secure aggregation
is not compatible with the subclass, the method should raise a
`NotImplementedError`.
Serialization
-------------
By default, subclasses will be made (de)serializable to and from
JSON, using `declearn.utils.add_json_support` and the `to_dict`
and `from_dict` methods. They will also be type-registered using
`declearn.utils.register_type`. This may be prevented by passing
the `register=False` keyword argument at inheritance time, i.e.
`class MyAggregate(Aggregate, register=False):`.
For this to succeed, first-child subclasses of `Aggregate` need
to define the class attribute `_group_key`, that acts as a root
for their children' JSON-registration name, and the group name
for their type registration. They also need to be passed the
`base_cls=True` keyword argument at inheritance time, i.e.
`class FirstChild(Aggregate, base_cls=True):`.
"""
_group_key: ClassVar[str] # Group key for JSON registration.
def __init_subclass__(
cls,
base_cls: bool = False,
register: bool = True,
) -> None:
"""Automatically type-register and add JSON support for subclasses."""
if base_cls:
create_types_registry(cls, name=cls._group_key)
if register:
name = f"{cls._group_key}>{cls.__name__}"
add_json_support(
cls, pack=cls.to_dict, unpack=cls.from_dict, name=name
)
register_type(cls, name=cls.__name__, group=cls._group_key)
def to_dict(
self,
) -> Dict[str, Any]:
"""Return a JSON-serializable dict representation of this instance."""
return dataclasses.asdict(self)
@classmethod
def from_dict(
cls,
data: Dict[str, Any],
) -> Self:
"""Instantiate from an object's dict representation."""
return cls(**data)
def __add__(
self,
other: Any,
) -> Self:
"""Overload the sum operator to aggregate multiple instances."""
try:
return self.aggregate(other)
except TypeError:
return NotImplemented
def __radd__(
self,
other: Any,
) -> Self:
"""Enable `0 + Self -> Self`, to support `sum(Iterator[Self])`."""
if isinstance(other, int) and not other:
return self
return NotImplemented
def aggregate(
self,
other: Self,
) -> Self:
"""Aggregate this with another instance of the same class.
Parameters
----------
other:
Another instance of the same type as `self`.
Returns
-------
aggregated:
An instance of the same class containing aggregated values.
Raises
------
TypeError
If `other` is of unproper type.
ValueError
If any field's aggregation fails.
"""
if not isinstance(other, self.__class__):
raise TypeError(
f"'{self.__class__.__name__}.aggregate' received a wrongful "
f"'other' argument: excepted same type, got '{type(other)}'."
)
# Run the fields' aggregation, wrapping any exception as ValueError.
try:
results = {
field.name: getattr(
self, f"aggregate_{field.name}", self.default_aggregate
)(getattr(self, field.name), getattr(other, field.name))
for field in dataclasses.fields(self)
}
except Exception as exc:
raise ValueError(
"Exception encountered while aggregating two instances "
f"of '{self.__class__.__name__}': {repr(exc)}."
) from exc
# If everything went right, return the resulting AuxVar.
return self.__class__(**results)
@staticmethod
def default_aggregate(
val_a: Any,
val_b: Any,
) -> Any:
"""Aggregate two values using the default summation operator."""
return val_a + val_b
def prepare_for_secagg(
self,
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
"""Return content for secure-aggregation of instances of this class.
Returns
-------
secagg_fields:
Dict storing fields that are compatible with encryption
and secure aggregation using mere summation.
clrtxt_fields:
Dict storing fields that are to be shared in cleartext
version. They will be aggregated using the same method
as usual (`aggregate_<name>` or `default_aggregate`).
Raises
------
NotImplementedError
If this class does not support Secure Aggregation,
and its contents should therefore not be shared.
Notes for developers
--------------------
- `secagg_fields` values should have one of the following types:
- `int` (for positive integer values only)
- `float`
- `numpy.ndarray` (with any floating or integer dtype)
- `Vector`
- Classes that are incompatible with secure aggregation should
implement a `raise NotImplementedError` statement, explaining
whether SecAgg cannot or is yet-to-be supported.
"""
return self.to_dict(), None
|
__add__(other)
Overload the sum operator to aggregate multiple instances.
Source code in declearn/utils/_aggregate.py
126
127
128
129
130
131
132
133
134 | def __add__(
self,
other: Any,
) -> Self:
"""Overload the sum operator to aggregate multiple instances."""
try:
return self.aggregate(other)
except TypeError:
return NotImplemented
|
__init_subclass__(base_cls=False, register=True)
Automatically type-register and add JSON support for subclasses.
Source code in declearn/utils/_aggregate.py
97
98
99
100
101
102
103
104
105
106
107
108
109
110 | def __init_subclass__(
cls,
base_cls: bool = False,
register: bool = True,
) -> None:
"""Automatically type-register and add JSON support for subclasses."""
if base_cls:
create_types_registry(cls, name=cls._group_key)
if register:
name = f"{cls._group_key}>{cls.__name__}"
add_json_support(
cls, pack=cls.to_dict, unpack=cls.from_dict, name=name
)
register_type(cls, name=cls.__name__, group=cls._group_key)
|
__radd__(other)
Enable 0 + Self -> Self
, to support sum(Iterator[Self])
.
Source code in declearn/utils/_aggregate.py
136
137
138
139
140
141
142
143 | def __radd__(
self,
other: Any,
) -> Self:
"""Enable `0 + Self -> Self`, to support `sum(Iterator[Self])`."""
if isinstance(other, int) and not other:
return self
return NotImplemented
|
aggregate(other)
Aggregate this with another instance of the same class.
Parameters:
Name |
Type |
Description |
Default |
other |
Self
|
Another instance of the same type as self . |
required
|
Returns:
Name | Type |
Description |
aggregated |
Self
|
An instance of the same class containing aggregated values. |
Raises:
Type |
Description |
TypeError
|
If other is of unproper type. |
ValueError
|
If any field's aggregation fails. |
Source code in declearn/utils/_aggregate.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187 | def aggregate(
self,
other: Self,
) -> Self:
"""Aggregate this with another instance of the same class.
Parameters
----------
other:
Another instance of the same type as `self`.
Returns
-------
aggregated:
An instance of the same class containing aggregated values.
Raises
------
TypeError
If `other` is of unproper type.
ValueError
If any field's aggregation fails.
"""
if not isinstance(other, self.__class__):
raise TypeError(
f"'{self.__class__.__name__}.aggregate' received a wrongful "
f"'other' argument: excepted same type, got '{type(other)}'."
)
# Run the fields' aggregation, wrapping any exception as ValueError.
try:
results = {
field.name: getattr(
self, f"aggregate_{field.name}", self.default_aggregate
)(getattr(self, field.name), getattr(other, field.name))
for field in dataclasses.fields(self)
}
except Exception as exc:
raise ValueError(
"Exception encountered while aggregating two instances "
f"of '{self.__class__.__name__}': {repr(exc)}."
) from exc
# If everything went right, return the resulting AuxVar.
return self.__class__(**results)
|
default_aggregate(val_a, val_b)
staticmethod
Aggregate two values using the default summation operator.
Source code in declearn/utils/_aggregate.py
189
190
191
192
193
194
195 | @staticmethod
def default_aggregate(
val_a: Any,
val_b: Any,
) -> Any:
"""Aggregate two values using the default summation operator."""
return val_a + val_b
|
from_dict(data)
classmethod
Instantiate from an object's dict representation.
Source code in declearn/utils/_aggregate.py
118
119
120
121
122
123
124 | @classmethod
def from_dict(
cls,
data: Dict[str, Any],
) -> Self:
"""Instantiate from an object's dict representation."""
return cls(**data)
|
prepare_for_secagg()
Return content for secure-aggregation of instances of this class.
Returns:
Name | Type |
Description |
secagg_fields |
Dict[str, Any]
|
Dict storing fields that are compatible with encryption
and secure aggregation using mere summation. |
clrtxt_fields |
Optional[Dict[str, Any]]
|
Dict storing fields that are to be shared in cleartext
version. They will be aggregated using the same method
as usual (aggregate_<name> or default_aggregate ). |
Raises:
Type |
Description |
NotImplementedError
|
If this class does not support Secure Aggregation,
and its contents should therefore not be shared. |
Notes for developers
secagg_fields
values should have one of the following types:
int
(for positive integer values only)
float
numpy.ndarray
(with any floating or integer dtype)
Vector
- Classes that are incompatible with secure aggregation should
implement a
raise NotImplementedError
statement, explaining
whether SecAgg cannot or is yet-to-be supported.
Source code in declearn/utils/_aggregate.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229 | def prepare_for_secagg(
self,
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
"""Return content for secure-aggregation of instances of this class.
Returns
-------
secagg_fields:
Dict storing fields that are compatible with encryption
and secure aggregation using mere summation.
clrtxt_fields:
Dict storing fields that are to be shared in cleartext
version. They will be aggregated using the same method
as usual (`aggregate_<name>` or `default_aggregate`).
Raises
------
NotImplementedError
If this class does not support Secure Aggregation,
and its contents should therefore not be shared.
Notes for developers
--------------------
- `secagg_fields` values should have one of the following types:
- `int` (for positive integer values only)
- `float`
- `numpy.ndarray` (with any floating or integer dtype)
- `Vector`
- Classes that are incompatible with secure aggregation should
implement a `raise NotImplementedError` statement, explaining
whether SecAgg cannot or is yet-to-be supported.
"""
return self.to_dict(), None
|
to_dict()
Return a JSON-serializable dict representation of this instance.
Source code in declearn/utils/_aggregate.py
| def to_dict(
self,
) -> Dict[str, Any]:
"""Return a JSON-serializable dict representation of this instance."""
return dataclasses.asdict(self)
|