Skip to content

declearn.model.tensorflow.utils.preserve_tensor_device

Wrap a tensor-processing function to have it run on its inputs' device.

Parameters:

Name Type Description Default
func Callable[..., tf.Tensor]

Function to wrap, that takes a tensorflow Tensor as first argument.

required

Returns:

Name Type Description
func function(tf.Tensor, ...) -

Similar function to the input one, that operates under a tf.device context so as to run computations on the first input tensor's device.

Source code in declearn/model/tensorflow/utils/_gpu.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def preserve_tensor_device(
    func: Callable[..., tf.Tensor],
) -> Callable[..., tf.Tensor]:
    """Wrap a tensor-processing function to have it run on its inputs' device.

    Parameters
    ----------
    func: function(tf.Tensor, ...) -> tf.Tensor:
        Function to wrap, that takes a tensorflow Tensor as first argument.

    Returns
    -------
    func: function(tf.Tensor, ...) -> tf.Tensor:
        Similar function to the input one, that operates under a `tf.device`
        context so as to run computations on the first input tensor's device.
    """

    @functools.wraps(func)
    def wrapped(tensor: tf.Tensor, *args: Any, **kwargs: Any) -> tf.Tensor:
        """Wrapped function, running under a `tf.device` context."""
        with tf.device(tensor.device):
            return func(tensor, *args, **kwargs)

    return wrapped