Skip to content

declearn.optimizer.modules.YogiMomentumModule

Bases: EWMAModule

Yogi-specific momentum gradient-acceleration module.

This module impements the following algorithm:

Init(beta):
    state = 0
Step(grads):
    state = state + sign(state-grads)*(1-beta)*grads
    grads = state

In other words, gradients are corrected in a somewhat-simlar fashion as in the base momentum formula, but so that the magnitude of the state update is merely a function of inputs rather than of both the inputs and the previous state [1].

Note that this module is actually meant to be used to compute a learning-rate adaptation term based on squared gradients.

References

[1] Zaheer and Reddi et al., 2018. Adaptive Methods for Nonconvex Optimization.

Source code in declearn/optimizer/modules/_momentum.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class YogiMomentumModule(EWMAModule):
    """Yogi-specific momentum gradient-acceleration module.

    This module impements the following algorithm:

        Init(beta):
            state = 0
        Step(grads):
            state = state + sign(state-grads)*(1-beta)*grads
            grads = state

    In other words, gradients are corrected in a somewhat-simlar
    fashion as in the base momentum formula, but so that the
    magnitude of the state update is merely a function of inputs
    rather than of both the inputs and the previous state [1].

    Note that this module is actually meant to be used to compute
    a learning-rate adaptation term based on squared gradients.

    References
    ----------
    [1] Zaheer and Reddi et al., 2018.
        Adaptive Methods for Nonconvex Optimization.
    """

    name: ClassVar[str] = "yogi-momentum"

    def run(
        self,
        gradients: Vector,
    ) -> Vector:
        sign = (self.state - gradients).sign()
        self.state = self.state - (sign * (1 - self.beta) * gradients)
        return self.state