Automatically build a dataclass matching a function's signature.
Parameters:
Name |
Type |
Description |
Default |
func |
Callable[..., S]
|
Function, the input signature of which to wrap up as a dataclass. |
required
|
name |
Optional[str]
|
Name to attach to the returned dataclass.
If None, use CamelCase-converted func.__name__ + "Config"
(e.g. "MyFuncConfig" for a "my_func" input function). |
None
|
Returns:
Name | Type |
Description |
dataclass |
Dataclass-built type
|
Dataclass, the fields of which are the input arguments to func
(with args as a list and *kwargs as a dict), exposing a call
method that triggers calling func with the wrapped parameters. |
Source code in declearn/utils/_dataclass.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124 | def dataclass_from_func(
func: Callable[..., S],
name: Optional[str] = None,
) -> Type[DataclassFromFunc[S]]:
"""Automatically build a dataclass matching a function's signature.
Parameters
----------
func: callable
Function, the input signature of which to wrap up as a dataclass.
name: str or None, default=None
Name to attach to the returned dataclass.
If None, use CamelCase-converted `func.__name__` + "Config"
(e.g. "MyFuncConfig" for a "my_func" input function).
Returns
-------
dataclass: Dataclass-built type
Dataclass, the fields of which are the input arguments to `func`
(with *args as a list and **kwargs as a dict), exposing a `call`
method that triggers calling `func` with the wrapped parameters.
"""
# Parse the function's signature into dataclass Field instances.
signature = inspect.signature(func)
parameters = list(signature.parameters.values())
fields = _parameters_to_fields(parameters)
# Make a dataclass out of the former fields.
if not name:
name = "".join(w.capitalize() for w in func.__name__.split("_"))
name += "Config"
dcls = dataclasses.make_dataclass(name, fields) # type: Type
# Bind the dataclass's main and __init__ docstrings.
docs = f"Dataclass for {func.__name__} instantiation parameters.\n"
dcls.__doc__ = docs
dcls.__init__.__doc__ = docs + (func.__doc__ or "").split("\n", 1)[-1]
# If the signature comprises *args / **kwargs parameters, record it.
args_field = kwargs_field = None # type: Optional[str]
for param in parameters:
if param.kind is param.VAR_POSITIONAL:
args_field = param.name
if param.kind is param.VAR_KEYWORD:
kwargs_field = param.name
# Add a method to instantiate from the dataclass.
r_type = (
Any
if signature.return_annotation is signature.empty
else signature.return_annotation
)
def call(self) -> r_type: # type: ignore
"""Call from the wrapped parameters."""
nonlocal args_field, kwargs_field
params = dataclasses.asdict(self)
args = params.pop(args_field) if args_field else []
kwargs = params.pop(kwargs_field) if kwargs_field else {}
return func(*args, **params, **kwargs)
call.__doc__ = (
f"Call function {func.__name__} from the wrapped parameters."
)
dcls.call = call
# Return the generated dataclass.
return dcls
|