Skip to content

declearn.model.tensorflow.TensorflowModel

Bases: Model

Model wrapper for TensorFlow Model instances.

This Model subclass is designed to wrap a tf_keras.Model instance to be trained federatively.

Notes regarding device management (CPU, GPU, etc.):

  • By default, tensorflow places data and operations on GPU whenever one is available.
  • Our TensorflowModel instead consults the device-placement policy (via declearn.utils.get_device_policy), places the wrapped keras model's weights there, and runs computations defined under public methods in a tensorflow.device context, to enforce that policy.
  • Note that there is no guarantee that calling a private method directly will result in abiding by that policy. Hence, be careful when writing custom code, and use your own context managers to get guarantees.
  • Note that if the global device-placement policy is updated, this will only be propagated to existing instances by manually calling their update_device_policy method.
  • You may consult the device policy enforced by a TensorflowModel instance by accessing its device_policy property.
Source code in declearn/model/tensorflow/_model.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
@register_type(name="TensorflowModel", group="Model")
class TensorflowModel(Model):
    """Model wrapper for TensorFlow Model instances.

    This `Model` subclass is designed to wrap a `tf_keras.Model` instance
    to be trained federatively.

    Notes regarding device management (CPU, GPU, etc.):

    - By default, tensorflow places data and operations on GPU whenever one
      is available.
    - Our `TensorflowModel` instead consults the device-placement policy (via
      `declearn.utils.get_device_policy`), places the wrapped keras model's
      weights there, and runs computations defined under public methods in
      a `tensorflow.device` context, to enforce that policy.
    - Note that there is no guarantee that calling a private method directly
      will result in abiding by that policy. Hence, be careful when writing
      custom code, and use your own context managers to get guarantees.
    - Note that if the global device-placement policy is updated, this will
      only be propagated to existing instances by manually calling their
      `update_device_policy` method.
    - You may consult the device policy enforced by a TensorflowModel
      instance by accessing its `device_policy` property.
    """

    def __init__(
        self,
        model: tf_keras.layers.Layer,
        loss: Union[str, tf_keras.losses.Loss],
        metrics: Optional[List[Union[str, tf_keras.metrics.Metric]]] = None,
        _from_config: bool = False,
        **kwargs: Any,
    ) -> None:
        """Instantiate a Model interface wrapping a tensorflow.keras model.

        Parameters
        ----------
        model: tf.keras.layers.Layer
            Keras Layer (or Model) instance that defines the model's
            architecture. If a Layer is provided, it will be wrapped
            into a keras Sequential Model.
        loss: tf.keras.losses.Loss or str
            Keras Loss instance, or name of one. If a function (name)
            is provided, it will be converted to a Loss instance, and
            an exception may be raised if that fails.
        metrics: list[str or tf.keras.metrics.Metric] or None
            List of keras Metric instances, or their names. These are
            compiled with the model and computed using the `evaluate`
            method of the returned TensorflowModel instance.
        **kwargs: Any
            Any additional keyword argument to `tf_keras.Model.compile`
            may be passed.
        """
        # Type-check the input Model and wrap it up.
        if not isinstance(model, tf_keras.layers.Layer):
            raise TypeError(
                "'model' should be a tf_keras.layers.Layer instance."
            )
        if not isinstance(model, tf_keras.Model):
            model = tf_keras.Sequential([model])
        super().__init__(model)
        # Ensure the loss is a keras.Loss object and set its reduction to none.
        loss = build_keras_loss(loss, reduction=tf_keras.losses.Reduction.NONE)
        # Select the device where to place computations and move the model.
        policy = get_device_policy()
        self._device = select_device(gpu=policy.gpu, idx=policy.idx)
        if not _from_config:
            self._model = move_layer_to_device(self._model, self._device)
        # Finalize initialization using the selected device.
        # Compile the wrapped model and retain compilation arguments.
        with tf.device(self._device):
            kwargs.update({"loss": loss, "metrics": metrics})
            self._model.compile(**kwargs)
            self._kwargs = kwargs

    @property
    def device_policy(
        self,
    ) -> DevicePolicy:
        device = self._device
        try:
            idx = int(device.name.rsplit(":", 1)[-1])
        except ValueError:
            idx = None
        return DevicePolicy(gpu=(device.device_type == "GPU"), idx=idx)

    @property
    def required_data_info(
        self,
    ) -> Set[str]:
        return set() if self._model.built else {"features_shape"}

    def initialize(
        self,
        data_info: Dict[str, Any],
    ) -> None:
        if not self._model.built:
            # Type- and value-check the provided information.
            data_info = aggregate_data_info(
                [data_info], self.required_data_info
            )
            # Build the model using the specified input shape.
            with tf.device(self._device):
                self._model.build((None, *data_info["features_shape"]))

    def get_config(
        self,
    ) -> Dict[str, Any]:
        config = tf_keras.layers.serialize(self._model)  # type: Dict[str, Any]
        kwargs = deepcopy(self._kwargs)
        loss = tf_keras.losses.serialize(kwargs.pop("loss"))
        return {"model": config, "loss": loss, "kwargs": kwargs}

    @classmethod
    def from_config(
        cls,
        config: Dict[str, Any],
    ) -> Self:
        """Instantiate a TensorflowModel from a configuration dict."""
        for key in ("model", "loss", "kwargs"):
            if key not in config.keys():
                raise KeyError(f"Missing key '{key}' in the config dict.")
        # Set up the device policy.
        policy = get_device_policy()
        device = select_device(gpu=policy.gpu, idx=policy.idx)
        # Deserialize the model and loss keras objects on the device.
        with tf.device(device):
            model = tf_keras.layers.deserialize(config["model"])
            loss = tf_keras.losses.deserialize(config["loss"])
        # Instantiate the TensorflowModel, avoiding device-to-device copies.
        return cls(model, loss, **config["kwargs"], _from_config=True)

    def get_weights(
        self,
        trainable: bool = False,
    ) -> TensorflowVector:
        variables = self._get_weight_variables(trainable)
        return TensorflowVector({var.name: var.value() for var in variables})

    def _get_weight_variables(
        self,
        trainable: bool,
    ) -> Iterable[tf.Variable]:
        """Access TensorFlow Variables wrapping model weight tensors."""
        variables = (
            self._model.trainable_weights if trainable else self._model.weights
        )
        if hasattr(tf_keras, "version") and tf_keras.version().startswith("3"):
            variables = (var.value for var in variables)
        return variables

    def set_weights(
        self,
        weights: TensorflowVector,
        trainable: bool = False,
    ) -> None:
        if not isinstance(weights, TensorflowVector):
            raise TypeError(
                "TensorflowModel requires TensorflowVector weights."
            )
        self._verify_weights_compatibility(weights, trainable=trainable)
        variables = {
            var.name: var for var in self._get_weight_variables(trainable)
        }
        with tf.device(self._device):
            for name, value in weights.coefs.items():
                variables[name].assign(value, read_value=False)

    def _verify_weights_compatibility(
        self,
        vector: TensorflowVector,
        trainable: bool = False,
    ) -> None:
        """Verify that a vector has the same names as the model's weights.

        Parameters
        ----------
        vector: TensorflowVector
            Vector wrapping weight-related coefficients (e.g. weight
            values or gradient-based updates).
        trainable: bool, default=False
            Whether to restrict the comparision to the model's trainable
            weights rather than to all of its weights.

        Raises
        ------
        KeyError
            In case some expected keys are missing, or additional keys
            are present. Be verbose about the identified mismatch(es).
        """
        variables = self._get_weight_variables(trainable)
        raise_on_stringsets_mismatch(
            received=set(vector.coefs),
            expected={var.name for var in variables},
            context="model weights",
        )

    def compute_batch_gradients(
        self,
        batch: Batch,
        max_norm: Optional[float] = None,
    ) -> TensorflowVector:
        with tf.device(self._device):
            data = self._unpack_batch(batch)
            if max_norm is None:
                grads, loss = self._compute_batch_gradients(*data)
            else:
                norm = tf.constant(max_norm)
                grads, loss = self._compute_clipped_gradients(*data, norm)
        self._loss_history.append(float(loss.numpy()))
        grads_and_vars = zip(grads, self._get_weight_variables(trainable=True))
        return TensorflowVector(
            {var.name: grad for grad, var in grads_and_vars}
        )

    def _unpack_batch(
        self,
        batch: Batch,
    ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor]]:
        """Unpack and enforce Tensor conversion to an input data batch."""
        # fmt: off
        # Define an array-to-tensor conversion routine.
        def convert(data: Optional[ArrayLike]) -> Optional[tf.Tensor]:
            if (data is None) or tf.is_tensor(data):
                return data
            return tf.convert_to_tensor(data)
        # Apply it to the the batched elements.
        return tf.nest.map_structure(convert, batch)
        # fmt: on

    @tf.function  # optimize tensorflow runtime
    def _compute_batch_gradients(
        self,
        inputs: tf.Tensor,
        y_true: Optional[tf.Tensor],
        s_wght: Optional[tf.Tensor],
    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
        """Compute and return batch-averaged gradients of trainable weights."""
        with tf.GradientTape() as tape:
            y_pred = self._model(inputs, training=True)
            loss = self._model.compute_loss(inputs, y_true, y_pred, s_wght)
            loss = tf.reduce_mean(loss)
            grad = tape.gradient(loss, self._model.trainable_weights)
        return grad, loss

    @tf.function  # optimize tensorflow runtime
    def _compute_clipped_gradients(
        self,
        inputs: tf.Tensor,
        y_true: Optional[tf.Tensor],
        s_wght: Optional[tf.Tensor],
        max_norm: Union[tf.Tensor, float],
    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
        """Compute and return sample-wise-clipped batch-averaged gradients."""
        grad, loss = self._compute_samplewise_gradients(inputs, y_true)
        if s_wght is None:
            s_wght = tf.cast(1, grad[0].dtype)
        grad = self._clip_and_average_gradients(grad, max_norm, s_wght)
        return grad, loss

    @tf.function  # optimize tensorflow runtime
    def _compute_samplewise_gradients(
        self,
        inputs: tf.Tensor,
        y_true: Optional[tf.Tensor],
    ) -> Tuple[List[tf.Tensor], tf.Tensor]:
        """Compute and return sample-wise gradients for a given batch."""
        with tf.GradientTape() as tape:
            y_pred = self._model(inputs, training=True)
            loss = self._model.compute_loss(inputs, y_true, y_pred)
            grad = tape.jacobian(loss, self._model.trainable_weights)
        return grad, tf.reduce_mean(loss)

    @staticmethod
    @tf.function  # optimize tensorflow runtime
    def _clip_and_average_gradients(
        gradients: List[tf.Tensor],
        max_norm: Union[tf.Tensor, float],
        s_wght: tf.Tensor,
    ) -> List[tf.Tensor]:
        """Clip sample-wise gradients then batch-average them."""
        outp = []  # type: List[tf.Tensor]
        for grad in gradients:
            dims = list(range(1, grad.shape.rank))
            grad = tf.clip_by_norm(grad, max_norm, axes=dims)
            outp.append(tf.reduce_mean(grad * s_wght, axis=0))
        return outp

    def apply_updates(
        self,
        updates: TensorflowVector,
    ) -> None:
        self._verify_weights_compatibility(updates, trainable=True)
        with tf.device(self._device):
            for var in self._get_weight_variables(trainable=True):
                updt = updates.coefs[var.name]
                if isinstance(updt, tf.IndexedSlices):
                    var.scatter_add(updt)
                else:
                    var.assign_add(updt, read_value=False)

    def evaluate(
        self,
        dataset: Iterable[Batch],
    ) -> Dict[str, float]:
        """Compute the model's built-in evaluation metrics on a given dataset.

        Parameters
        ----------
        dataset: iterable of batches
            Iterable yielding batch structures that are to be unpacked
            into (input_features, target_labels, [sample_weights]).
            If set, sample weights will affect metrics' averaging.

        Returns
        -------
        metrics: dict[str, float]
            Dictionary associating evaluation metrics' values to their name.
        """
        with tf.device(self._device):
            return self._model.evaluate(dataset, return_dict=True)

    def compute_batch_predictions(
        self,
        batch: Batch,
    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
        with tf.device(self._device):
            inputs, y_true, s_wght = self._unpack_batch(batch)
            if y_true is None:
                raise TypeError(
                    "`TensorflowModel.compute_batch_predictions` received a "
                    "batch with `y_true=None`, which is unsupported. Please "
                    "correct the inputs, or override this method to support "
                    "creating labels from the base inputs."
                )
            y_pred = self._model(inputs, training=False).numpy()
        y_true = y_true.numpy()
        s_wght = s_wght.numpy() if s_wght is not None else s_wght
        return y_true, y_pred, s_wght

    def loss_function(
        self,
        y_true: np.ndarray,
        y_pred: np.ndarray,
    ) -> np.ndarray:
        with tf.device(self._device):
            tns_true = tf.convert_to_tensor(y_true)
            tns_pred = tf.convert_to_tensor(y_pred)
            s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred)
        return s_loss.numpy().squeeze()

    def update_device_policy(
        self,
        policy: Optional[DevicePolicy] = None,
    ) -> None:
        # Select the device to use based on the provided or global policy.
        if policy is None:
            policy = get_device_policy()
        device = select_device(gpu=policy.gpu, idx=policy.idx)
        # When needed, re-create the model to force moving it to the device.
        if self._device is not device:
            self._device = device
            self._model = move_layer_to_device(self._model, self._device)

__init__(model, loss, metrics=None, _from_config=False, **kwargs)

Instantiate a Model interface wrapping a tensorflow.keras model.

Parameters:

Name Type Description Default
model tf_keras.layers.Layer

Keras Layer (or Model) instance that defines the model's architecture. If a Layer is provided, it will be wrapped into a keras Sequential Model.

required
loss Union[str, tf_keras.losses.Loss]

Keras Loss instance, or name of one. If a function (name) is provided, it will be converted to a Loss instance, and an exception may be raised if that fails.

required
metrics Optional[List[Union[str, tf_keras.metrics.Metric]]]

List of keras Metric instances, or their names. These are compiled with the model and computed using the evaluate method of the returned TensorflowModel instance.

None
**kwargs Any

Any additional keyword argument to tf_keras.Model.compile may be passed.

{}
Source code in declearn/model/tensorflow/_model.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def __init__(
    self,
    model: tf_keras.layers.Layer,
    loss: Union[str, tf_keras.losses.Loss],
    metrics: Optional[List[Union[str, tf_keras.metrics.Metric]]] = None,
    _from_config: bool = False,
    **kwargs: Any,
) -> None:
    """Instantiate a Model interface wrapping a tensorflow.keras model.

    Parameters
    ----------
    model: tf.keras.layers.Layer
        Keras Layer (or Model) instance that defines the model's
        architecture. If a Layer is provided, it will be wrapped
        into a keras Sequential Model.
    loss: tf.keras.losses.Loss or str
        Keras Loss instance, or name of one. If a function (name)
        is provided, it will be converted to a Loss instance, and
        an exception may be raised if that fails.
    metrics: list[str or tf.keras.metrics.Metric] or None
        List of keras Metric instances, or their names. These are
        compiled with the model and computed using the `evaluate`
        method of the returned TensorflowModel instance.
    **kwargs: Any
        Any additional keyword argument to `tf_keras.Model.compile`
        may be passed.
    """
    # Type-check the input Model and wrap it up.
    if not isinstance(model, tf_keras.layers.Layer):
        raise TypeError(
            "'model' should be a tf_keras.layers.Layer instance."
        )
    if not isinstance(model, tf_keras.Model):
        model = tf_keras.Sequential([model])
    super().__init__(model)
    # Ensure the loss is a keras.Loss object and set its reduction to none.
    loss = build_keras_loss(loss, reduction=tf_keras.losses.Reduction.NONE)
    # Select the device where to place computations and move the model.
    policy = get_device_policy()
    self._device = select_device(gpu=policy.gpu, idx=policy.idx)
    if not _from_config:
        self._model = move_layer_to_device(self._model, self._device)
    # Finalize initialization using the selected device.
    # Compile the wrapped model and retain compilation arguments.
    with tf.device(self._device):
        kwargs.update({"loss": loss, "metrics": metrics})
        self._model.compile(**kwargs)
        self._kwargs = kwargs

evaluate(dataset)

Compute the model's built-in evaluation metrics on a given dataset.

Parameters:

Name Type Description Default
dataset Iterable[Batch]

Iterable yielding batch structures that are to be unpacked into (input_features, target_labels, [sample_weights]). If set, sample weights will affect metrics' averaging.

required

Returns:

Name Type Description
metrics dict[str, float]

Dictionary associating evaluation metrics' values to their name.

Source code in declearn/model/tensorflow/_model.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def evaluate(
    self,
    dataset: Iterable[Batch],
) -> Dict[str, float]:
    """Compute the model's built-in evaluation metrics on a given dataset.

    Parameters
    ----------
    dataset: iterable of batches
        Iterable yielding batch structures that are to be unpacked
        into (input_features, target_labels, [sample_weights]).
        If set, sample weights will affect metrics' averaging.

    Returns
    -------
    metrics: dict[str, float]
        Dictionary associating evaluation metrics' values to their name.
    """
    with tf.device(self._device):
        return self._model.evaluate(dataset, return_dict=True)

from_config(config) classmethod

Instantiate a TensorflowModel from a configuration dict.

Source code in declearn/model/tensorflow/_model.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@classmethod
def from_config(
    cls,
    config: Dict[str, Any],
) -> Self:
    """Instantiate a TensorflowModel from a configuration dict."""
    for key in ("model", "loss", "kwargs"):
        if key not in config.keys():
            raise KeyError(f"Missing key '{key}' in the config dict.")
    # Set up the device policy.
    policy = get_device_policy()
    device = select_device(gpu=policy.gpu, idx=policy.idx)
    # Deserialize the model and loss keras objects on the device.
    with tf.device(device):
        model = tf_keras.layers.deserialize(config["model"])
        loss = tf_keras.losses.deserialize(config["loss"])
    # Instantiate the TensorflowModel, avoiding device-to-device copies.
    return cls(model, loss, **config["kwargs"], _from_config=True)