Skip to content

declearn.main.FederatedClient

Client-side Federated Learning orchestrating class.

Source code in declearn/main/_client.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
class FederatedClient:
    """Client-side Federated Learning orchestrating class."""

    # one-too-many attribute; pylint: disable=too-many-instance-attributes

    def __init__(
        self,
        netwk: Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str],
        train_data: Union[Dataset, str],
        valid_data: Optional[Union[Dataset, str]] = None,
        checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
        secagg: Union[SecaggConfigClient, Dict[str, Any], None] = None,
        share_metrics: bool = True,
        logger: Union[logging.Logger, str, None] = None,
        verbose: bool = True,
    ) -> None:
        """Instantiate a client to participate in a federated learning task.

        Parameters
        ----------
        netwk: NetworkClient or NetworkClientConfig or dict or str
            NetworkClient communication endpoint instance, or configuration
            dict, dataclass or path to a TOML file enabling its instantiation.
            In the latter three cases, the object's default logger will be set
            to that of this `FederatedClient`.
        train_data: Dataset or str
            Dataset instance wrapping the training data.
            (DEPRECATED) May be a path to a JSON dump file.
        valid_data: Dataset or str or None
            Optional Dataset instance wrapping validation data.
            If None, run evaluation rounds over `train_data`.
            (DEPRECATED) May be a path to a JSON dump file.
        checkpoint: Checkpointer or dict or str or None, default=None
            Optional Checkpointer instance or instantiation dict to be
            used so as to save round-wise model, optimizer and metrics.
            If a single string is provided, treat it as the checkpoint
            folder path and use default values for other parameters.
        secagg: SecaggConfigClient or dict or None, default=None
            Optional SecAgg config and setup controller
            or dict of kwargs to set one up.
        share_metrics: bool, default=True
            Whether to share evaluation metrics with the server,
            or save them locally and only send the model's loss.
            This may prevent information leakage, e.g. as to the
            local distribution of target labels or values.
        logger: logging.Logger or str or None, default=None,
            Logger to use, or name of a logger to set up with
            `declearn.utils.get_logger`.
            If None, use `type(self):netwk.name`.
        verbose: bool, default=True
            Whether to verbose about ongoing operations.
            If True, display progress bars during training and validation
            rounds. If False and `logger is None`, set the logger's level
            to filter off most routine information.
        """
        # arguments serve modularity; pylint: disable=too-many-arguments
        # Assign the wrapped NetworkClient.
        self.netwk, replace_netwk_logger = self._parse_netwk(netwk)
        # Assign the logger and optionally replace that of the network client.
        if not isinstance(logger, logging.Logger):
            logger = get_logger(
                name=logger or f"{type(self).__name__}-{self.netwk.name}",
                level=logging.INFO if verbose else LOGGING_LEVEL_MAJOR,
            )
        self.logger = logger
        if replace_netwk_logger:
            self.netwk.logger = self.logger
        # Assign the wrapped training dataset.
        if isinstance(train_data, str):
            train_data = load_dataset_from_json(train_data)
        if not isinstance(train_data, Dataset):
            raise TypeError("'train_data' should be a Dataset or path to one.")
        self.train_data = train_data
        # Assign the wrapped validation dataset (if any).
        if isinstance(valid_data, str):
            valid_data = load_dataset_from_json(valid_data)
        if not (valid_data is None or isinstance(valid_data, Dataset)):
            raise TypeError("'valid_data' should be a Dataset or path to one.")
        self.valid_data = valid_data
        # Assign an optional checkpointer.
        if checkpoint is not None:
            checkpoint = Checkpointer.from_specs(checkpoint)
        self.ckptr = checkpoint
        # Assign the optional SecAgg config and declare an Encrypter slot.
        self.secagg = self._parse_secagg(secagg)
        self._encrypter = None  # type: Optional[Encrypter]
        # Record the metric-sharing and verbosity bool values.
        self.share_metrics = bool(share_metrics)
        if (self.secagg is not None) and not self.share_metrics:
            msg = (
                "Disabling metrics' sharing with SecAgg enabled is likely"
                "to cause errors, unless each and every client does so."
            )
            self.logger.warning(msg)
            warnings.warn(msg, UserWarning, stacklevel=-1)
        self.verbose = bool(verbose)
        # Create a TrainingManager slot, populated at initialization phase.
        self.trainmanager = None  # type: Optional[TrainingManager]

    @staticmethod
    def _parse_netwk(netwk) -> Tuple[NetworkClient, bool]:
        """Parse 'netwrk' instantiation argument.

        Return both a 'NetworkClient' instance and a bool indicating
        whether that instance's logger should be replaced with that
        of the client (set up at a latter step).
        """
        # Case when a NetworkClient instance is provided: return.
        if isinstance(netwk, NetworkClient):
            return netwk, False
        # Case when a NetworkClientConfig is expected: verify or parse.
        if isinstance(netwk, NetworkClientConfig):
            config = netwk
        elif isinstance(netwk, str):
            config = NetworkClientConfig.from_toml(netwk)
        elif isinstance(netwk, dict):
            replace_netwk_logger = netwk.get("logger", None) is None
            config = NetworkClientConfig.from_params(**netwk)
        else:
            raise TypeError(
                "'netwk' should be a 'NetworkClient' instance or the valid "
                f"configuration of one, not '{type(netwk)}'"
            )
        # Instantiate from the (parsed) config.
        replace_netwk_logger = config.logger is None
        return config.build_client(), replace_netwk_logger

    @staticmethod
    def _parse_secagg(
        secagg: Union[SecaggConfigClient, Dict[str, Any], None],
    ) -> Optional[SecaggConfigClient]:
        """Parse 'secagg' instantiation argument."""
        if secagg is None:
            return None
        if isinstance(secagg, SecaggConfigClient):
            return secagg
        if isinstance(secagg, dict):
            try:
                return parse_secagg_config_client(**secagg)
            except Exception as exc:
                raise TypeError("Failed to parse 'secagg' inputs.") from exc
        raise TypeError(
            "'secagg' should be a 'SecaggConfigClient' instance or a dict "
            f"of keyword arguments to set one up, not '{type(secagg)}'."
        )

    def run(
        self,
    ) -> None:
        """Participate in the federated learning process.

        * Connect to the orchestrating `FederatedServer` and register
          for training, sharing some metadata about `self.train_data`.
        * Await initialization instructions to spawn the Model that is
          to be trained and the local Optimizer used to do so.
        * Participate in training and evaluation rounds based on the
          server's requests, checkpointing the model and local loss.
        * Expect instructions to stop training, or to cancel it in
          case errors are reported during the process.
        """
        asyncio.run(self.async_run())

    async def async_run(
        self,
    ) -> None:
        """Participate in the federated learning process.

        Note: this method is the async backend of `self.run`.
        """
        async with self.netwk:
            # Register for training, then collect initialization information.
            await self.register()
            await self.initialize()
            # Process server instructions as they come.
            while True:
                message = await self.netwk.recv_message()
                stoprun = await self.handle_message(message)
                if stoprun:
                    break

    async def handle_message(
        self,
        message: SerializedMessage,
    ) -> bool:
        """Handle an incoming message from the server.

        Parameters
        ----------
        message: SerializedMessage
            Serialized message that needs triage and processing.

        Returns
        -------
        exit_loop: bool
            Whether to interrupt the client's message-receiving loop.
        """
        exit_loop = False
        if issubclass(message.message_cls, messaging.TrainRequest):
            await self.training_round(message.deserialize())
        elif issubclass(message.message_cls, messaging.EvaluationRequest):
            await self.evaluation_round(message.deserialize())
        elif issubclass(message.message_cls, SecaggSetupQuery):
            await self.setup_secagg(message)  # note: keep serialized
        elif issubclass(message.message_cls, messaging.StopTraining):
            await self.stop_training(message.deserialize())
            exit_loop = True
        elif issubclass(message.message_cls, messaging.CancelTraining):
            await self.cancel_training(message.deserialize())
        else:
            error = "Unexpected message type received from server: "
            error += message.message_cls.__name__
            self.logger.error(error)
            raise ValueError(error)
        return exit_loop

    async def register(
        self,
    ) -> None:
        """Register for participation in the federated learning process.

        Raises
        ------
        RuntimeError
            If registration has failed 10 times (with a 1 minute delay
            between connection and registration attempts).
        """
        for i in range(10):  # max_attempts (10)
            self.logger.info(
                "Attempting to join training (attempt n°%s)", i + 1
            )
            registered = await self.netwk.register()
            if registered:
                break
            await asyncio.sleep(60)  # delay_retries (1 minute)
        else:
            raise RuntimeError("Failed to register for training.")

    async def initialize(
        self,
    ) -> None:
        """Set up a Model and an Optimizer based on server instructions.

        Await server instructions (as an InitRequest message) and conduct
        initialization.

        Raises
        ------
        RuntimeError
            If initialization failed, either because the message was not
            received or was of incorrect type, or because instantiation
            of the objects it specifies failed.

        Returns
        -------
        model: Model
            Model that is to be trained (with shared initial parameters).
        optim: Optimizer
            Optimizer that is to be used locally to train the model.
        """
        # Await initialization instructions.
        self.logger.info("Awaiting initialization instructions from server.")
        received = await self.netwk.recv_message()
        # If a MetadataQuery is received, process it, then await InitRequest.
        if issubclass(received.message_cls, messaging.MetadataQuery):
            await self._collect_and_send_metadata(received.deserialize())
            received = await self.netwk.recv_message()
        # Ensure that an 'InitRequest' was received.
        message = await verify_server_message_validity(
            self.netwk, received, expected=messaging.InitRequest
        )
        # Verify that SecAgg type is coherent across peers.
        secagg_type = None if self.secagg is None else self.secagg.secagg_type
        if message.secagg != secagg_type:
            error = (
                "SecAgg configurgation mismatch: server set "
                f"'{message.secagg}', client set '{secagg_type}'."
            )
            self.logger.error(error)
            await self.netwk.send_message(messaging.Error(error))
            raise RuntimeError(f"Initialization failed: {error}.")
        # Perform initialization, catching errors to report them to the server.
        try:
            self.trainmanager = TrainingManager(
                model=message.model,
                optim=message.optim,
                aggrg=message.aggrg,
                train_data=self.train_data,
                valid_data=self.valid_data,
                metrics=message.metrics,
                logger=self.logger,
                verbose=self.verbose,
            )
        except Exception as exc:
            await self.netwk.send_message(messaging.Error(repr(exc)))
            raise RuntimeError("Initialization failed.") from exc
        # If instructed to do so, run additional steps to set up DP-SGD.
        if message.dpsgd:
            await self._initialize_dpsgd()
        # Send back an empty message to indicate that all went fine.
        self.logger.info("Notifying the server that initialization went fine.")
        await self.netwk.send_message(messaging.InitReply())
        # Optionally checkpoint the received model and optimizer.
        if self.ckptr:
            self.ckptr.checkpoint(
                model=self.trainmanager.model,
                optimizer=self.trainmanager.optim,
                first_call=True,
            )

    async def _collect_and_send_metadata(
        self,
        message: messaging.MetadataQuery,
    ) -> None:
        """Collect and report some metadata based on server instructions."""
        self.logger.info("Collecting metadata to send to the server.")
        metadata = dataclasses.asdict(self.train_data.get_data_specs())
        if missing := set(message.fields).difference(metadata):
            err_msg = f"Metadata query for undefined fields: {missing}."
            await self.netwk.send_message(messaging.Error(err_msg))
            raise RuntimeError(err_msg)
        data_info = {key: metadata[key] for key in message.fields}
        self.logger.info(
            "Sending training dataset metadata to the server: %s.",
            list(data_info),
        )
        await self.netwk.send_message(messaging.MetadataReply(data_info))

    async def _initialize_dpsgd(
        self,
    ) -> None:
        """Set up differentially-private training as part of initialization.

        This method wraps the `make_private` one in the context of
        `initialize` and should never be called in another context.
        """
        received = await self.netwk.recv_message()
        try:
            message = await verify_server_message_validity(
                self.netwk, received, expected=messaging.PrivacyRequest
            )
        except Exception as exc:
            raise RuntimeError("DP-SGD initialization failed.") from exc
        self.logger.info("Received a request to set up DP-SGD.")
        try:
            self.make_private(message)
        except Exception as exc:  # pylint: disable=broad-except
            self.logger.error(
                "Exception encountered in `make_private`: %s", exc
            )
            await self.netwk.send_message(messaging.Error(repr(exc)))
            raise RuntimeError("DP-SGD initialization failed.") from exc
        # If things went right, notify the server.
        self.logger.info("Notifying the server that DP-SGD setup went fine.")
        await self.netwk.send_message(messaging.PrivacyReply())

    def make_private(
        self,
        message: messaging.PrivacyRequest,
    ) -> None:
        """Set up differentially-private training, using DP-SGD.

        Based on the server message, replace the wrapped `trainmanager`
        attribute by an instance of a subclass that provides with DP-SGD.

        Note that this method triggers the import of `declearn.main.privacy`
        which may result in an error if the third-party dependency 'opacus'
        is not available.

        Parameters:
        ----------
        message: PrivacyRequest
            Instructions from the server regarding the DP-SGD setup.
        """
        assert self.trainmanager is not None
        # fmt: off
        # lazy-import the DPTrainingManager, that involves some optional,
        # heavy-loadtime dependencies; pylint: disable=import-outside-toplevel
        from declearn.main.privacy import DPTrainingManager

        # pylint: enable=import-outside-toplevel
        self.trainmanager = DPTrainingManager(
            model=self.trainmanager.model,
            optim=self.trainmanager.optim,
            aggrg=self.trainmanager.aggrg,
            train_data=self.trainmanager.train_data,
            valid_data=self.trainmanager.valid_data,
            metrics=self.trainmanager.metrics,
            logger=self.trainmanager.logger,
            verbose=self.trainmanager.verbose,
        )
        self.trainmanager.make_private(message)

    async def setup_secagg(
        self,
        received: SerializedMessage[SecaggSetupQuery],
    ) -> None:
        """Participate in a SecAgg setup protocol.

        Process a setup request from the server, run a method-specific
        protocol (that may involve additional communications) and update
        the held SecAgg `Encrypter` with the resulting one.

        Parameters
        ----------
        received:
            Serialized `SecaggSetupQuery` request received from the server,
            the exact type of which depends on the SecAgg method being set.
        """
        # If no SecAgg setup controller was set, send an Error message.
        if self.secagg is None:
            error = (
                "Received a SecAgg setup request, but SecAgg is not "
                "configured to be used."
            )
            self.logger.error(error)
            await self.netwk.send_message(messaging.Error(error))
            return
        # Otherwise, participate in the SecAgg setup protocol.
        self.logger.info("Received a SecAgg setup request.")
        try:
            self._encrypter = await self.secagg.setup_encrypter(
                netwk=self.netwk, query=received
            )
        except (KeyError, RuntimeError, ValueError) as exc:
            self.logger.error("SecAgg setup failed: %s", repr(exc))

    async def training_round(
        self,
        message: messaging.TrainRequest,
    ) -> None:
        """Run a local training round.

        If an exception is raised during the local process, wrap
        it as an Error message and send it to the server instead
        of raising it.

        Parameters
        ----------
        message: TrainRequest
            Instructions from the server regarding the training round.
        """
        assert self.trainmanager is not None
        # When SecAgg is to be used, verify that it was set up.
        if self.secagg is not None and self._encrypter is None:
            error = (
                f"Refusing to participate in training round {message.round_i}"
                "as SecAgg is configured to be used but was not set up."
            )
            self.logger.error(error)
            await self.netwk.send_message(messaging.Error(error))
            return
        # Run the training round.
        reply = self.trainmanager.training_round(message)  # type: Message
        # Collect and optionally record batch-wise training losses.
        # Note: collection enables purging them from memory.
        losses = self.trainmanager.model.collect_training_losses()
        if self.ckptr is not None:
            self.ckptr.save_metrics(
                metrics={"training_losses": np.array(losses)},
                prefix="training_losses",
                append=True,
                timestamp=f"round_{message.round_i}",
            )
        # Optionally SecAgg-encrypt the reply.
        if self._encrypter is not None and isinstance(
            reply, messaging.TrainReply
        ):
            reply = SecaggTrainReply.from_cleartext_message(
                cleartext=reply, encrypter=self._encrypter
            )
        # Send training results (or error message) to the server.
        await self.netwk.send_message(reply)

    async def evaluation_round(
        self,
        message: messaging.EvaluationRequest,
    ) -> None:
        """Run a local evaluation round.

        If an exception is raised during the local process, wrap
        it as an Error message and send it to the server instead
        of raising it.

        If a checkpointer is set, record the local loss, and the
        model weights received from the server.

        Parameters
        ----------
        message: EvaluationRequest
            Instructions from the server regarding the evaluation round.
        """
        assert self.trainmanager is not None
        # When SecAgg is to be used, verify that it was set up.
        if self.secagg is not None and self._encrypter is None:
            error = (
                "Refusing to participate in evaluation round "
                f"{message.round_i} as SecAgg is configured to be used "
                "but was not set up."
            )
            self.logger.error(error)
            await self.netwk.send_message(messaging.Error(error))
            return
        # Run the evaluation round.
        reply = self.trainmanager.evaluation_round(message)  # type: Message
        # Post-process the results.
        if isinstance(reply, messaging.EvaluationReply):  # not an Error
            # Optionnally checkpoint the model, optimizer and local loss.
            if self.ckptr:
                self.ckptr.checkpoint(
                    model=self.trainmanager.model,
                    optimizer=self.trainmanager.optim,
                    metrics=self.trainmanager.metrics.get_result(),
                )
            # Optionally prevent sharing metrics (save for the loss).
            if not self.share_metrics:
                reply.metrics.clear()
            # Optionally SecAgg-encrypt results.
            if self._encrypter is not None:
                reply = SecaggEvaluationReply.from_cleartext_message(
                    cleartext=reply, encrypter=self._encrypter
                )
        # Send evaluation results (or error message) to the server.
        await self.netwk.send_message(reply)

    async def stop_training(
        self,
        message: messaging.StopTraining,
    ) -> None:
        """Handle a server request to stop training.

        Parameters
        ----------
        message: StopTraining
            StopTraining message received from the server.
        """
        self.logger.info(
            "Training is now over, after %s rounds. Global loss: %s",
            message.rounds,
            message.loss,
        )
        if self.ckptr:
            path = os.path.join(self.ckptr.folder, "model_state_best.json")
            self.logger.info("Checkpointing final weights under %s.", path)
            assert self.trainmanager is not None  # for mypy
            self.trainmanager.model.set_weights(message.weights)
            self.ckptr.save_model(self.trainmanager.model, timestamp="best")

    async def cancel_training(
        self,
        message: messaging.CancelTraining,
    ) -> None:
        """Handle a server request to cancel training.

        Parameters
        ----------
        message: CancelTraining
            CancelTraining message received from the server.
        """
        error = "Training was cancelled by the server, with reason:\n"
        error += message.reason
        self.logger.warning(error)
        raise RuntimeError(error)

__init__(netwk, train_data, valid_data=None, checkpoint=None, secagg=None, share_metrics=True, logger=None, verbose=True)

Instantiate a client to participate in a federated learning task.

Parameters:

Name Type Description Default
netwk Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str]

NetworkClient communication endpoint instance, or configuration dict, dataclass or path to a TOML file enabling its instantiation. In the latter three cases, the object's default logger will be set to that of this FederatedClient.

required
train_data Union[Dataset, str]

Dataset instance wrapping the training data. (DEPRECATED) May be a path to a JSON dump file.

required
valid_data Optional[Union[Dataset, str]]

Optional Dataset instance wrapping validation data. If None, run evaluation rounds over train_data. (DEPRECATED) May be a path to a JSON dump file.

None
checkpoint Union[Checkpointer, Dict[str, Any], str, None]

Optional Checkpointer instance or instantiation dict to be used so as to save round-wise model, optimizer and metrics. If a single string is provided, treat it as the checkpoint folder path and use default values for other parameters.

None
secagg Union[SecaggConfigClient, Dict[str, Any], None]

Optional SecAgg config and setup controller or dict of kwargs to set one up.

None
share_metrics bool

Whether to share evaluation metrics with the server, or save them locally and only send the model's loss. This may prevent information leakage, e.g. as to the local distribution of target labels or values.

True
logger Union[logging.Logger, str, None]

Logger to use, or name of a logger to set up with declearn.utils.get_logger. If None, use type(self):netwk.name.

None
verbose bool

Whether to verbose about ongoing operations. If True, display progress bars during training and validation rounds. If False and logger is None, set the logger's level to filter off most routine information.

True
Source code in declearn/main/_client.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __init__(
    self,
    netwk: Union[NetworkClient, NetworkClientConfig, Dict[str, Any], str],
    train_data: Union[Dataset, str],
    valid_data: Optional[Union[Dataset, str]] = None,
    checkpoint: Union[Checkpointer, Dict[str, Any], str, None] = None,
    secagg: Union[SecaggConfigClient, Dict[str, Any], None] = None,
    share_metrics: bool = True,
    logger: Union[logging.Logger, str, None] = None,
    verbose: bool = True,
) -> None:
    """Instantiate a client to participate in a federated learning task.

    Parameters
    ----------
    netwk: NetworkClient or NetworkClientConfig or dict or str
        NetworkClient communication endpoint instance, or configuration
        dict, dataclass or path to a TOML file enabling its instantiation.
        In the latter three cases, the object's default logger will be set
        to that of this `FederatedClient`.
    train_data: Dataset or str
        Dataset instance wrapping the training data.
        (DEPRECATED) May be a path to a JSON dump file.
    valid_data: Dataset or str or None
        Optional Dataset instance wrapping validation data.
        If None, run evaluation rounds over `train_data`.
        (DEPRECATED) May be a path to a JSON dump file.
    checkpoint: Checkpointer or dict or str or None, default=None
        Optional Checkpointer instance or instantiation dict to be
        used so as to save round-wise model, optimizer and metrics.
        If a single string is provided, treat it as the checkpoint
        folder path and use default values for other parameters.
    secagg: SecaggConfigClient or dict or None, default=None
        Optional SecAgg config and setup controller
        or dict of kwargs to set one up.
    share_metrics: bool, default=True
        Whether to share evaluation metrics with the server,
        or save them locally and only send the model's loss.
        This may prevent information leakage, e.g. as to the
        local distribution of target labels or values.
    logger: logging.Logger or str or None, default=None,
        Logger to use, or name of a logger to set up with
        `declearn.utils.get_logger`.
        If None, use `type(self):netwk.name`.
    verbose: bool, default=True
        Whether to verbose about ongoing operations.
        If True, display progress bars during training and validation
        rounds. If False and `logger is None`, set the logger's level
        to filter off most routine information.
    """
    # arguments serve modularity; pylint: disable=too-many-arguments
    # Assign the wrapped NetworkClient.
    self.netwk, replace_netwk_logger = self._parse_netwk(netwk)
    # Assign the logger and optionally replace that of the network client.
    if not isinstance(logger, logging.Logger):
        logger = get_logger(
            name=logger or f"{type(self).__name__}-{self.netwk.name}",
            level=logging.INFO if verbose else LOGGING_LEVEL_MAJOR,
        )
    self.logger = logger
    if replace_netwk_logger:
        self.netwk.logger = self.logger
    # Assign the wrapped training dataset.
    if isinstance(train_data, str):
        train_data = load_dataset_from_json(train_data)
    if not isinstance(train_data, Dataset):
        raise TypeError("'train_data' should be a Dataset or path to one.")
    self.train_data = train_data
    # Assign the wrapped validation dataset (if any).
    if isinstance(valid_data, str):
        valid_data = load_dataset_from_json(valid_data)
    if not (valid_data is None or isinstance(valid_data, Dataset)):
        raise TypeError("'valid_data' should be a Dataset or path to one.")
    self.valid_data = valid_data
    # Assign an optional checkpointer.
    if checkpoint is not None:
        checkpoint = Checkpointer.from_specs(checkpoint)
    self.ckptr = checkpoint
    # Assign the optional SecAgg config and declare an Encrypter slot.
    self.secagg = self._parse_secagg(secagg)
    self._encrypter = None  # type: Optional[Encrypter]
    # Record the metric-sharing and verbosity bool values.
    self.share_metrics = bool(share_metrics)
    if (self.secagg is not None) and not self.share_metrics:
        msg = (
            "Disabling metrics' sharing with SecAgg enabled is likely"
            "to cause errors, unless each and every client does so."
        )
        self.logger.warning(msg)
        warnings.warn(msg, UserWarning, stacklevel=-1)
    self.verbose = bool(verbose)
    # Create a TrainingManager slot, populated at initialization phase.
    self.trainmanager = None  # type: Optional[TrainingManager]

async_run() async

Participate in the federated learning process.

Note: this method is the async backend of self.run.

Source code in declearn/main/_client.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
async def async_run(
    self,
) -> None:
    """Participate in the federated learning process.

    Note: this method is the async backend of `self.run`.
    """
    async with self.netwk:
        # Register for training, then collect initialization information.
        await self.register()
        await self.initialize()
        # Process server instructions as they come.
        while True:
            message = await self.netwk.recv_message()
            stoprun = await self.handle_message(message)
            if stoprun:
                break

cancel_training(message) async

Handle a server request to cancel training.

Parameters:

Name Type Description Default
message messaging.CancelTraining

CancelTraining message received from the server.

required
Source code in declearn/main/_client.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
async def cancel_training(
    self,
    message: messaging.CancelTraining,
) -> None:
    """Handle a server request to cancel training.

    Parameters
    ----------
    message: CancelTraining
        CancelTraining message received from the server.
    """
    error = "Training was cancelled by the server, with reason:\n"
    error += message.reason
    self.logger.warning(error)
    raise RuntimeError(error)

evaluation_round(message) async

Run a local evaluation round.

If an exception is raised during the local process, wrap it as an Error message and send it to the server instead of raising it.

If a checkpointer is set, record the local loss, and the model weights received from the server.

Parameters:

Name Type Description Default
message messaging.EvaluationRequest

Instructions from the server regarding the evaluation round.

required
Source code in declearn/main/_client.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
async def evaluation_round(
    self,
    message: messaging.EvaluationRequest,
) -> None:
    """Run a local evaluation round.

    If an exception is raised during the local process, wrap
    it as an Error message and send it to the server instead
    of raising it.

    If a checkpointer is set, record the local loss, and the
    model weights received from the server.

    Parameters
    ----------
    message: EvaluationRequest
        Instructions from the server regarding the evaluation round.
    """
    assert self.trainmanager is not None
    # When SecAgg is to be used, verify that it was set up.
    if self.secagg is not None and self._encrypter is None:
        error = (
            "Refusing to participate in evaluation round "
            f"{message.round_i} as SecAgg is configured to be used "
            "but was not set up."
        )
        self.logger.error(error)
        await self.netwk.send_message(messaging.Error(error))
        return
    # Run the evaluation round.
    reply = self.trainmanager.evaluation_round(message)  # type: Message
    # Post-process the results.
    if isinstance(reply, messaging.EvaluationReply):  # not an Error
        # Optionnally checkpoint the model, optimizer and local loss.
        if self.ckptr:
            self.ckptr.checkpoint(
                model=self.trainmanager.model,
                optimizer=self.trainmanager.optim,
                metrics=self.trainmanager.metrics.get_result(),
            )
        # Optionally prevent sharing metrics (save for the loss).
        if not self.share_metrics:
            reply.metrics.clear()
        # Optionally SecAgg-encrypt results.
        if self._encrypter is not None:
            reply = SecaggEvaluationReply.from_cleartext_message(
                cleartext=reply, encrypter=self._encrypter
            )
    # Send evaluation results (or error message) to the server.
    await self.netwk.send_message(reply)

handle_message(message) async

Handle an incoming message from the server.

Parameters:

Name Type Description Default
message SerializedMessage

Serialized message that needs triage and processing.

required

Returns:

Name Type Description
exit_loop bool

Whether to interrupt the client's message-receiving loop.

Source code in declearn/main/_client.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
async def handle_message(
    self,
    message: SerializedMessage,
) -> bool:
    """Handle an incoming message from the server.

    Parameters
    ----------
    message: SerializedMessage
        Serialized message that needs triage and processing.

    Returns
    -------
    exit_loop: bool
        Whether to interrupt the client's message-receiving loop.
    """
    exit_loop = False
    if issubclass(message.message_cls, messaging.TrainRequest):
        await self.training_round(message.deserialize())
    elif issubclass(message.message_cls, messaging.EvaluationRequest):
        await self.evaluation_round(message.deserialize())
    elif issubclass(message.message_cls, SecaggSetupQuery):
        await self.setup_secagg(message)  # note: keep serialized
    elif issubclass(message.message_cls, messaging.StopTraining):
        await self.stop_training(message.deserialize())
        exit_loop = True
    elif issubclass(message.message_cls, messaging.CancelTraining):
        await self.cancel_training(message.deserialize())
    else:
        error = "Unexpected message type received from server: "
        error += message.message_cls.__name__
        self.logger.error(error)
        raise ValueError(error)
    return exit_loop

initialize() async

Set up a Model and an Optimizer based on server instructions.

Await server instructions (as an InitRequest message) and conduct initialization.

Raises:

Type Description
RuntimeError

If initialization failed, either because the message was not received or was of incorrect type, or because instantiation of the objects it specifies failed.

Returns:

Name Type Description
model Model

Model that is to be trained (with shared initial parameters).

optim Optimizer

Optimizer that is to be used locally to train the model.

Source code in declearn/main/_client.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
async def initialize(
    self,
) -> None:
    """Set up a Model and an Optimizer based on server instructions.

    Await server instructions (as an InitRequest message) and conduct
    initialization.

    Raises
    ------
    RuntimeError
        If initialization failed, either because the message was not
        received or was of incorrect type, or because instantiation
        of the objects it specifies failed.

    Returns
    -------
    model: Model
        Model that is to be trained (with shared initial parameters).
    optim: Optimizer
        Optimizer that is to be used locally to train the model.
    """
    # Await initialization instructions.
    self.logger.info("Awaiting initialization instructions from server.")
    received = await self.netwk.recv_message()
    # If a MetadataQuery is received, process it, then await InitRequest.
    if issubclass(received.message_cls, messaging.MetadataQuery):
        await self._collect_and_send_metadata(received.deserialize())
        received = await self.netwk.recv_message()
    # Ensure that an 'InitRequest' was received.
    message = await verify_server_message_validity(
        self.netwk, received, expected=messaging.InitRequest
    )
    # Verify that SecAgg type is coherent across peers.
    secagg_type = None if self.secagg is None else self.secagg.secagg_type
    if message.secagg != secagg_type:
        error = (
            "SecAgg configurgation mismatch: server set "
            f"'{message.secagg}', client set '{secagg_type}'."
        )
        self.logger.error(error)
        await self.netwk.send_message(messaging.Error(error))
        raise RuntimeError(f"Initialization failed: {error}.")
    # Perform initialization, catching errors to report them to the server.
    try:
        self.trainmanager = TrainingManager(
            model=message.model,
            optim=message.optim,
            aggrg=message.aggrg,
            train_data=self.train_data,
            valid_data=self.valid_data,
            metrics=message.metrics,
            logger=self.logger,
            verbose=self.verbose,
        )
    except Exception as exc:
        await self.netwk.send_message(messaging.Error(repr(exc)))
        raise RuntimeError("Initialization failed.") from exc
    # If instructed to do so, run additional steps to set up DP-SGD.
    if message.dpsgd:
        await self._initialize_dpsgd()
    # Send back an empty message to indicate that all went fine.
    self.logger.info("Notifying the server that initialization went fine.")
    await self.netwk.send_message(messaging.InitReply())
    # Optionally checkpoint the received model and optimizer.
    if self.ckptr:
        self.ckptr.checkpoint(
            model=self.trainmanager.model,
            optimizer=self.trainmanager.optim,
            first_call=True,
        )

make_private(message)

Set up differentially-private training, using DP-SGD.

Based on the server message, replace the wrapped trainmanager attribute by an instance of a subclass that provides with DP-SGD.

Note that this method triggers the import of declearn.main.privacy which may result in an error if the third-party dependency 'opacus' is not available.

Parameters:

message: PrivacyRequest Instructions from the server regarding the DP-SGD setup.

Source code in declearn/main/_client.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def make_private(
    self,
    message: messaging.PrivacyRequest,
) -> None:
    """Set up differentially-private training, using DP-SGD.

    Based on the server message, replace the wrapped `trainmanager`
    attribute by an instance of a subclass that provides with DP-SGD.

    Note that this method triggers the import of `declearn.main.privacy`
    which may result in an error if the third-party dependency 'opacus'
    is not available.

    Parameters:
    ----------
    message: PrivacyRequest
        Instructions from the server regarding the DP-SGD setup.
    """
    assert self.trainmanager is not None
    # fmt: off
    # lazy-import the DPTrainingManager, that involves some optional,
    # heavy-loadtime dependencies; pylint: disable=import-outside-toplevel
    from declearn.main.privacy import DPTrainingManager

    # pylint: enable=import-outside-toplevel
    self.trainmanager = DPTrainingManager(
        model=self.trainmanager.model,
        optim=self.trainmanager.optim,
        aggrg=self.trainmanager.aggrg,
        train_data=self.trainmanager.train_data,
        valid_data=self.trainmanager.valid_data,
        metrics=self.trainmanager.metrics,
        logger=self.trainmanager.logger,
        verbose=self.trainmanager.verbose,
    )
    self.trainmanager.make_private(message)

register() async

Register for participation in the federated learning process.

Raises:

Type Description
RuntimeError

If registration has failed 10 times (with a 1 minute delay between connection and registration attempts).

Source code in declearn/main/_client.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
async def register(
    self,
) -> None:
    """Register for participation in the federated learning process.

    Raises
    ------
    RuntimeError
        If registration has failed 10 times (with a 1 minute delay
        between connection and registration attempts).
    """
    for i in range(10):  # max_attempts (10)
        self.logger.info(
            "Attempting to join training (attempt n°%s)", i + 1
        )
        registered = await self.netwk.register()
        if registered:
            break
        await asyncio.sleep(60)  # delay_retries (1 minute)
    else:
        raise RuntimeError("Failed to register for training.")

run()

Participate in the federated learning process.

  • Connect to the orchestrating FederatedServer and register for training, sharing some metadata about self.train_data.
  • Await initialization instructions to spawn the Model that is to be trained and the local Optimizer used to do so.
  • Participate in training and evaluation rounds based on the server's requests, checkpointing the model and local loss.
  • Expect instructions to stop training, or to cancel it in case errors are reported during the process.
Source code in declearn/main/_client.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def run(
    self,
) -> None:
    """Participate in the federated learning process.

    * Connect to the orchestrating `FederatedServer` and register
      for training, sharing some metadata about `self.train_data`.
    * Await initialization instructions to spawn the Model that is
      to be trained and the local Optimizer used to do so.
    * Participate in training and evaluation rounds based on the
      server's requests, checkpointing the model and local loss.
    * Expect instructions to stop training, or to cancel it in
      case errors are reported during the process.
    """
    asyncio.run(self.async_run())

setup_secagg(received) async

Participate in a SecAgg setup protocol.

Process a setup request from the server, run a method-specific protocol (that may involve additional communications) and update the held SecAgg Encrypter with the resulting one.

Parameters:

Name Type Description Default
received SerializedMessage[SecaggSetupQuery]

Serialized SecaggSetupQuery request received from the server, the exact type of which depends on the SecAgg method being set.

required
Source code in declearn/main/_client.py
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
async def setup_secagg(
    self,
    received: SerializedMessage[SecaggSetupQuery],
) -> None:
    """Participate in a SecAgg setup protocol.

    Process a setup request from the server, run a method-specific
    protocol (that may involve additional communications) and update
    the held SecAgg `Encrypter` with the resulting one.

    Parameters
    ----------
    received:
        Serialized `SecaggSetupQuery` request received from the server,
        the exact type of which depends on the SecAgg method being set.
    """
    # If no SecAgg setup controller was set, send an Error message.
    if self.secagg is None:
        error = (
            "Received a SecAgg setup request, but SecAgg is not "
            "configured to be used."
        )
        self.logger.error(error)
        await self.netwk.send_message(messaging.Error(error))
        return
    # Otherwise, participate in the SecAgg setup protocol.
    self.logger.info("Received a SecAgg setup request.")
    try:
        self._encrypter = await self.secagg.setup_encrypter(
            netwk=self.netwk, query=received
        )
    except (KeyError, RuntimeError, ValueError) as exc:
        self.logger.error("SecAgg setup failed: %s", repr(exc))

stop_training(message) async

Handle a server request to stop training.

Parameters:

Name Type Description Default
message messaging.StopTraining

StopTraining message received from the server.

required
Source code in declearn/main/_client.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
async def stop_training(
    self,
    message: messaging.StopTraining,
) -> None:
    """Handle a server request to stop training.

    Parameters
    ----------
    message: StopTraining
        StopTraining message received from the server.
    """
    self.logger.info(
        "Training is now over, after %s rounds. Global loss: %s",
        message.rounds,
        message.loss,
    )
    if self.ckptr:
        path = os.path.join(self.ckptr.folder, "model_state_best.json")
        self.logger.info("Checkpointing final weights under %s.", path)
        assert self.trainmanager is not None  # for mypy
        self.trainmanager.model.set_weights(message.weights)
        self.ckptr.save_model(self.trainmanager.model, timestamp="best")

training_round(message) async

Run a local training round.

If an exception is raised during the local process, wrap it as an Error message and send it to the server instead of raising it.

Parameters:

Name Type Description Default
message messaging.TrainRequest

Instructions from the server regarding the training round.

required
Source code in declearn/main/_client.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
async def training_round(
    self,
    message: messaging.TrainRequest,
) -> None:
    """Run a local training round.

    If an exception is raised during the local process, wrap
    it as an Error message and send it to the server instead
    of raising it.

    Parameters
    ----------
    message: TrainRequest
        Instructions from the server regarding the training round.
    """
    assert self.trainmanager is not None
    # When SecAgg is to be used, verify that it was set up.
    if self.secagg is not None and self._encrypter is None:
        error = (
            f"Refusing to participate in training round {message.round_i}"
            "as SecAgg is configured to be used but was not set up."
        )
        self.logger.error(error)
        await self.netwk.send_message(messaging.Error(error))
        return
    # Run the training round.
    reply = self.trainmanager.training_round(message)  # type: Message
    # Collect and optionally record batch-wise training losses.
    # Note: collection enables purging them from memory.
    losses = self.trainmanager.model.collect_training_losses()
    if self.ckptr is not None:
        self.ckptr.save_metrics(
            metrics={"training_losses": np.array(losses)},
            prefix="training_losses",
            append=True,
            timestamp=f"round_{message.round_i}",
        )
    # Optionally SecAgg-encrypt the reply.
    if self._encrypter is not None and isinstance(
        reply, messaging.TrainReply
    ):
        reply = SecaggTrainReply.from_cleartext_message(
            cleartext=reply, encrypter=self._encrypter
        )
    # Send training results (or error message) to the server.
    await self.netwk.send_message(reply)