Skip to content

declearn.model.tensorflow.utils.build_keras_loss

Type-check, deserialize and/or wrap a keras loss into a Loss object.

Parameters:

Name Type Description Default
loss Union[str, tf_keras.losses.Loss, CallableLoss]

Either a keras Loss object, the name of a keras loss, or a loss function that needs wrapping into a Loss object.

required
reduction str

Reduction scheme to apply on point-wise loss values.

tf_keras.losses.Reduction.NONE

Returns:

Name Type Description
loss_obj tf_keras.losses.Loss

Loss object, configured to apply the reduction scheme.

Source code in declearn/model/tensorflow/utils/_loss.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def build_keras_loss(
    loss: Union[str, tf_keras.losses.Loss, CallableLoss],
    reduction: str = tf_keras.losses.Reduction.NONE,
) -> tf_keras.losses.Loss:
    """Type-check, deserialize and/or wrap a keras loss into a Loss object.

    Parameters
    ----------
    loss: str or tf.keras.losses.Loss or function(y_true, y_pred)->loss
        Either a keras Loss object, the name of a keras loss, or a loss
        function that needs wrapping into a Loss object.
    reduction: str, default=`tf.keras.losses.Reduction.NONE`
        Reduction scheme to apply on point-wise loss values.

    Returns
    -------
    loss_obj: tf_keras.losses.Loss
        Loss object, configured to apply the `reduction` scheme.
    """
    # Case when 'loss' is already a Loss object.
    if isinstance(loss, tf_keras.losses.Loss):
        loss.reduction = reduction
    # Case when 'loss' is a string: deserialize and/or wrap into a Loss object.
    elif isinstance(loss, str):
        loss = get_keras_loss_from_string(name=loss, reduction=reduction)
    # Case when 'loss' is a function: wrap it up using LossFunction.
    elif inspect.isfunction(loss):
        loss = LossFunction(loss, reduction=reduction)
    # Case when 'loss' is of invalid type: raise a TypeError.
    if not isinstance(loss, tf_keras.losses.Loss):
        raise TypeError(
            "'loss' should be a keras Loss object or the name of one."
        )
    # Otherwise, properly configure the reduction scheme and return.
    return loss