Skip to content

declearn.data_info.ClassesField

Bases: DataInfoField

Specifications for 'classes' data_info field.

Source code in declearn/data_info/_fields.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@register_data_info_field
class ClassesField(DataInfoField):
    """Specifications for 'classes' data_info field."""

    field = "classes"
    types = (list, set, tuple, np.ndarray)
    doc = "Set of classification targets, combined by union."

    @classmethod
    def is_valid(
        cls,
        value: Any,
    ) -> bool:
        if isinstance(value, np.ndarray):
            return value.ndim == 1
        return super().is_valid(value)

    @classmethod
    def combine(
        cls,
        *values: Any,
    ) -> Set[Any]:
        super().combine(*values)  # type-check inputs
        return set.union(*map(set, values))