Skip to content

declearn.metrics.MetricSet

Wrapper for an ensemble of Metric objects.

This class is designed to wrap together a collection of Metric instances (see declearn.metric.Metric), and expose the key API methods in a grouped fashion, i.e. internalizing the boilerplate loops on the metrics to update them based on a batch of inputs, gather their states, compute their end results, reset them, etc.

This class also enables specifying an ensemble of metrics through a modular specification system, where each metric may be provided either as an instance, a name identifier string, or a tuple with both the former identifier and a configuration dict (enabling the use of non-default hyper-parameters).

Source code in declearn/metrics/_wrapper.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
class MetricSet:
    """Wrapper for an ensemble of Metric objects.

    This class is designed to wrap together a collection of `Metric`
    instances (see `declearn.metric.Metric`), and expose the key API
    methods in a grouped fashion, i.e. internalizing the boilerplate
    loops on the metrics to update them based on a batch of inputs,
    gather their states, compute their end results, reset them, etc.

    This class also enables specifying an ensemble of metrics through
    a modular specification system, where each metric may be provided
    either as an instance, a name identifier string, or a tuple with
    both the former identifier and a configuration dict (enabling the
    use of non-default hyper-parameters).
    """

    def __init__(
        self,
        metrics: List[MetricInputType],
    ) -> None:
        """Instantiate the grouped ensemble of Metric instances.

        Parameters
        ----------
        metrics: list[Metric, str, tuple(str, dict[str, any])]
            List of metrics to bind together. The metrics may be provided
            either as a Metric instance, a name identifier string, or a
            tuple with both a name identifier and a configuration dict.

        Raises
        ------
        TypeError
            If one of the input `metrics` elements is of improper type.
        KeyError
            If a metric name identifier fails to be mapped to a Metric class.
        RuntimeError
            If multiple metrics are of the same final type.
        """
        # REVISE: store metrics into a Dict and adjust labels when needed
        self.metrics = []  # type: List[Metric]
        for metric in metrics:
            if isinstance(metric, str):
                metric = Metric.from_specs(metric)
            if isinstance(metric, (tuple, list)):
                if (
                    (len(metric) == 2)
                    and isinstance(metric[0], str)
                    and isinstance(metric[1], dict)
                ):
                    metric = Metric.from_specs(*metric)
            if not isinstance(metric, Metric):
                raise TypeError(
                    "'MetricSet' inputs must be Metric instances, string "
                    "identifiers or (string identifier, config dict) tuples."
                )
            self.metrics.append(metric)
        if len(set(type(m) for m in self.metrics)) < len(self.metrics):
            raise RuntimeError(
                "'MetricSet' cannot wrap multiple metrics of the same type."
            )

    @classmethod
    def from_specs(
        cls,
        metrics: Union[List[MetricInputType], "MetricSet", None],
    ) -> Self:
        """Type-check and/or transform inputs into a MetricSet instance.

        This classmethod is merely implemented to avoid duplicate and
        boilerplate code from polluting FL orchestrating classes.

        Parameters
        ----------
        metrics: list[MetricInputType] or MetricSet or None
            Inputs set up a MetricSet instance, instance to type-check
            or None, resulting in an empty MetricSet being returned.

        Returns
        -------
        metricset: MetricSet
            MetricSet instance, type-checked or instantiated from inputs.

        Raises
        ------
        TypeError
            If `metrics` is of improper type.

        Other exceptions may be raised when calling this class's `__init__`.
        """
        if metrics is None:
            metrics = cls([])
        if isinstance(metrics, list):
            metrics = cls(metrics)
        if not isinstance(metrics, cls):
            raise TypeError(
                f"'metrics' should be a `{cls.__name__}`, a valid list of "
                "Metric instances and/or specs to wrap into one, or None."
            )
        return metrics

    def get_result(
        self,
    ) -> Dict[str, Union[float, np.ndarray]]:
        """Compute the metric(s), based on the current state variables.

        Returns
        -------
        results: dict[str, float or numpy.ndarray]
            Dict of named result metrics, that may either be
            unitary float scores or numpy arrays.
        """
        results = {}
        for metric in self.metrics:
            results.update(metric.get_result())
        return results

    def update(
        self,
        y_true: np.ndarray,
        y_pred: np.ndarray,
        s_wght: Optional[np.ndarray] = None,
    ) -> None:
        """Update the metric's internal state based on a data batch.

        Parameters
        ----------
        y_true: numpy.ndarray
            True labels or values that were to be predicted.
        y_pred: numpy.ndarray
            Predictions (scores or values) that are to be evaluated.
        s_wght: numpy.ndarray or None, default=None
            Optional sample weights to take into account in scores.
        """
        for metric in self.metrics:
            metric.update(y_true, y_pred, s_wght)

    def reset(
        self,
    ) -> None:
        """Reset the metric to its initial state."""
        for metric in self.metrics:
            metric.reset()

    def get_states(
        self,
    ) -> Dict[str, MetricState]:
        """Return a copy of the current state variables.

        This method is designed to expose and share partial results
        that may be aggregated with those of other instances of the
        same metric before computing overall results.

        Returns
        -------
        states:
            Dict of metric states that may be aggregated with their
            counterparts and re-assigned for finalization using the
            `set_states` then `get_result` methods of this object.
        """
        return {metric.name: metric.get_states() for metric in self.metrics}

    def set_states(
        self,
        states: Dict[str, MetricState],
    ) -> None:
        """Replace internal states with a copy of incoming ones.

        Parameters
        ----------
        states:
            Replacement states, as a compatible `MetricState` instance.

        Raises
        ------
        TypeError
            If any metric states are of improper type.
        """
        for metric in self.metrics:
            if metric.name in states:
                metric.set_states(states[metric.name])

    def agg_states(
        self,
        states: Dict[str, MetricState],
    ) -> None:
        """Aggregate provided state variables into self ones.

        This method is DEPRECATED as of DecLearn v2.4, in favor of
        merely aggregating `MetricState` instances, using either
        their `aggregate` method or the overloaded `+` operator.
        It will be removed in DecLearn 2.6 and/or 3.0.

        This method is designed to aggregate results from multiple
        similar metrics objects into a single one before computing
        its results.

        Parameters
        ----------
        states: dict[str, float or numpy.ndarray]
            Dict of states emitted by another instance of this class
            via its `get_states` method.

        Raises
        ------
        KeyError
            If any state variable is missing from `states`.
        TypeError
            If any state variable is of improper type.
        ValueError
            If any array state variable is of improper shape.
        """
        warnings.warn(
            "'MetricSet.agg_states' was deprecated in DecLearn v2.4, in favor "
            "of aggregating 'MetricState' instances directly, and setting "
            "final aggregated states using 'MetricSet.set_state'. It will be "
            "removed in DecLearn 2.6 and/or 3.0.",
            DeprecationWarning,
        )
        with warnings.catch_warnings():
            warnings.simplefilter(action="ignore", category=DeprecationWarning)
            for metric in self.metrics:
                if metric.name in states:
                    metric.agg_states(states[metric.name])

    def get_config(
        self,
    ) -> Dict[str, Any]:
        """Return a JSON-serializable configuration dict for this MetricSet."""
        cfg = [(metric.name, metric.get_config()) for metric in self.metrics]
        return {"metrics": cfg}

    @classmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate a MetricSet from its configuration dict."""
        return cls(**config)

__init__(metrics)

Instantiate the grouped ensemble of Metric instances.

Parameters:

Name Type Description Default
metrics List[MetricInputType]

List of metrics to bind together. The metrics may be provided either as a Metric instance, a name identifier string, or a tuple with both a name identifier and a configuration dict.

required

Raises:

Type Description
TypeError

If one of the input metrics elements is of improper type.

KeyError

If a metric name identifier fails to be mapped to a Metric class.

RuntimeError

If multiple metrics are of the same final type.

Source code in declearn/metrics/_wrapper.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    metrics: List[MetricInputType],
) -> None:
    """Instantiate the grouped ensemble of Metric instances.

    Parameters
    ----------
    metrics: list[Metric, str, tuple(str, dict[str, any])]
        List of metrics to bind together. The metrics may be provided
        either as a Metric instance, a name identifier string, or a
        tuple with both a name identifier and a configuration dict.

    Raises
    ------
    TypeError
        If one of the input `metrics` elements is of improper type.
    KeyError
        If a metric name identifier fails to be mapped to a Metric class.
    RuntimeError
        If multiple metrics are of the same final type.
    """
    # REVISE: store metrics into a Dict and adjust labels when needed
    self.metrics = []  # type: List[Metric]
    for metric in metrics:
        if isinstance(metric, str):
            metric = Metric.from_specs(metric)
        if isinstance(metric, (tuple, list)):
            if (
                (len(metric) == 2)
                and isinstance(metric[0], str)
                and isinstance(metric[1], dict)
            ):
                metric = Metric.from_specs(*metric)
        if not isinstance(metric, Metric):
            raise TypeError(
                "'MetricSet' inputs must be Metric instances, string "
                "identifiers or (string identifier, config dict) tuples."
            )
        self.metrics.append(metric)
    if len(set(type(m) for m in self.metrics)) < len(self.metrics):
        raise RuntimeError(
            "'MetricSet' cannot wrap multiple metrics of the same type."
        )

agg_states(states)

Aggregate provided state variables into self ones.

This method is DEPRECATED as of DecLearn v2.4, in favor of merely aggregating MetricState instances, using either their aggregate method or the overloaded + operator. It will be removed in DecLearn 2.6 and/or 3.0.

This method is designed to aggregate results from multiple similar metrics objects into a single one before computing its results.

Parameters:

Name Type Description Default
states Dict[str, MetricState]

Dict of states emitted by another instance of this class via its get_states method.

required

Raises:

Type Description
KeyError

If any state variable is missing from states.

TypeError

If any state variable is of improper type.

ValueError

If any array state variable is of improper shape.

Source code in declearn/metrics/_wrapper.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def agg_states(
    self,
    states: Dict[str, MetricState],
) -> None:
    """Aggregate provided state variables into self ones.

    This method is DEPRECATED as of DecLearn v2.4, in favor of
    merely aggregating `MetricState` instances, using either
    their `aggregate` method or the overloaded `+` operator.
    It will be removed in DecLearn 2.6 and/or 3.0.

    This method is designed to aggregate results from multiple
    similar metrics objects into a single one before computing
    its results.

    Parameters
    ----------
    states: dict[str, float or numpy.ndarray]
        Dict of states emitted by another instance of this class
        via its `get_states` method.

    Raises
    ------
    KeyError
        If any state variable is missing from `states`.
    TypeError
        If any state variable is of improper type.
    ValueError
        If any array state variable is of improper shape.
    """
    warnings.warn(
        "'MetricSet.agg_states' was deprecated in DecLearn v2.4, in favor "
        "of aggregating 'MetricState' instances directly, and setting "
        "final aggregated states using 'MetricSet.set_state'. It will be "
        "removed in DecLearn 2.6 and/or 3.0.",
        DeprecationWarning,
    )
    with warnings.catch_warnings():
        warnings.simplefilter(action="ignore", category=DeprecationWarning)
        for metric in self.metrics:
            if metric.name in states:
                metric.agg_states(states[metric.name])

from_config(config) classmethod

Instantiate a MetricSet from its configuration dict.

Source code in declearn/metrics/_wrapper.py
268
269
270
271
272
273
274
@classmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate a MetricSet from its configuration dict."""
    return cls(**config)

from_specs(metrics) classmethod

Type-check and/or transform inputs into a MetricSet instance.

This classmethod is merely implemented to avoid duplicate and boilerplate code from polluting FL orchestrating classes.

Parameters:

Name Type Description Default
metrics Union[List[MetricInputType], MetricSet, None]

Inputs set up a MetricSet instance, instance to type-check or None, resulting in an empty MetricSet being returned.

required

Returns:

Name Type Description
metricset MetricSet

MetricSet instance, type-checked or instantiated from inputs.

Raises:

Type Description
TypeError

If metrics is of improper type.

Other exceptions may be raised when calling this class's __init__.

Source code in declearn/metrics/_wrapper.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@classmethod
def from_specs(
    cls,
    metrics: Union[List[MetricInputType], "MetricSet", None],
) -> Self:
    """Type-check and/or transform inputs into a MetricSet instance.

    This classmethod is merely implemented to avoid duplicate and
    boilerplate code from polluting FL orchestrating classes.

    Parameters
    ----------
    metrics: list[MetricInputType] or MetricSet or None
        Inputs set up a MetricSet instance, instance to type-check
        or None, resulting in an empty MetricSet being returned.

    Returns
    -------
    metricset: MetricSet
        MetricSet instance, type-checked or instantiated from inputs.

    Raises
    ------
    TypeError
        If `metrics` is of improper type.

    Other exceptions may be raised when calling this class's `__init__`.
    """
    if metrics is None:
        metrics = cls([])
    if isinstance(metrics, list):
        metrics = cls(metrics)
    if not isinstance(metrics, cls):
        raise TypeError(
            f"'metrics' should be a `{cls.__name__}`, a valid list of "
            "Metric instances and/or specs to wrap into one, or None."
        )
    return metrics

get_config()

Return a JSON-serializable configuration dict for this MetricSet.

Source code in declearn/metrics/_wrapper.py
261
262
263
264
265
266
def get_config(
    self,
) -> Dict[str, Any]:
    """Return a JSON-serializable configuration dict for this MetricSet."""
    cfg = [(metric.name, metric.get_config()) for metric in self.metrics]
    return {"metrics": cfg}

get_result()

Compute the metric(s), based on the current state variables.

Returns:

Name Type Description
results dict[str, float or numpy.ndarray]

Dict of named result metrics, that may either be unitary float scores or numpy arrays.

Source code in declearn/metrics/_wrapper.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_result(
    self,
) -> Dict[str, Union[float, np.ndarray]]:
    """Compute the metric(s), based on the current state variables.

    Returns
    -------
    results: dict[str, float or numpy.ndarray]
        Dict of named result metrics, that may either be
        unitary float scores or numpy arrays.
    """
    results = {}
    for metric in self.metrics:
        results.update(metric.get_result())
    return results

get_states()

Return a copy of the current state variables.

This method is designed to expose and share partial results that may be aggregated with those of other instances of the same metric before computing overall results.

Returns:

Name Type Description
states Dict[str, MetricState]

Dict of metric states that may be aggregated with their counterparts and re-assigned for finalization using the set_states then get_result methods of this object.

Source code in declearn/metrics/_wrapper.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def get_states(
    self,
) -> Dict[str, MetricState]:
    """Return a copy of the current state variables.

    This method is designed to expose and share partial results
    that may be aggregated with those of other instances of the
    same metric before computing overall results.

    Returns
    -------
    states:
        Dict of metric states that may be aggregated with their
        counterparts and re-assigned for finalization using the
        `set_states` then `get_result` methods of this object.
    """
    return {metric.name: metric.get_states() for metric in self.metrics}

reset()

Reset the metric to its initial state.

Source code in declearn/metrics/_wrapper.py
173
174
175
176
177
178
def reset(
    self,
) -> None:
    """Reset the metric to its initial state."""
    for metric in self.metrics:
        metric.reset()

set_states(states)

Replace internal states with a copy of incoming ones.

Parameters:

Name Type Description Default
states Dict[str, MetricState]

Replacement states, as a compatible MetricState instance.

required

Raises:

Type Description
TypeError

If any metric states are of improper type.

Source code in declearn/metrics/_wrapper.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def set_states(
    self,
    states: Dict[str, MetricState],
) -> None:
    """Replace internal states with a copy of incoming ones.

    Parameters
    ----------
    states:
        Replacement states, as a compatible `MetricState` instance.

    Raises
    ------
    TypeError
        If any metric states are of improper type.
    """
    for metric in self.metrics:
        if metric.name in states:
            metric.set_states(states[metric.name])

update(y_true, y_pred, s_wght=None)

Update the metric's internal state based on a data batch.

Parameters:

Name Type Description Default
y_true np.ndarray

True labels or values that were to be predicted.

required
y_pred np.ndarray

Predictions (scores or values) that are to be evaluated.

required
s_wght Optional[np.ndarray]

Optional sample weights to take into account in scores.

None
Source code in declearn/metrics/_wrapper.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def update(
    self,
    y_true: np.ndarray,
    y_pred: np.ndarray,
    s_wght: Optional[np.ndarray] = None,
) -> None:
    """Update the metric's internal state based on a data batch.

    Parameters
    ----------
    y_true: numpy.ndarray
        True labels or values that were to be predicted.
    y_pred: numpy.ndarray
        Predictions (scores or values) that are to be evaluated.
    s_wght: numpy.ndarray or None, default=None
        Optional sample weights to take into account in scores.
    """
    for metric in self.metrics:
        metric.update(y_true, y_pred, s_wght)