Skip to content

declearn.dataset.torch.collate_with_padding

Collate input elements into batches, with padding when required.

This custom collate function is designed to enable padding samples of variable length as part of their stacking into mini-batches. It relies on the torch.nn.utils.rnn.pad_sequence utility function and supports receiving samples that contain both inputs that need padding and that do not (e.g. variable-length token sequences as inputs but fixed-size values as labels).

It may be used as collate_fn argument to the declearn TorchDataset to wrap up data that needs such collation - but users are free to set up and use their own custom function if this fails to fit their data.

Parameters:

Name Type Description Default
samples List[Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]]

Sample-level records, formatted as (same-structure) tuples with torch tensors and/or lists of tensors as elements. None elements are also supported.

required

Returns:

Name Type Description
batch Tuple[Union[List[torch.Tensor], torch.Tensor], ...]

Tuple with the same structure as input ones, collating sample-level records into batched tensors.

Source code in declearn/dataset/torch/_utils.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def collate_with_padding(
    samples: List[Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]],
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], ...]:
    """Collate input elements into batches, with padding when required.

    This custom collate function is designed to enable padding samples of
    variable length as part of their stacking into mini-batches. It relies
    on the `torch.nn.utils.rnn.pad_sequence` utility function and supports
    receiving samples that contain both inputs that need padding and that
    do not (e.g. variable-length token sequences as inputs but fixed-size
    values as labels).

    It may be used as `collate_fn` argument to the declearn `TorchDataset`
    to wrap up data that needs such collation - but users are free to set
    up and use their own custom function if this fails to fit their data.

    Parameters
    ----------
    samples:
        Sample-level records, formatted as (same-structure) tuples with
        torch tensors and/or lists of tensors as elements. None elements
        are also supported.

    Returns
    -------
    batch:
        Tuple with the same structure as input ones, collating sample-level
        records into batched tensors.
    """
    output = []  # type: List[Union[List[torch.Tensor], torch.Tensor]]
    for i, element in enumerate(samples[0]):
        if element is None:
            output.append(None)
            continue
        if isinstance(element, (list, tuple)):
            out = [
                torch.nn.utils.rnn.pad_sequence(
                    [smp[i][j] for smp in samples],
                    batch_first=True,
                )
                for j in range(len(element))
            ]  # type: Union[torch.Tensor, List[torch.Tensor]]
        elif element.shape:
            out = torch.nn.utils.rnn.pad_sequence(
                [smp[i] for smp in samples],  # type: ignore  # false-positive
                batch_first=True,
            )
        else:
            out = torch.stack(  # pylint: disable=no-member
                [smp[i] for smp in samples]  # type: ignore
            )
        output.append(out)
    return tuple(output)