Skip to content

declearn.utils.Aggregate

Abstract base dataclass for cross-peers data aggregation containers.

This class defines an API for containers of values that are to be shared across peers and aggregated with other similar instances.

It is typically intended as a base structure to share model updates, optimizer auxiliary variables, metadata, analytics or model evaluation metrics that are to be aggregated, and eventually finalized into some results, across a federated or decentralized network of data-holding peers.

Aggregation

By default, fields are aggregated using default_aggregate, which by default implements the mere summation of two values. However, the aggregation rule for any field may be overridden by declaring an aggregate_<field.name> method.

Subclasses may also overload the main aggregate method, if some fields require to be aggregated in a specific way that involves crossing values from mutiple ones.

Secure Aggregation

The prepare_for_secagg method defines whether an Aggregate is suitable for secure aggregation, and if so, which fields are to be encrypted/sum-decrypted, and which are to be shared in cleartext and aggregated similarly as in cleartext mode.

By default, subclasses are assumed to support secure summation and require it for each and every field. The method should be overridden when this is not the case, returning a pair of dict storing, respectively, fields that require secure summation, and fields that are to remain cleartext. If secure aggregation is not compatible with the subclass, the method should raise a NotImplementedError.

Serialization

By default, subclasses will be made (de)serializable to and from JSON, using declearn.utils.add_json_support and the to_dict and from_dict methods. They will also be type-registered using declearn.utils.register_type. This may be prevented by passing the register=False keyword argument at inheritance time, i.e. class MyAggregate(Aggregate, register=False):.

For this to succeed, first-child subclasses of Aggregate need to define the class attribute _group_key, that acts as a root for their children' JSON-registration name, and the group name for their type registration. They also need to be passed the base_cls=True keyword argument at inheritance time, i.e. class FirstChild(Aggregate, base_cls=True):.

Source code in declearn/utils/_aggregate.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
@dataclasses.dataclass
class Aggregate(metaclass=abc.ABCMeta):
    """Abstract base dataclass for cross-peers data aggregation containers.

    This class defines an API for containers of values that are
    to be shared across peers and aggregated with other similar
    instances.

    It is typically intended as a base structure to share model
    updates, optimizer auxiliary variables, metadata, analytics
    or model evaluation metrics that are to be aggregated, and
    eventually finalized into some results, across a federated
    or decentralized network of data-holding peers.

    Aggregation
    -----------

    By default, fields are aggregated using `default_aggregate`,
    which by default implements the mere summation of two values.
    However, the aggregation rule for any field may be overridden
    by declaring an `aggregate_<field.name>` method.

    Subclasses may also overload the main `aggregate` method, if
    some fields require to be aggregated in a specific way that
    involves crossing values from mutiple ones.

    Secure Aggregation
    ------------------

    The `prepare_for_secagg` method defines whether an `Aggregate`
    is suitable for secure aggregation, and if so, which fields
    are to be encrypted/sum-decrypted, and which are to be shared
    in cleartext and aggregated similarly as in cleartext mode.

    By default, subclasses are assumed to support secure summation
    and require it for each and every field. The method should be
    overridden when this is not the case, returning a pair of dict
    storing, respectively, fields that require secure summation,
    and fields that are to remain cleartext. If secure aggregation
    is not compatible with the subclass, the method should raise a
    `NotImplementedError`.

    Serialization
    -------------

    By default, subclasses will be made (de)serializable to and from
    JSON, using `declearn.utils.add_json_support` and the `to_dict`
    and `from_dict` methods. They will also be type-registered using
    `declearn.utils.register_type`. This may be prevented by passing
    the `register=False` keyword argument at inheritance time, i.e.
    `class MyAggregate(Aggregate, register=False):`.

    For this to succeed, first-child subclasses of `Aggregate` need
    to define the class attribute `_group_key`, that acts as a root
    for their children' JSON-registration name, and the group name
    for their type registration. They also need to be passed the
    `base_cls=True` keyword argument at inheritance time, i.e.
    `class FirstChild(Aggregate, base_cls=True):`.
    """

    _group_key: ClassVar[str]  # Group key for JSON registration.

    def __init_subclass__(
        cls,
        base_cls: bool = False,
        register: bool = True,
    ) -> None:
        """Automatically type-register and add JSON support for subclasses."""
        if base_cls:
            create_types_registry(cls, name=cls._group_key)
        if register:
            name = f"{cls._group_key}>{cls.__name__}"
            add_json_support(
                cls, pack=cls.to_dict, unpack=cls.from_dict, name=name
            )
            register_type(cls, name=cls.__name__, group=cls._group_key)

    def to_dict(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable dict representation of this instance."""
        return dataclasses.asdict(self)

    @classmethod
    def from_dict(
        cls,
        data: Dict[str, Any],
    ) -> Self:
        """Instantiate from an object's dict representation."""
        return cls(**data)

    def __add__(
        self,
        other: Any,
    ) -> Self:
        """Overload the sum operator to aggregate multiple instances."""
        try:
            return self.aggregate(other)
        except TypeError:
            return NotImplemented

    def __radd__(
        self,
        other: Any,
    ) -> Self:
        """Enable `0 + Self -> Self`, to support `sum(Iterator[Self])`."""
        if isinstance(other, int) and not other:
            return self
        return NotImplemented

    def aggregate(
        self,
        other: Self,
    ) -> Self:
        """Aggregate this with another instance of the same class.

        Parameters
        ----------
        other:
            Another instance of the same type as `self`.

        Returns
        -------
        aggregated:
            An instance of the same class containing aggregated values.

        Raises
        ------
        TypeError
            If `other` is of unproper type.
        ValueError
            If any field's aggregation fails.
        """
        if not isinstance(other, self.__class__):
            raise TypeError(
                f"'{self.__class__.__name__}.aggregate' received a wrongful "
                f"'other'  argument: excepted same type, got '{type(other)}'."
            )
        # Run the fields' aggregation, wrapping any exception as ValueError.
        try:
            results = {
                field.name: getattr(
                    self, f"aggregate_{field.name}", self.default_aggregate
                )(getattr(self, field.name), getattr(other, field.name))
                for field in dataclasses.fields(self)
            }
        except Exception as exc:
            raise ValueError(
                "Exception encountered while aggregating two instances "
                f"of '{self.__class__.__name__}': {repr(exc)}."
            ) from exc
        # If everything went right, return the resulting AuxVar.
        return self.__class__(**results)

    @staticmethod
    def default_aggregate(
        val_a: Any,
        val_b: Any,
    ) -> Any:
        """Aggregate two values using the default summation operator."""
        return val_a + val_b

    def prepare_for_secagg(
        self,
    ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
        """Return content for secure-aggregation of instances of this class.

        Returns
        -------
        secagg_fields:
            Dict storing fields that are compatible with encryption
            and secure aggregation using mere summation.
        clrtxt_fields:
            Dict storing fields that are to be shared in cleartext
            version. They will be aggregated using the same method
            as usual (`aggregate_<name>` or `default_aggregate`).

        Raises
        ------
        NotImplementedError
            If this class does not support Secure Aggregation,
            and its contents should therefore not be shared.

        Notes for developers
        --------------------
        - `secagg_fields` values should have one of the following types:
            - `int` (for positive integer values only)
            - `float`
            - `numpy.ndarray` (with any floating or integer dtype)
            - `Vector`
        - Classes that are incompatible with secure aggregation should
          implement a `raise NotImplementedError` statement, explaining
          whether SecAgg cannot or is yet-to-be supported.
        """
        return self.to_dict(), None

__add__(other)

Overload the sum operator to aggregate multiple instances.

Source code in declearn/utils/_aggregate.py
126
127
128
129
130
131
132
133
134
def __add__(
    self,
    other: Any,
) -> Self:
    """Overload the sum operator to aggregate multiple instances."""
    try:
        return self.aggregate(other)
    except TypeError:
        return NotImplemented

__init_subclass__(base_cls=False, register=True)

Automatically type-register and add JSON support for subclasses.

Source code in declearn/utils/_aggregate.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init_subclass__(
    cls,
    base_cls: bool = False,
    register: bool = True,
) -> None:
    """Automatically type-register and add JSON support for subclasses."""
    if base_cls:
        create_types_registry(cls, name=cls._group_key)
    if register:
        name = f"{cls._group_key}>{cls.__name__}"
        add_json_support(
            cls, pack=cls.to_dict, unpack=cls.from_dict, name=name
        )
        register_type(cls, name=cls.__name__, group=cls._group_key)

__radd__(other)

Enable 0 + Self -> Self, to support sum(Iterator[Self]).

Source code in declearn/utils/_aggregate.py
136
137
138
139
140
141
142
143
def __radd__(
    self,
    other: Any,
) -> Self:
    """Enable `0 + Self -> Self`, to support `sum(Iterator[Self])`."""
    if isinstance(other, int) and not other:
        return self
    return NotImplemented

aggregate(other)

Aggregate this with another instance of the same class.

Parameters:

Name Type Description Default
other Self

Another instance of the same type as self.

required

Returns:

Name Type Description
aggregated Self

An instance of the same class containing aggregated values.

Raises:

Type Description
TypeError

If other is of unproper type.

ValueError

If any field's aggregation fails.

Source code in declearn/utils/_aggregate.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def aggregate(
    self,
    other: Self,
) -> Self:
    """Aggregate this with another instance of the same class.

    Parameters
    ----------
    other:
        Another instance of the same type as `self`.

    Returns
    -------
    aggregated:
        An instance of the same class containing aggregated values.

    Raises
    ------
    TypeError
        If `other` is of unproper type.
    ValueError
        If any field's aggregation fails.
    """
    if not isinstance(other, self.__class__):
        raise TypeError(
            f"'{self.__class__.__name__}.aggregate' received a wrongful "
            f"'other'  argument: excepted same type, got '{type(other)}'."
        )
    # Run the fields' aggregation, wrapping any exception as ValueError.
    try:
        results = {
            field.name: getattr(
                self, f"aggregate_{field.name}", self.default_aggregate
            )(getattr(self, field.name), getattr(other, field.name))
            for field in dataclasses.fields(self)
        }
    except Exception as exc:
        raise ValueError(
            "Exception encountered while aggregating two instances "
            f"of '{self.__class__.__name__}': {repr(exc)}."
        ) from exc
    # If everything went right, return the resulting AuxVar.
    return self.__class__(**results)

default_aggregate(val_a, val_b) staticmethod

Aggregate two values using the default summation operator.

Source code in declearn/utils/_aggregate.py
189
190
191
192
193
194
195
@staticmethod
def default_aggregate(
    val_a: Any,
    val_b: Any,
) -> Any:
    """Aggregate two values using the default summation operator."""
    return val_a + val_b

from_dict(data) classmethod

Instantiate from an object's dict representation.

Source code in declearn/utils/_aggregate.py
118
119
120
121
122
123
124
@classmethod
def from_dict(
    cls,
    data: Dict[str, Any],
) -> Self:
    """Instantiate from an object's dict representation."""
    return cls(**data)

prepare_for_secagg()

Return content for secure-aggregation of instances of this class.

Returns:

Name Type Description
secagg_fields Dict[str, Any]

Dict storing fields that are compatible with encryption and secure aggregation using mere summation.

clrtxt_fields Optional[Dict[str, Any]]

Dict storing fields that are to be shared in cleartext version. They will be aggregated using the same method as usual (aggregate_<name> or default_aggregate).

Raises:

Type Description
NotImplementedError

If this class does not support Secure Aggregation, and its contents should therefore not be shared.

Notes for developers

  • secagg_fields values should have one of the following types:
    • int (for positive integer values only)
    • float
    • numpy.ndarray (with any floating or integer dtype)
    • Vector
  • Classes that are incompatible with secure aggregation should implement a raise NotImplementedError statement, explaining whether SecAgg cannot or is yet-to-be supported.
Source code in declearn/utils/_aggregate.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def prepare_for_secagg(
    self,
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
    """Return content for secure-aggregation of instances of this class.

    Returns
    -------
    secagg_fields:
        Dict storing fields that are compatible with encryption
        and secure aggregation using mere summation.
    clrtxt_fields:
        Dict storing fields that are to be shared in cleartext
        version. They will be aggregated using the same method
        as usual (`aggregate_<name>` or `default_aggregate`).

    Raises
    ------
    NotImplementedError
        If this class does not support Secure Aggregation,
        and its contents should therefore not be shared.

    Notes for developers
    --------------------
    - `secagg_fields` values should have one of the following types:
        - `int` (for positive integer values only)
        - `float`
        - `numpy.ndarray` (with any floating or integer dtype)
        - `Vector`
    - Classes that are incompatible with secure aggregation should
      implement a `raise NotImplementedError` statement, explaining
      whether SecAgg cannot or is yet-to-be supported.
    """
    return self.to_dict(), None

to_dict()

Return a JSON-serializable dict representation of this instance.

Source code in declearn/utils/_aggregate.py
112
113
114
115
116
def to_dict(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable dict representation of this instance."""
    return dataclasses.asdict(self)