Skip to content

declearn.data_info.InputShapeField

Bases: DataInfoField

Specifications for 'input_shape' data_info field.

Source code in declearn/data_info/_fields.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@register_data_info_field
class InputShapeField(DataInfoField):
    """Specifications for 'input_shape' data_info field."""

    field = "input_shape"
    types = (tuple, list)
    doc = "Input features' batched shape, checked to be equal."

    @classmethod
    def is_valid(
        cls,
        value: Any,
    ) -> bool:
        return (
            isinstance(value, cls.types)
            and (len(value) >= 2)
            and all(isinstance(val, int) or (val is None) for val in value)
        )

    @classmethod
    def combine(
        cls,
        *values: Any,
    ) -> List[Optional[int]]:
        # Warn about this class being deprecated.
        warnings.warn(
            "'NbFeaturesField has been deprecated as of declearn v2.2,"
            " and will be removed in v2.4 and/or v3.0."
            " Please use 'SingleInputShapeField' instead.",
            DeprecationWarning,
            stacklevel=3,
        )
        # Type check each and every input shape.
        super().combine(*values)
        # Check that all shapes are of same length.
        unique = list({len(shp) for shp in values})
        if len(unique) != 1:
            raise ValueError(
                f"Cannot combine '{cls.field}': inputs have various lengths."
            )
        # Fill-in the unified shape: except all-None or (None or unique) value.
        # Note: batching dimension is set to None by default (no check).
        shape = [None] * unique[0]  # type: List[Optional[int]]
        for i in range(1, unique[0]):
            val = [shp[i] for shp in values if shp[i] is not None]
            if not val:  # all None
                shape[i] = None
            elif len(set(val)) > 1:
                raise ValueError(
                    f"Cannot combine '{cls.field}': provided shapes differ."
                )
            else:
                shape[i] = val[0]
        # Return the combined shape.
        return shape