Skip to content

declearn.utils.dataclass_from_init

Automatically build a dataclass matching a class's init signature.

Parameters:

Name Type Description Default
cls Type[S]

Class, the init signature of which to wrap up as a dataclass.

required
name Optional[str]

Name to attach to the returned dataclass. If None, use cls.__name__ + "Config" (e.g. "MyClassConfig" for a "MyClass" input class).

None

Returns:

Name Type Description
dataclass Dataclass-built type

Dataclass, the fields of which are the input arguments to the cls.__init__ method (with args as a list and *kwargs as a dict), exposing an instantiate method that triggers calling cls(...) with the wrapped parameters.

Source code in declearn/utils/_dataclass.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def dataclass_from_init(
    cls: Type[S],
    name: Optional[str] = None,
) -> Type[DataclassFromInit[S]]:
    """Automatically build a dataclass matching a class's init signature.

    Parameters
    ----------
    cls: Type
        Class, the __init__ signature of which to wrap up as a dataclass.
    name: str or None, default=None
        Name to attach to the returned dataclass.
        If None, use `cls.__name__` + "Config"
        (e.g. "MyClassConfig" for a "MyClass" input class).

    Returns
    -------
    dataclass: Dataclass-built type
        Dataclass, the fields of which are the input arguments to the
        `cls.__init__` method (with *args as a list and **kwargs as a
        dict), exposing an `instantiate` method that triggers calling
        `cls(...)` with the wrapped parameters.
    """
    # Parse the class's __init__ signature into dataclass Field instances.
    parameters = list(inspect.signature(cls.__init__).parameters.values())[1:]
    fields = _parameters_to_fields(parameters)
    # Make a dataclass out of the former fields.
    name = name or f"{cls.__name__}Config"
    dcls = dataclasses.make_dataclass(name, fields)  # type: Type
    # Bind the dataclass's main and __init__ docstrings.
    docs = f"Dataclass for {cls.__name__} instantiation parameters.\n"
    dcls.__doc__ = docs
    dcls.__init__.__doc__ = (
        docs + (cls.__init__.__doc__ or "").split("\n", 1)[-1]
    )
    # If the signature comprises *args / **kwargs parameters, record it.
    args_field = kwargs_field = None  # type: Optional[str]
    for param in parameters:
        if param.kind is param.VAR_POSITIONAL:
            args_field = param.name
        if param.kind is param.VAR_KEYWORD:
            kwargs_field = param.name

    # Add a method to instantiate from the dataclass.
    def instantiate(self) -> cls:  # type: ignore
        """Instantiate from the wrapped init parameters."""
        nonlocal args_field, kwargs_field
        params = dataclasses.asdict(self)
        args = params.pop(args_field) if args_field else []
        kwargs = params.pop(kwargs_field) if kwargs_field else {}
        return cls(*args, **params, **kwargs)

    instantiate.__doc__ = (
        f"Instantiate a {cls.__name__} from the wrapped init parameters."
    )
    dcls.instantiate = instantiate
    # Return the generated dataclass.
    return dcls