Skip to content

declearn.model.tensorflow.utils.add_indexed_slices_support

Wrap an input function to overload the handling of tf.IndexedSlices.

Parameters:

Name Type Description Default
tf_op Callable[[tf.Tensor, Any], tf.Tensor]

Tensor-processing operation that needs wrapping.

required
inplc bool

Whether to replace the second argument of tf_op with None. Use this to transform tensor-processing functions (wich, in general, have a name=None argument) rather than operations.

False

Returns:

Name Type Description
func Callable[[TensorT, Any], TensorT]

Tensor-processing operation that wraps tf_op but supports and preserves tf.IndexedSlices inputs as first (and opt. second) argument. Note that in the rare case when func(slices, dense) is called, the output will be dense, and a RuntimeWarning will be raised.

Source code in declearn/model/tensorflow/utils/_slices.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def add_indexed_slices_support(
    tf_op: Callable[[tf.Tensor, Any], tf.Tensor],
    inplc: bool = False,
) -> Callable[[TensorT, Any], TensorT]:
    """Wrap an input function to overload the handling of tf.IndexedSlices.

    Parameters
    ----------
    tf_op: function(tf.Tensor, [any]) -> tf.Tensor
        Tensor-processing operation that needs wrapping.
    inplc: bool, default=False
        Whether to replace the second argument of `tf_op` with None.
        Use this to transform tensor-processing functions (wich, in
        general, have a `name=None` argument) rather than operations.

    Returns
    -------
    func:
        Tensor-processing operation that wraps `tf_op` but supports and
        preserves tf.IndexedSlices inputs as first (and opt. second)
        argument.
        Note that in the rare case when func(slices, dense) is called,
        the output will be dense, and a RuntimeWarning will be raised.
    """
    func = functools.partial(apply_func_to_tensor_or_slices, tf_op=tf_op)
    if inplc:
        func = functools.partial(func, other=None)
    return functools.wraps(tf_op)(func)