Skip to content

declearn.main.config.TrainingConfig

Dataclass wrapping parameters for a training round.

The parameters wrapped by this class are those of declearn.dataset.Dataset.generate_batches and declearn.communication.messaging.TrainRequest.

Attributes:

Name Type Description
batch_size int

Number of samples per processed data batch.

shuffle bool

Whether to shuffle data samples prior to batching.

drop_remainder bool

Whether to drop the last batch if it contains less samples than batch_size, or yield it anyway.

poisson bool

Whether to use Poisson sampling to generate the batches. Useful to maintain tight Differential Privacy guarantees.

n_epoch int or None

Maximum number of local data-processing epochs to perform. May be overridden by n_steps or timeout.

n_steps int or None

Maximum number of local data-processing steps to perform. May be overridden by n_epoch or timeout.

timeout int or None

Time (in seconds) beyond which to interrupt processing, regardless of the actual number of steps taken (> 0).

Source code in declearn/main/config/_dataclasses.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@dataclasses.dataclass
class TrainingConfig:
    """Dataclass wrapping parameters for a training round.

    The parameters wrapped by this class are those of
    `declearn.dataset.Dataset.generate_batches` and
    `declearn.communication.messaging.TrainRequest`.

    Attributes
    ----------
    batch_size: int
        Number of samples per processed data batch.
    shuffle: bool
        Whether to shuffle data samples prior to batching.
    drop_remainder: bool
        Whether to drop the last batch if it contains less
        samples than `batch_size`, or yield it anyway.
    poisson: bool
        Whether to use Poisson sampling to generate the batches.
        Useful to maintain tight Differential Privacy guarantees.
    n_epoch: int or None
        Maximum number of local data-processing epochs to
        perform. May be overridden by `n_steps` or `timeout`.
    n_steps: int or None
        Maximum number of local data-processing steps to
        perform. May be overridden by `n_epoch` or `timeout`.
    timeout: int or None
        Time (in seconds) beyond which to interrupt processing,
        regardless of the actual number of steps taken (> 0).
    """

    # Dataset.generate_batches() parameters
    batch_size: int
    shuffle: bool = False
    drop_remainder: bool = True
    poisson: bool = False
    # training effort constraints
    n_epoch: Optional[int] = 1
    n_steps: Optional[int] = None
    timeout: Optional[int] = None

    def __post_init__(self) -> None:
        if all(v is None for v in (self.n_epoch, self.n_steps, self.timeout)):
            raise ValueError(
                "At least one effort constraint must be set: "
                "n_epoch, n_steps and timeout cannot all be None."
            )

    @property
    def batch_cfg(self) -> Dict[str, Any]:
        """Batches-generation parameters from this config."""
        return {
            "batch_size": self.batch_size,
            "shuffle": self.shuffle,
            "drop_remainder": self.drop_remainder,
            "poisson": self.poisson,
        }

    @property
    def message_params(self) -> Dict[str, Any]:
        """TrainRequest message parameters from this config."""
        return {
            "batches": self.batch_cfg,
            "n_epoch": self.n_epoch,
            "n_steps": self.n_steps,
            "timeout": self.timeout,
        }

batch_cfg: Dict[str, Any] property

Batches-generation parameters from this config.

message_params: Dict[str, Any] property

TrainRequest message parameters from this config.