Skip to content

declearn.metrics.Metric

Abstract class defining an API to compute federative metrics.

This class defines an API to instantiate stateful containers for one or multiple metrics, that enable computing the final results through iterative update steps that may additionally be run in a federative way.

Usage

Single-party usage:

>>> metric = MetricSubclass()
>>> metric.update(y_true, y_pred)  # take one update state
>>> metric.get_result()    # after one or multiple updates
>>> metric.reset()  # reset before a next evaluation round

Multiple-parties usage:

>>> # Instantiate 2+ metrics and run local update steps.
>>> metric_0 = MetricSubclass()
>>> metric_1 = MetricSubclass()
>>> metric_0.udpate(y_true_0, y_pred_0)
>>> metric_1.update(y_true_1, y_pred_1)
>>> # Gather and share metric states (aggregated information).
>>> states_0 = metric_0.get_states()  # metrics_0 is unaltered
>>> metric_1.agg_states(states_0)     # metrics_1 is updated
>>> # Compute results that aggregate info from both clients.
>>> metric_1.get_result()

Abstract

To define a concrete Metric, one must subclass it and define:

  • name: str class attribute Name identifier of the class (should be unique across existing Metric classes). Also used for automatic types-registration of the class (see Inheritance section below).
  • _build_states() -> dict[str, (float | np.ndarray)]: Build and return an ensemble of state variables. This method is called to initialize the _states attribute, that should be used and updated by other abstract methods.
  • update(y_true: np.ndarray, y_pred: np.ndarray, s_wght: np.ndarray|None): Update the metric's internal state based on a data batch. This method should update self._states in-place.
  • get_result() -> dict[str, (float | np.ndarray)]: Compute the metric(s), based on the current state variables. This method should make use of self._states and prevent side effects on its contents.

Overridable

Some methods may be overridden based on the concrete Metric's needs. The most imporant one is the states-aggregation method:

  • agg_states(states: dict[str, (float | np.ndarray)]: Aggregate provided state variables into self ones. By default, it expects input and internal states to have similar specifications, and aggregates them by summation, which might no be proper depending on the actual metric.

A pair of methods may be extended to cover non-self._states-contained variables:

  • reset(): Reset the metric to its initial state.
  • get_states() -> dict[str, (float | np.ndarray)]: Return a copy of the current state variables.

Finally, depending on the hyper-parameters defined by the subclass's __init__, one should adjust JSON-configuration-interfacing methods:

  • get_config() -> dict[str, any]: Return a JSON-serializable configuration dict for this Metric.
  • from_config(config: dict[str, any]) -> Self: Instantiate a Metric from its configuration dict.

Inheritance

When a subclass inheriting from Metric is declared, it is automatically registered under the "Metric" group using its class-attribute name. This can be prevented by adding register=False to the inheritance specs (e.g. class MyCls(Metric, register=False)).

See declearn.utils.register_type for details on types registration.

Source code in declearn/metrics/_api.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
@create_types_registry(name="Metric")
class Metric(metaclass=ABCMeta):
    """Abstract class defining an API to compute federative metrics.

    This class defines an API to instantiate stateful containers
    for one or multiple metrics, that enable computing the final
    results through iterative update steps that may additionally
    be run in a federative way.

    Usage
    -----
    Single-party usage:
    ```
    >>> metric = MetricSubclass()
    >>> metric.update(y_true, y_pred)  # take one update state
    >>> metric.get_result()    # after one or multiple updates
    >>> metric.reset()  # reset before a next evaluation round
    ```

    Multiple-parties usage:
    ```
    >>> # Instantiate 2+ metrics and run local update steps.
    >>> metric_0 = MetricSubclass()
    >>> metric_1 = MetricSubclass()
    >>> metric_0.udpate(y_true_0, y_pred_0)
    >>> metric_1.update(y_true_1, y_pred_1)
    >>> # Gather and share metric states (aggregated information).
    >>> states_0 = metric_0.get_states()  # metrics_0 is unaltered
    >>> metric_1.agg_states(states_0)     # metrics_1 is updated
    >>> # Compute results that aggregate info from both clients.
    >>> metric_1.get_result()
    ```

    Abstract
    --------
    To define a concrete Metric, one must subclass it and define:

    - name: str class attribute
        Name identifier of the class (should be unique across existing
        Metric classes). Also used for automatic types-registration of
        the class (see `Inheritance` section below).
    - _build_states() -> dict[str, (float | np.ndarray)]:
        Build and return an ensemble of state variables.
        This method is called to initialize the `_states` attribute,
        that should be used and updated by other abstract methods.
    - update(y_true: np.ndarray, y_pred: np.ndarray, s_wght: np.ndarray|None):
        Update the metric's internal state based on a data batch.
        This method should update `self._states` in-place.
    - get_result() -> dict[str, (float | np.ndarray)]:
        Compute the metric(s), based on the current state variables.
        This method should make use of `self._states` and prevent
        side effects on its contents.

    Overridable
    -----------
    Some methods may be overridden based on the concrete Metric's needs.
    The most imporant one is the states-aggregation method:

    - agg_states(states: dict[str, (float | np.ndarray)]:
        Aggregate provided state variables into self ones.
        By default, it expects input and internal states to have
        similar specifications, and aggregates them by summation,
        which might no be proper depending on the actual metric.

    A pair of methods may be extended to cover non-`self._states`-contained
    variables:

    - reset():
        Reset the metric to its initial state.
    - get_states() -> dict[str, (float | np.ndarray)]:
        Return a copy of the current state variables.


    Finally, depending on the hyper-parameters defined by the subclass's
    `__init__`, one should adjust JSON-configuration-interfacing methods:

    - get_config() -> dict[str, any]:
        Return a JSON-serializable configuration dict for this Metric.
    - from_config(config: dict[str, any]) -> Self:
        Instantiate a Metric from its configuration dict.

    Inheritance
    -----------
    When a subclass inheriting from `Metric` is declared, it is automatically
    registered under the "Metric" group using its class-attribute `name`.
    This can be prevented by adding `register=False` to the inheritance specs
    (e.g. `class MyCls(Metric, register=False)`).

    See `declearn.utils.register_type` for details on types registration.
    """

    name: ClassVar[str] = NotImplemented
    """Name identifier of the class, unique across Metric classes."""

    def __init__(
        self,
    ) -> None:
        """Instantiate the metric object."""
        self._states = self._build_states()

    @abstractmethod
    def _build_states(
        self,
    ) -> Dict[str, Union[float, np.ndarray]]:
        """Build and return an ensemble of state variables.

        The state variables stored in this dict are (by default)
        sharable with other instances of this metric and may be
        combined with the latter's through summation in order to
        compute final metrics in a federated way.

        Note that the update process may be altered by extending
        or overridding the `agg_states` method.

        Returns
        -------
        states: dict[str, float or numpy.ndarray]
            Dict of initial states that are to be assigned as
            `_states` private attribute.
        """

    @abstractmethod
    def get_result(
        self,
    ) -> Dict[str, Union[float, np.ndarray]]:
        """Compute the metric(s), based on the current state variables.

        Returns
        -------
        results: dict[str, float or numpy.ndarray]
            Dict of named result metrics, that may either be
            unitary float scores or numpy arrays.
        """

    @abstractmethod
    def update(
        self,
        y_true: np.ndarray,
        y_pred: np.ndarray,
        s_wght: Optional[np.ndarray] = None,
    ) -> None:
        """Update the metric's internal state based on a data batch.

        Parameters
        ----------
        y_true: numpy.ndarray
            True labels or values that were to be predicted.
        y_pred: numpy.ndarray
            Predictions (scores or values) that are to be evaluated.
        s_wght: numpy.ndarray or None, default=None
            Optional sample weights to take into account in scores.
        """

    def reset(
        self,
    ) -> None:
        """Reset the metric to its initial state."""
        self._states = self._build_states()

    def get_states(
        self,
    ) -> Dict[str, Union[float, np.ndarray]]:
        """Return a copy of the current state variables.

        This method is designed to expose and share partial results
        that may be aggregated with those of other instances of the
        same metric before computing overall results.

        Returns
        -------
        states: dict[str, float or numpy.ndarray]
            Dict of states that may be fed to another instance of
            this class via its `agg_states` method.
        """
        return deepcopy(self._states)

    def agg_states(
        self,
        states: Dict[str, Union[float, np.ndarray]],
    ) -> None:
        """Aggregate provided state variables into self ones.

        This method is designed to aggregate results from multiple
        similar metrics objects into a single one before computing
        its results.

        Parameters
        ----------
        states: dict[str, float or numpy.ndarray]
            Dict of states emitted by another instance of this class
            via its `get_states` method.

        Raises
        ------
        KeyError
            If any state variable is missing from `states`.
        TypeError
            If any state variable is of unproper type.
        ValueError
            If any array state variable is of unproper shape.
        """
        final = {}  # type: Dict[str, Union[float, np.ndarray]]
        # Iteratively compute sum-aggregated states, running sanity checks.
        for name, own in self._states.items():
            if name not in states:
                raise KeyError(f"Missing required state variable: '{name}'.")
            oth = states[name]
            if not isinstance(oth, type(own)):
                raise TypeError(f"Input state '{name}' is of unproper type.")
            if isinstance(own, np.ndarray):
                if own.shape != oth.shape:  # type: ignore
                    msg = f"Input state '{name}' is of unproper shape."
                    raise ValueError(msg)
            final[name] = own + oth
        # Assign the sum-aggregated states.
        self._states = final

    def __init_subclass__(
        cls,
        register: bool = True,
        **kwargs: Any,
    ) -> None:
        """Automatically type-register Metric subclasses."""
        super().__init_subclass__(**kwargs)
        if register:
            register_type(cls, name=cls.name, group="Metric")

    def get_config(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable configuration dict for this Metric."""
        return {}

    @classmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate a Metric from its configuration dict."""
        return cls(**config)

    @staticmethod
    def from_specs(
        name: str,
        config: Optional[Dict[str, Any]] = None,
    ) -> "Metric":
        """Instantiate a Metric from its registered name and config dict.

        Parameters
        ----------
        name: str
            Name based on which the metric can be retrieved.
            Available as a class attribute.
        config: dict[str, any] or None
            Configuration dict of the metric, that is to be
            passed to its `from_config` class constructor.

        Raises
        ------
        KeyError
            If the provided `name` fails to be mapped to a registered
            Metric subclass.
        """
        try:
            cls = access_registered(name, group="Metric")
        except KeyError as exc:
            raise KeyError(
                f"Failed to retrieve Metric subclass from name '{name}'."
            ) from exc
        return cls.from_config(config or {})

    @staticmethod
    def _prepare_sample_weights(
        s_wght: Optional[np.ndarray],
        n_samples: int,
    ) -> np.ndarray:
        """Flatten or generate sample weights and validate their shape.

        This method is a shared util that may or may not be used as part
        of concrete Metric classes' backend depending on their formula.

        Parameters
        ----------
        s_wght: np.ndarray or None
            1-d (or squeezable) array of sample-wise positive scalar
            weights. If None, one will be generated, with one values.
        n_samples: int
            Expected length of the sample weights.

        Returns
        -------
        s_wght: np.ndarray
            Input (opt. squeezed) `s_wght`, or `np.ones(n_samples)`
            if input was None.

        Raises
        ------
        ValueError
            If the input array has improper shape or negative values.
        """
        if s_wght is None:
            return np.ones(shape=(n_samples,))
        s_wght = s_wght.squeeze()
        if s_wght.shape != (n_samples,) or np.any(s_wght < 0):
            raise ValueError(
                "Improper shape for 's_wght': should be a 1-d array "
                "of sample-wise positive scalar weights."
            )
        return s_wght

    @staticmethod
    def normalize_weights(s_wght: np.ndarray) -> np.ndarray:
        """Utility method to ensure weights sum to one.

        Note that this method may or may not be used depending on
        the actual `Metric` considered, and is merely provided as
        a utility to metric developers.
        """
        warn = DeprecationWarning(
            "'Metric.normalize_weights' is unfit for the iterative "
            "nature of the metric-computation process. It will be "
            "removed from the Metric API in declearn v3.0."
        )
        warnings.warn(warn)
        if s_wght.sum():
            s_wght /= s_wght.sum()
        else:
            raise ValueError(
                "Weights provided sum to zero, please provide only "
                "positive weights with at least one non-zero weight."
            )
        return s_wght

name: ClassVar[str] = NotImplemented class-attribute

Name identifier of the class, unique across Metric classes.

__init__()

Instantiate the metric object.

Source code in declearn/metrics/_api.py
133
134
135
136
137
def __init__(
    self,
) -> None:
    """Instantiate the metric object."""
    self._states = self._build_states()

__init_subclass__(register=True, **kwargs)

Automatically type-register Metric subclasses.

Source code in declearn/metrics/_api.py
256
257
258
259
260
261
262
263
264
def __init_subclass__(
    cls,
    register: bool = True,
    **kwargs: Any,
) -> None:
    """Automatically type-register Metric subclasses."""
    super().__init_subclass__(**kwargs)
    if register:
        register_type(cls, name=cls.name, group="Metric")

agg_states(states)

Aggregate provided state variables into self ones.

This method is designed to aggregate results from multiple similar metrics objects into a single one before computing its results.

Parameters:

Name Type Description Default
states Dict[str, Union[float, np.ndarray]]

Dict of states emitted by another instance of this class via its get_states method.

required

Raises:

Type Description
KeyError

If any state variable is missing from states.

TypeError

If any state variable is of unproper type.

ValueError

If any array state variable is of unproper shape.

Source code in declearn/metrics/_api.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def agg_states(
    self,
    states: Dict[str, Union[float, np.ndarray]],
) -> None:
    """Aggregate provided state variables into self ones.

    This method is designed to aggregate results from multiple
    similar metrics objects into a single one before computing
    its results.

    Parameters
    ----------
    states: dict[str, float or numpy.ndarray]
        Dict of states emitted by another instance of this class
        via its `get_states` method.

    Raises
    ------
    KeyError
        If any state variable is missing from `states`.
    TypeError
        If any state variable is of unproper type.
    ValueError
        If any array state variable is of unproper shape.
    """
    final = {}  # type: Dict[str, Union[float, np.ndarray]]
    # Iteratively compute sum-aggregated states, running sanity checks.
    for name, own in self._states.items():
        if name not in states:
            raise KeyError(f"Missing required state variable: '{name}'.")
        oth = states[name]
        if not isinstance(oth, type(own)):
            raise TypeError(f"Input state '{name}' is of unproper type.")
        if isinstance(own, np.ndarray):
            if own.shape != oth.shape:  # type: ignore
                msg = f"Input state '{name}' is of unproper shape."
                raise ValueError(msg)
        final[name] = own + oth
    # Assign the sum-aggregated states.
    self._states = final

from_config(config) classmethod

Instantiate a Metric from its configuration dict.

Source code in declearn/metrics/_api.py
272
273
274
275
276
277
278
@classmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate a Metric from its configuration dict."""
    return cls(**config)

from_specs(name, config=None) staticmethod

Instantiate a Metric from its registered name and config dict.

Parameters:

Name Type Description Default
name str

Name based on which the metric can be retrieved. Available as a class attribute.

required
config Optional[Dict[str, Any]]

Configuration dict of the metric, that is to be passed to its from_config class constructor.

None

Raises:

Type Description
KeyError

If the provided name fails to be mapped to a registered Metric subclass.

Source code in declearn/metrics/_api.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
@staticmethod
def from_specs(
    name: str,
    config: Optional[Dict[str, Any]] = None,
) -> "Metric":
    """Instantiate a Metric from its registered name and config dict.

    Parameters
    ----------
    name: str
        Name based on which the metric can be retrieved.
        Available as a class attribute.
    config: dict[str, any] or None
        Configuration dict of the metric, that is to be
        passed to its `from_config` class constructor.

    Raises
    ------
    KeyError
        If the provided `name` fails to be mapped to a registered
        Metric subclass.
    """
    try:
        cls = access_registered(name, group="Metric")
    except KeyError as exc:
        raise KeyError(
            f"Failed to retrieve Metric subclass from name '{name}'."
        ) from exc
    return cls.from_config(config or {})

get_config()

Return a JSON-serializable configuration dict for this Metric.

Source code in declearn/metrics/_api.py
266
267
268
269
270
def get_config(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable configuration dict for this Metric."""
    return {}

get_result() abstractmethod

Compute the metric(s), based on the current state variables.

Returns:

Name Type Description
results dict[str, float or numpy.ndarray]

Dict of named result metrics, that may either be unitary float scores or numpy arrays.

Source code in declearn/metrics/_api.py
160
161
162
163
164
165
166
167
168
169
170
171
@abstractmethod
def get_result(
    self,
) -> Dict[str, Union[float, np.ndarray]]:
    """Compute the metric(s), based on the current state variables.

    Returns
    -------
    results: dict[str, float or numpy.ndarray]
        Dict of named result metrics, that may either be
        unitary float scores or numpy arrays.
    """

get_states()

Return a copy of the current state variables.

This method is designed to expose and share partial results that may be aggregated with those of other instances of the same metric before computing overall results.

Returns:

Name Type Description
states dict[str, float or numpy.ndarray]

Dict of states that may be fed to another instance of this class via its agg_states method.

Source code in declearn/metrics/_api.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def get_states(
    self,
) -> Dict[str, Union[float, np.ndarray]]:
    """Return a copy of the current state variables.

    This method is designed to expose and share partial results
    that may be aggregated with those of other instances of the
    same metric before computing overall results.

    Returns
    -------
    states: dict[str, float or numpy.ndarray]
        Dict of states that may be fed to another instance of
        this class via its `agg_states` method.
    """
    return deepcopy(self._states)

normalize_weights(s_wght) staticmethod

Utility method to ensure weights sum to one.

Note that this method may or may not be used depending on the actual Metric considered, and is merely provided as a utility to metric developers.

Source code in declearn/metrics/_api.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
@staticmethod
def normalize_weights(s_wght: np.ndarray) -> np.ndarray:
    """Utility method to ensure weights sum to one.

    Note that this method may or may not be used depending on
    the actual `Metric` considered, and is merely provided as
    a utility to metric developers.
    """
    warn = DeprecationWarning(
        "'Metric.normalize_weights' is unfit for the iterative "
        "nature of the metric-computation process. It will be "
        "removed from the Metric API in declearn v3.0."
    )
    warnings.warn(warn)
    if s_wght.sum():
        s_wght /= s_wght.sum()
    else:
        raise ValueError(
            "Weights provided sum to zero, please provide only "
            "positive weights with at least one non-zero weight."
        )
    return s_wght

reset()

Reset the metric to its initial state.

Source code in declearn/metrics/_api.py
192
193
194
195
196
def reset(
    self,
) -> None:
    """Reset the metric to its initial state."""
    self._states = self._build_states()

update(y_true, y_pred, s_wght=None) abstractmethod

Update the metric's internal state based on a data batch.

Parameters:

Name Type Description Default
y_true np.ndarray

True labels or values that were to be predicted.

required
y_pred np.ndarray

Predictions (scores or values) that are to be evaluated.

required
s_wght Optional[np.ndarray]

Optional sample weights to take into account in scores.

None
Source code in declearn/metrics/_api.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
@abstractmethod
def update(
    self,
    y_true: np.ndarray,
    y_pred: np.ndarray,
    s_wght: Optional[np.ndarray] = None,
) -> None:
    """Update the metric's internal state based on a data batch.

    Parameters
    ----------
    y_true: numpy.ndarray
        True labels or values that were to be predicted.
    y_pred: numpy.ndarray
        Predictions (scores or values) that are to be evaluated.
    s_wght: numpy.ndarray or None, default=None
        Optional sample weights to take into account in scores.
    """