Skip to content

declearn.metrics.Metric

Bases: Generic[MetricStateT]

Abstract class defining an API to compute federative metrics.

This class defines an API to instantiate stateful containers for one or multiple metrics, that enable computing the final results through iterative update steps that may additionally be run in a federative way.

Usage

Single-party usage:

>>> metric = MetricSubclass()
>>> metric.update(y_true, y_pred)  # take one update state
>>> metric.get_result()    # after one or multiple updates
>>> metric.reset()  # reset before a next evaluation round

Multiple-parties usage:

>>> # Instantiate 2+ metrics and run local update steps.
>>> metric_0 = MetricSubclass()
>>> metric_1 = MetricSubclass()
>>> metric_0.udpate(y_true_0, y_pred_0)
>>> metric_1.update(y_true_1, y_pred_1)
>>> # Gather and share metric states (aggregated information).
>>> states_0 = metric_0.get_states()  # metric_0 is unaltered
>>> states_1 = metric_1.get_states()  # metric_1 is unaltered
>>> # Compute results that aggregate info from both clients.
>>> states = states_0 + states_1
>>> metric_0.set_states(states)  # would work the same with metrics_1
>>> metric_0.get_result()

Abstract

To define a concrete Metric, one must subclass it and define:

  • name: str class attribute Name identifier of the class (should be unique across existing Metric classes). Also used for automatic types-registration of the class (see Inheritance section below).
  • build_initial_states() -> MetricState: Return the initial states for this Metric instance. This method is called to initialize the _states attribute, that should be used and updated by other abstract methods.
  • update(y_true: np.ndarray, y_pred: np.ndarray, s_wght: np.ndarray|None): Update the metric's internal state based on a data batch. This method should update self._states in-place.
  • get_result() -> dict[str, (float | np.ndarray)]: Compute the metric(s), based on the current state variables. This method should make use of self._states and prevent side effects on its contents.

Overridable

Some methods may be overridden based on the concrete Metric's needs:

  • reset(): Reset the metric to its initial state.
  • get_states() -> MetricState: Return a copy of the current state variables.
  • set_states(MetricState): Replace current state variables with a copy of inputs.

Finally, depending on the hyper-parameters defined by the subclass's __init__, one should adjust JSON-configuration-interfacing methods:

  • get_config() -> dict[str, any]: Return a JSON-serializable configuration dict for this Metric.
  • from_config(config: dict[str, any]) -> Self: Instantiate a Metric from its configuration dict.

Inheritance

When a subclass inheriting from Metric is declared, it is automatically registered under the "Metric" group using its class-attribute name. This can be prevented by adding register=False to the inheritance specs (e.g. class MyCls(Metric, register=False)).

See declearn.utils.register_type for details on types registration.

Source code in declearn/metrics/_api.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@create_types_registry(name="Metric")
class Metric(Generic[MetricStateT], metaclass=abc.ABCMeta):
    """Abstract class defining an API to compute federative metrics.

    This class defines an API to instantiate stateful containers
    for one or multiple metrics, that enable computing the final
    results through iterative update steps that may additionally
    be run in a federative way.

    Usage
    -----
    Single-party usage:
    ```
    >>> metric = MetricSubclass()
    >>> metric.update(y_true, y_pred)  # take one update state
    >>> metric.get_result()    # after one or multiple updates
    >>> metric.reset()  # reset before a next evaluation round
    ```

    Multiple-parties usage:
    ```
    >>> # Instantiate 2+ metrics and run local update steps.
    >>> metric_0 = MetricSubclass()
    >>> metric_1 = MetricSubclass()
    >>> metric_0.udpate(y_true_0, y_pred_0)
    >>> metric_1.update(y_true_1, y_pred_1)
    >>> # Gather and share metric states (aggregated information).
    >>> states_0 = metric_0.get_states()  # metric_0 is unaltered
    >>> states_1 = metric_1.get_states()  # metric_1 is unaltered
    >>> # Compute results that aggregate info from both clients.
    >>> states = states_0 + states_1
    >>> metric_0.set_states(states)  # would work the same with metrics_1
    >>> metric_0.get_result()
    ```

    Abstract
    --------
    To define a concrete Metric, one must subclass it and define:

    - name: str class attribute
        Name identifier of the class (should be unique across existing
        Metric classes). Also used for automatic types-registration of
        the class (see `Inheritance` section below).
    - build_initial_states() -> MetricState:
        Return the initial states for this Metric instance.
        This method is called to initialize the `_states` attribute,
        that should be used and updated by other abstract methods.
    - update(y_true: np.ndarray, y_pred: np.ndarray, s_wght: np.ndarray|None):
        Update the metric's internal state based on a data batch.
        This method should update `self._states` in-place.
    - get_result() -> dict[str, (float | np.ndarray)]:
        Compute the metric(s), based on the current state variables.
        This method should make use of `self._states` and prevent
        side effects on its contents.

    Overridable
    -----------
    Some methods may be overridden based on the concrete Metric's needs:

    - reset():
        Reset the metric to its initial state.
    - get_states() -> MetricState:
        Return a copy of the current state variables.
    - set_states(MetricState):
        Replace current state variables with a copy of inputs.

    Finally, depending on the hyper-parameters defined by the subclass's
    `__init__`, one should adjust JSON-configuration-interfacing methods:

    - get_config() -> dict[str, any]:
        Return a JSON-serializable configuration dict for this Metric.
    - from_config(config: dict[str, any]) -> Self:
        Instantiate a Metric from its configuration dict.

    Inheritance
    -----------
    When a subclass inheriting from `Metric` is declared, it is automatically
    registered under the "Metric" group using its class-attribute `name`.
    This can be prevented by adding `register=False` to the inheritance specs
    (e.g. `class MyCls(Metric, register=False)`).

    See `declearn.utils.register_type` for details on types registration.
    """

    name: ClassVar[str]
    """Name identifier of the class, unique across Metric classes."""

    state_cls: ClassVar[Type[MetricState]]
    """Type of 'MetricState' data structure used by this 'Metric' class."""

    def __init__(
        self,
    ) -> None:
        """Instantiate the metric object."""
        self._states = self.build_initial_states()

    @abc.abstractmethod
    def build_initial_states(
        self,
    ) -> MetricStateT:
        """Return the initial states for this Metric instance.

        Returns
        -------
        states:
            Initial internal states for this object, as a `MetricState`.
        """

    @abc.abstractmethod
    def get_result(
        self,
    ) -> Dict[str, Union[float, np.ndarray]]:
        """Compute finalized metric(s), based on the current state variables.

        Returns
        -------
        results: dict[str, float or numpy.ndarray]
            Dict of named result metrics, that may either be
            unitary float scores or numpy arrays.
        """

    @abc.abstractmethod
    def update(
        self,
        y_true: np.ndarray,
        y_pred: np.ndarray,
        s_wght: Optional[np.ndarray] = None,
    ) -> None:
        """Update the metric's internal state based on a data batch.

        Parameters
        ----------
        y_true: numpy.ndarray
            True labels or values that were to be predicted.
        y_pred: numpy.ndarray
            Predictions (scores or values) that are to be evaluated.
        s_wght: numpy.ndarray or None, default=None
            Optional sample weights to take into account in scores.
        """

    def reset(
        self,
    ) -> None:
        """Reset the metric to its initial state."""
        self._states = self.build_initial_states()

    def get_states(
        self,
    ) -> MetricStateT:
        """Return a copy of the current state variables.

        This method is designed to expose and share partial results
        that may be aggregated with those of other instances of the
        same metric before computing overall results.

        Returns
        -------
        states:
            Copy of current states, as a `MetricState` instance.
        """
        return deepcopy(self._states)

    def set_states(
        self,
        states: MetricStateT,
    ) -> None:
        """Replace internal states with a copy of incoming ones.

        Parameters
        ----------
        states:
            Replacement states, as a compatible `MetricState` instance.

        Raises
        ------
        TypeError
            If `states` is of improper type.
        """
        if not isinstance(states, self.state_cls):
            raise TypeError(
                f"'{self.__class__.__name__}.set_states' expected "
                f"'{self.state_cls}' inputs, got '{type(states)}'."
            )
        self._states = deepcopy(states)  # type: ignore

    def agg_states(
        self,
        states: MetricStateT,
    ) -> None:
        """Aggregate provided state variables into self ones.

        This method is DEPRECATED as of DecLearn v2.4, in favor of
        merely aggregating `MetricState` instances, using either
        their `aggregate` method or the overloaded `+` operator.
        It will be removed in DecLearn 2.6 and/or 3.0.

        This method is designed to aggregate results from multiple
        similar metrics objects into a single one before computing
        its results.

        Parameters
        ----------
        states:
            `MetricState` emitted by another instance of this class
            via its `get_states` method.

        Raises
        ------
        TypeError
            If `states` is of improper type.
        """
        warnings.warn(
            "'Metric.agg_states' was deprecated in DecLearn v2.4, in favor "
            "of aggregating 'MetricState' instances directly, and setting "
            "final aggregated states using 'Metric.set_state'. It will be "
            "removed in DecLearn 2.6 and/or 3.0.",
            DeprecationWarning,
        )
        if not isinstance(states, self.state_cls):
            raise TypeError(
                f"'{self.__class__.__name__}.set_states' expected "
                f"'{self.state_cls}' inputs, got '{type(states)}'."
            )
        self.set_states(self._states + states)

    def __init_subclass__(
        cls,
        register: bool = True,
        **kwargs: Any,
    ) -> None:
        """Automatically type-register Metric subclasses."""
        super().__init_subclass__(**kwargs)
        if register:
            register_type(cls, name=cls.name, group="Metric")

    def get_config(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable configuration dict for this Metric."""
        return {}

    @classmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate a Metric from its configuration dict."""
        return cls(**config)

    @staticmethod
    def from_specs(
        name: str,
        config: Optional[Dict[str, Any]] = None,
    ) -> "Metric":
        """Instantiate a Metric from its registered name and config dict.

        Parameters
        ----------
        name: str
            Name based on which the metric can be retrieved.
            Available as a class attribute.
        config: dict[str, any] or None
            Configuration dict of the metric, that is to be
            passed to its `from_config` class constructor.

        Raises
        ------
        KeyError
            If the provided `name` fails to be mapped to a registered
            Metric subclass.
        """
        try:
            cls = access_registered(name, group="Metric")
        except KeyError as exc:
            raise KeyError(
                f"Failed to retrieve Metric subclass from name '{name}'."
            ) from exc
        return cls.from_config(config or {})

    @staticmethod
    def _prepare_sample_weights(
        s_wght: Optional[np.ndarray],
        n_samples: int,
    ) -> np.ndarray:
        """Flatten or generate sample weights and validate their shape.

        This method is a shared util that may or may not be used as part
        of concrete Metric classes' backend depending on their formula.

        Parameters
        ----------
        s_wght: np.ndarray or None
            1-d (or squeezable) array of sample-wise positive scalar
            weights. If None, one will be generated, with one values.
        n_samples: int
            Expected length of the sample weights.

        Returns
        -------
        s_wght: np.ndarray
            Input (opt. squeezed) `s_wght`, or `np.ones(n_samples)`
            if input was None.

        Raises
        ------
        ValueError
            If the input array has improper shape or negative values.
        """
        if s_wght is None:
            return np.ones(shape=(n_samples,))
        s_wght = s_wght.squeeze()
        if s_wght.shape != (n_samples,) or np.any(s_wght < 0):
            raise ValueError(
                "Improper shape for 's_wght': should be a 1-d array "
                "of sample-wise positive scalar weights."
            )
        return s_wght

name: ClassVar[str] class-attribute

Name identifier of the class, unique across Metric classes.

state_cls: ClassVar[Type[MetricState]] class-attribute

Type of 'MetricState' data structure used by this 'Metric' class.

__init__()

Instantiate the metric object.

Source code in declearn/metrics/_api.py
156
157
158
159
160
def __init__(
    self,
) -> None:
    """Instantiate the metric object."""
    self._states = self.build_initial_states()

__init_subclass__(register=True, **kwargs)

Automatically type-register Metric subclasses.

Source code in declearn/metrics/_api.py
291
292
293
294
295
296
297
298
299
def __init_subclass__(
    cls,
    register: bool = True,
    **kwargs: Any,
) -> None:
    """Automatically type-register Metric subclasses."""
    super().__init_subclass__(**kwargs)
    if register:
        register_type(cls, name=cls.name, group="Metric")

agg_states(states)

Aggregate provided state variables into self ones.

This method is DEPRECATED as of DecLearn v2.4, in favor of merely aggregating MetricState instances, using either their aggregate method or the overloaded + operator. It will be removed in DecLearn 2.6 and/or 3.0.

This method is designed to aggregate results from multiple similar metrics objects into a single one before computing its results.

Parameters:

Name Type Description Default
states MetricStateT

MetricState emitted by another instance of this class via its get_states method.

required

Raises:

Type Description
TypeError

If states is of improper type.

Source code in declearn/metrics/_api.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def agg_states(
    self,
    states: MetricStateT,
) -> None:
    """Aggregate provided state variables into self ones.

    This method is DEPRECATED as of DecLearn v2.4, in favor of
    merely aggregating `MetricState` instances, using either
    their `aggregate` method or the overloaded `+` operator.
    It will be removed in DecLearn 2.6 and/or 3.0.

    This method is designed to aggregate results from multiple
    similar metrics objects into a single one before computing
    its results.

    Parameters
    ----------
    states:
        `MetricState` emitted by another instance of this class
        via its `get_states` method.

    Raises
    ------
    TypeError
        If `states` is of improper type.
    """
    warnings.warn(
        "'Metric.agg_states' was deprecated in DecLearn v2.4, in favor "
        "of aggregating 'MetricState' instances directly, and setting "
        "final aggregated states using 'Metric.set_state'. It will be "
        "removed in DecLearn 2.6 and/or 3.0.",
        DeprecationWarning,
    )
    if not isinstance(states, self.state_cls):
        raise TypeError(
            f"'{self.__class__.__name__}.set_states' expected "
            f"'{self.state_cls}' inputs, got '{type(states)}'."
        )
    self.set_states(self._states + states)

build_initial_states() abstractmethod

Return the initial states for this Metric instance.

Returns:

Name Type Description
states MetricStateT

Initial internal states for this object, as a MetricState.

Source code in declearn/metrics/_api.py
162
163
164
165
166
167
168
169
170
171
172
@abc.abstractmethod
def build_initial_states(
    self,
) -> MetricStateT:
    """Return the initial states for this Metric instance.

    Returns
    -------
    states:
        Initial internal states for this object, as a `MetricState`.
    """

from_config(config) classmethod

Instantiate a Metric from its configuration dict.

Source code in declearn/metrics/_api.py
307
308
309
310
311
312
313
@classmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate a Metric from its configuration dict."""
    return cls(**config)

from_specs(name, config=None) staticmethod

Instantiate a Metric from its registered name and config dict.

Parameters:

Name Type Description Default
name str

Name based on which the metric can be retrieved. Available as a class attribute.

required
config Optional[Dict[str, Any]]

Configuration dict of the metric, that is to be passed to its from_config class constructor.

None

Raises:

Type Description
KeyError

If the provided name fails to be mapped to a registered Metric subclass.

Source code in declearn/metrics/_api.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@staticmethod
def from_specs(
    name: str,
    config: Optional[Dict[str, Any]] = None,
) -> "Metric":
    """Instantiate a Metric from its registered name and config dict.

    Parameters
    ----------
    name: str
        Name based on which the metric can be retrieved.
        Available as a class attribute.
    config: dict[str, any] or None
        Configuration dict of the metric, that is to be
        passed to its `from_config` class constructor.

    Raises
    ------
    KeyError
        If the provided `name` fails to be mapped to a registered
        Metric subclass.
    """
    try:
        cls = access_registered(name, group="Metric")
    except KeyError as exc:
        raise KeyError(
            f"Failed to retrieve Metric subclass from name '{name}'."
        ) from exc
    return cls.from_config(config or {})

get_config()

Return a JSON-serializable configuration dict for this Metric.

Source code in declearn/metrics/_api.py
301
302
303
304
305
def get_config(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable configuration dict for this Metric."""
    return {}

get_result() abstractmethod

Compute finalized metric(s), based on the current state variables.

Returns:

Name Type Description
results dict[str, float or numpy.ndarray]

Dict of named result metrics, that may either be unitary float scores or numpy arrays.

Source code in declearn/metrics/_api.py
174
175
176
177
178
179
180
181
182
183
184
185
@abc.abstractmethod
def get_result(
    self,
) -> Dict[str, Union[float, np.ndarray]]:
    """Compute finalized metric(s), based on the current state variables.

    Returns
    -------
    results: dict[str, float or numpy.ndarray]
        Dict of named result metrics, that may either be
        unitary float scores or numpy arrays.
    """

get_states()

Return a copy of the current state variables.

This method is designed to expose and share partial results that may be aggregated with those of other instances of the same metric before computing overall results.

Returns:

Name Type Description
states MetricStateT

Copy of current states, as a MetricState instance.

Source code in declearn/metrics/_api.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def get_states(
    self,
) -> MetricStateT:
    """Return a copy of the current state variables.

    This method is designed to expose and share partial results
    that may be aggregated with those of other instances of the
    same metric before computing overall results.

    Returns
    -------
    states:
        Copy of current states, as a `MetricState` instance.
    """
    return deepcopy(self._states)

reset()

Reset the metric to its initial state.

Source code in declearn/metrics/_api.py
206
207
208
209
210
def reset(
    self,
) -> None:
    """Reset the metric to its initial state."""
    self._states = self.build_initial_states()

set_states(states)

Replace internal states with a copy of incoming ones.

Parameters:

Name Type Description Default
states MetricStateT

Replacement states, as a compatible MetricState instance.

required

Raises:

Type Description
TypeError

If states is of improper type.

Source code in declearn/metrics/_api.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def set_states(
    self,
    states: MetricStateT,
) -> None:
    """Replace internal states with a copy of incoming ones.

    Parameters
    ----------
    states:
        Replacement states, as a compatible `MetricState` instance.

    Raises
    ------
    TypeError
        If `states` is of improper type.
    """
    if not isinstance(states, self.state_cls):
        raise TypeError(
            f"'{self.__class__.__name__}.set_states' expected "
            f"'{self.state_cls}' inputs, got '{type(states)}'."
        )
    self._states = deepcopy(states)  # type: ignore

update(y_true, y_pred, s_wght=None) abstractmethod

Update the metric's internal state based on a data batch.

Parameters:

Name Type Description Default
y_true np.ndarray

True labels or values that were to be predicted.

required
y_pred np.ndarray

Predictions (scores or values) that are to be evaluated.

required
s_wght Optional[np.ndarray]

Optional sample weights to take into account in scores.

None
Source code in declearn/metrics/_api.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@abc.abstractmethod
def update(
    self,
    y_true: np.ndarray,
    y_pred: np.ndarray,
    s_wght: Optional[np.ndarray] = None,
) -> None:
    """Update the metric's internal state based on a data batch.

    Parameters
    ----------
    y_true: numpy.ndarray
        True labels or values that were to be predicted.
    y_pred: numpy.ndarray
        Predictions (scores or values) that are to be evaluated.
    s_wght: numpy.ndarray or None, default=None
        Optional sample weights to take into account in scores.
    """