Skip to content

declearn.optimizer.modules.ScaffoldAuxVar

Bases: AuxVar

AuxVar subclass for Scaffold.

  • In Server -> Client direction, state should be specified.
  • In Client -> Server direction, delta should be specified.
Source code in declearn/optimizer/modules/_scaffold.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@dataclasses.dataclass
class ScaffoldAuxVar(AuxVar):
    """AuxVar subclass for Scaffold.

    - In Server -> Client direction, `state` should be specified.
    - In Client -> Server direction, `delta` should be specified.
    """

    state: Union[Vector, float, None] = None
    delta: Union[Vector, float, None] = None
    clients: Set[str] = dataclasses.field(default_factory=set)

    def __post_init__(
        self,
    ) -> None:
        if ((self.state is None) + (self.delta is None)) != 1:
            raise ValueError(
                "'ScaffoldAuxVar' should have exactly one of 'state' or "
                "'delta' specified as a Vector or conventional 0.0 value."
            )
        if isinstance(self.clients, list):
            self.clients = set(self.clients)

    @staticmethod
    def aggregate_clients(
        val_a: Set[str],
        val_b: Set[str],
    ) -> Set[str]:
        """Custom aggregation rule for 'clients' field."""
        return val_a.union(val_b)

    @classmethod
    def aggregate_state(
        cls,
        val_a: Union[Vector, float, None],
        val_b: Union[Vector, float, None],
    ) -> None:
        """Custom aggregation rule for 'state' field, raising when due."""
        if (val_a is not None) or (val_b is not None):
            raise NotImplementedError(
                "'ScaffoldAuxVar' should not be aggregating 'state' values."
            )

    def to_dict(
        self,
    ) -> Dict[str, Any]:
        output = {}  # type: Dict[str, Any]
        if self.state is not None:
            output["state"] = self.state
        if self.delta is not None:
            output["delta"] = self.delta
        if self.clients:
            output["clients"] = list(self.clients)
        return output

    def prepare_for_secagg(
        self,
    ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
        if self.state is not None:
            raise NotImplementedError(
                "'ScaffoldAuxVar' with 'state' should not undergo SecAgg."
            )
        return {"delta": self.delta}, {"clients": self.clients}

aggregate_clients(val_a, val_b) staticmethod

Custom aggregation rule for 'clients' field.

Source code in declearn/optimizer/modules/_scaffold.py
73
74
75
76
77
78
79
@staticmethod
def aggregate_clients(
    val_a: Set[str],
    val_b: Set[str],
) -> Set[str]:
    """Custom aggregation rule for 'clients' field."""
    return val_a.union(val_b)

aggregate_state(val_a, val_b) classmethod

Custom aggregation rule for 'state' field, raising when due.

Source code in declearn/optimizer/modules/_scaffold.py
81
82
83
84
85
86
87
88
89
90
91
@classmethod
def aggregate_state(
    cls,
    val_a: Union[Vector, float, None],
    val_b: Union[Vector, float, None],
) -> None:
    """Custom aggregation rule for 'state' field, raising when due."""
    if (val_a is not None) or (val_b is not None):
        raise NotImplementedError(
            "'ScaffoldAuxVar' should not be aggregating 'state' values."
        )