Skip to content

declearn.optimizer.regularizers.FedProxRegularizer

Bases: Regularizer

FedProx loss-regularization plug-in.

The FedProx algorithm is implemented through this regularizer, that adds a proximal term to the loss function so as to handle heterogeneity across clients in a federated learning context. See paper [1].

This regularizer implements the following term:

loss += alpha / 2 * (weights - ref_wgt)^2
w/ ref_wgt := weights at the 1st step of the round

To do so, it applies the following correction to gradients:

grads += alpha * (weights - ref_wgt)

In other words, this regularizer penalizes weights' departure (as a result from local optimization steps) from their initial (shared) values.

References

[1] Li et al., 2020. Federated Optimization in Heterogeneous Networks. https://arxiv.org/abs/1812.06127

Source code in declearn/optimizer/regularizers/_base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class FedProxRegularizer(Regularizer):
    """FedProx loss-regularization plug-in.

    The FedProx algorithm is implemented through this regularizer,
    that adds a proximal term to the loss function so as to handle
    heterogeneity across clients in a federated learning context.
    See paper [1].

    This regularizer implements the following term:

        loss += alpha / 2 * (weights - ref_wgt)^2
        w/ ref_wgt := weights at the 1st step of the round

    To do so, it applies the following correction to gradients:

        grads += alpha * (weights - ref_wgt)

    In other words, this regularizer penalizes weights' departure
    (as a result from local optimization steps) from their initial
    (shared) values.

    References
    ----------
    [1] Li et al., 2020.
        Federated Optimization in Heterogeneous Networks.
        https://arxiv.org/abs/1812.06127
    """

    name: ClassVar[str] = "fedprox"

    def __init__(
        self,
        alpha: float = 0.01,
    ) -> None:
        super().__init__(alpha)
        self.ref_wgt = None  # type: Optional[Vector]

    def on_round_start(
        self,
    ) -> None:
        self.ref_wgt = None

    def run(
        self,
        gradients: Vector,
        weights: Vector,
    ) -> Vector:
        if self.ref_wgt is None:
            self.ref_wgt = weights
            return gradients
        correct = self.alpha * (weights - self.ref_wgt)
        return gradients + correct