Skip to content

declearn.utils.dataclass_from_func

Automatically build a dataclass matching a function's signature.

Parameters:

Name Type Description Default
func Callable[..., S]

Function, the input signature of which to wrap up as a dataclass.

required
name Optional[str]

Name to attach to the returned dataclass. If None, use CamelCase-converted func.__name__ + "Config" (e.g. "MyFuncConfig" for a "my_func" input function).

None

Returns:

Name Type Description
dataclass Dataclass-built type

Dataclass, the fields of which are the input arguments to func (with args as a list and *kwargs as a dict), exposing a call method that triggers calling func with the wrapped parameters.

Source code in declearn/utils/_dataclass.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def dataclass_from_func(
    func: Callable[..., S],
    name: Optional[str] = None,
) -> Type[DataclassFromFunc[S]]:
    """Automatically build a dataclass matching a function's signature.

    Parameters
    ----------
    func: callable
        Function, the input signature of which to wrap up as a dataclass.
    name: str or None, default=None
        Name to attach to the returned dataclass.
        If None, use CamelCase-converted `func.__name__` + "Config"
        (e.g. "MyFuncConfig" for a "my_func" input function).

    Returns
    -------
    dataclass: Dataclass-built type
        Dataclass, the fields of which are the input arguments to `func`
        (with *args as a list and **kwargs as a dict), exposing a `call`
        method that triggers calling `func` with the wrapped parameters.
    """
    # Parse the function's signature into dataclass Field instances.
    signature = inspect.signature(func)
    parameters = list(signature.parameters.values())
    fields = _parameters_to_fields(parameters)
    # Make a dataclass out of the former fields.
    if not name:
        name = "".join(w.capitalize() for w in func.__name__.split("_"))
        name += "Config"
    dcls = dataclasses.make_dataclass(name, fields)  # type: Type
    # Bind the dataclass's main and __init__ docstrings.
    docs = f"Dataclass for {func.__name__} instantiation parameters.\n"
    dcls.__doc__ = docs
    dcls.__init__.__doc__ = docs + (func.__doc__ or "").split("\n", 1)[-1]
    # If the signature comprises *args / **kwargs parameters, record it.
    args_field = kwargs_field = None  # type: Optional[str]
    for param in parameters:
        if param.kind is param.VAR_POSITIONAL:
            args_field = param.name
        if param.kind is param.VAR_KEYWORD:
            kwargs_field = param.name
    # Add a method to instantiate from the dataclass.
    r_type = (
        Any
        if signature.return_annotation is signature.empty
        else signature.return_annotation
    )

    def call(self) -> r_type:  # type: ignore
        """Call from the wrapped parameters."""
        nonlocal args_field, kwargs_field
        params = dataclasses.asdict(self)
        args = params.pop(args_field) if args_field else []
        kwargs = params.pop(kwargs_field) if kwargs_field else {}
        return func(*args, **params, **kwargs)

    call.__doc__ = (
        f"Call function {func.__name__} from the wrapped parameters."
    )
    dcls.call = call
    # Return the generated dataclass.
    return dcls