Skip to content

declearn.metrics.MeanSquaredError

Bases: MeanMetric

Mean Squared Error (MSE) metric.

This metric applies to a regression model, and computes the (opt. weighted) mean sample-wise squared error. Note that for inputs with multiple channels, the sum of squared channel-wise errors is computed for each sample, and averaged across samples.

Computed metric is the following:

  • mse: float Mean squared error, averaged across samples (possibly summed over channels for (>=2)-dimensional inputs).
Source code in declearn/metrics/_mean.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class MeanSquaredError(MeanMetric):
    """Mean Squared Error (MSE) metric.

    This metric applies to a regression model, and computes the (opt.
    weighted) mean sample-wise squared error. Note that for inputs
    with multiple channels, the sum of squared channel-wise errors
    is computed for each sample, and averaged across samples.

    Computed metric is the following:

    * mse: float
        Mean squared error, averaged across samples (possibly
        summed over channels for (>=2)-dimensional inputs).
    """

    name = "mse"

    def metric_func(
        self,
        y_true: np.ndarray,
        y_pred: np.ndarray,
    ) -> np.ndarray:
        # Sample-wise (sum of) squared error function.
        y_true, y_pred = squeeze_into_identical_shapes(y_true, y_pred)
        errors = np.square(y_true - y_pred)
        while errors.ndim > 1:
            errors = errors.sum(axis=-1)
        return errors