Bases: DataInfoField
Specifications for 'features_shape' data_info field.
Source code in declearn/data_info/_fields.py
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129 | @register_data_info_field
class FeaturesShapeField(DataInfoField):
"""Specifications for 'features_shape' data_info field."""
field = "features_shape"
types = (tuple, list)
doc = "Input features' shape, excluding batch size, checked to be equal."
@classmethod
def is_valid(
cls,
value: Any,
) -> bool:
return isinstance(value, cls.types) and all(
(isinstance(val, int) and val > 0) or (val is None)
for val in value
)
@classmethod
def combine(
cls,
*values: Any,
) -> Tuple[Optional[int], ...]:
# Type check each and every input shape.
super().combine(*values)
# Check that all shapes are the same.
unique_shapes = list({tuple(shp) for shp in values})
if len(unique_shapes) != 1:
raise ValueError(
f"Cannot combine '{cls.field}': non-unique shapes."
)
return unique_shapes[0]
|