declearn.dataset.torch.collate_with_padding
Collate input elements into batches, with padding when required.
This custom collate function is designed to enable padding samples of
variable length as part of their stacking into mini-batches. It relies
on the torch.nn.utils.rnn.pad_sequence
utility function and supports
receiving samples that contain both inputs that need padding and that
do not (e.g. variable-length token sequences as inputs but fixed-size
values as labels).
It may be used as collate_fn
argument to the declearn TorchDataset
to wrap up data that needs such collation - but users are free to set
up and use their own custom function if this fails to fit their data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
List[Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]]
|
Sample-level records, formatted as (same-structure) tuples with torch tensors and/or lists of tensors as elements. None elements are also supported. |
required |
Returns:
Name | Type | Description |
---|---|---|
batch |
Tuple[Union[List[torch.Tensor], torch.Tensor], ...]
|
Tuple with the same structure as input ones, collating sample-level records into batched tensors. |
Source code in declearn/dataset/torch/_utils.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
|