Skip to content

declearn.data_info.FeaturesShapeField

Bases: DataInfoField

Specifications for 'features_shape' data_info field.

Source code in declearn/data_info/_fields.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@register_data_info_field
class FeaturesShapeField(DataInfoField):
    """Specifications for 'features_shape' data_info field."""

    field = "features_shape"
    types = (tuple, list)
    doc = "Input features' shape, excluding batch size, checked to be equal."

    @classmethod
    def is_valid(
        cls,
        value: Any,
    ) -> bool:
        return isinstance(value, cls.types) and all(
            (isinstance(val, int) and val > 0) or (val is None)
            for val in value
        )

    @classmethod
    def combine(
        cls,
        *values: Any,
    ) -> Tuple[Optional[int], ...]:
        # Type check each and every input shape.
        super().combine(*values)
        # Check that all shapes are the same.
        unique_shapes = list({tuple(shp) for shp in values})
        if len(unique_shapes) != 1:
            raise ValueError(
                f"Cannot combine '{cls.field}': non-unique shapes."
            )
        return unique_shapes[0]