Skip to content

declearn.dataset.tensorflow.TensorflowDataset

Bases: Dataset

Dataset subclass to wrap up 'tensorflow.data.Dataset' instances.

Source code in declearn/dataset/tensorflow/_tensorflow.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
@register_type(group="Dataset")
class TensorflowDataset(Dataset):
    """Dataset subclass to wrap up 'tensorflow.data.Dataset' instances."""

    def __init__(
        self,
        dataset: tf.data.Dataset,
        buffer_size: Optional[int] = None,
        batch_mode: BatchMode = "default",
        seed: Optional[int] = None,
    ) -> None:
        """Wrap up a 'tensorflow.data.Dataset' into a declearn Dataset.

        Parameters
        ----------
        dataset: tensorflow.data.Dataset
            A tensorflow Dataset instance to be wrapped for declearn use.
            The dataset is expected to yield sample-level records, made
            of one to three (tuples of) tensorflow tensors: model inputs,
            target labels and/or sample weights.
        buffer_size: int or None, default=None
            Optional buffer size denoting the number of samples to pre-fetch
            and shuffle when sampling from the original dataset. The higher,
            the better the shuffling, but also the more memory costly.
            If None, use context-based `batch_size * 10` value.
        batch_mode: str in {"default", "padded", "ragged"}
            Flag specifying how to batch inputs. Use "padded" or "ragged" to
            batch up variable-dimension samples (e.g. sequences of tokens),
            using either `tf.data.Dataset.padded_batch` or `.ragged_batch`.
        seed: int or None, default=None
            Optional seed for the random number generator based on which
            the dataset is (optionally) shuffled when generating batches.
            Note that successive batch-generating calls will not yield
            the same results, as the seeded state is not reset on each
            call.

        Notes
        -----
        The wrapped `tensorflow.data.Dataset`:

        - *must* have a fixed length (with TensorFlow <2.13) / *should*
          have an established `cardinality` (TensorFlow >=2.13).
        - should return sample-level (unbatched) elements, as either:
            - (inputs,)
            - (inputs, labels)
            - (inputs, labels, weights)
          where each element may be a (nested structure of) tensor(s).
        - when using `declearn.model.tensorflow.TensorflowModel`:
            - inputs may be a single tensor or list of tensors
            - labels may be a single tensor or None (usually, not None)
            - weights may be a single tensor or None
        """
        # Assign the dataset, parse and validate its specifications.
        self.dataset = dataset
        self._dspecs = parse_and_validate_tensorflow_dataset(self.dataset)
        warn_if_dataset_is_likely_batched(self.dataset)
        # Assign additional parameters and set up an opt.-seeded RNG.
        self.batch_mode = batch_mode
        self.buffer_size = buffer_size
        self.rng = np.random.default_rng(seed)

    def get_data_specs(
        self,
    ) -> DataSpecs:
        return DataSpecs(
            n_samples=self._dspecs.n_samples,
            features_shape=self._dspecs.input_shp,
            classes=self._dspecs.y_classes,
            data_type=self._dspecs.data_type,
        )

    def generate_batches(
        self,
        batch_size: int,
        shuffle: bool = False,
        drop_remainder: bool = True,
        replacement: bool = False,
        poisson: bool = False,
    ) -> Iterator[Batch]:
        # inherited signature; pylint: disable=too-many-arguments
        if poisson:
            generator = self._generate_batches_poisson(
                batch_size=batch_size,
                drop_remainder=drop_remainder,
            )
        else:
            generator = self._generate_batches_batching(
                batch_size=batch_size,
                drop_remainder=drop_remainder,
                shuffle=shuffle,
                replacement=replacement,
            )
        yield from generator

    def _generate_batches_batching(
        self,
        batch_size: int,
        drop_remainder: bool,
        shuffle: bool,
        replacement: bool,
    ) -> Iterator[Batch]:
        """Backend to `generate_batches` when `poisson=False`."""
        # Start setting up the dataset, then setup optional shuffling.
        dataset, _, n_batches, none_pads = self._prepare_dataset(
            batch_size=batch_size, drop_remainder=drop_remainder
        )
        if shuffle:
            dataset = self._setup_shuffling(
                dataset, batch_size=batch_size, replacement=replacement
            )
        # Set up batching, opt. in padded or ragged mode.
        batch = get_batch_function(self.batch_mode, dataset=dataset)
        dataset = batch(batch_size=batch_size, drop_remainder=drop_remainder)
        dataset = dataset.take(n_batches).map(lambda *s: (*s, *none_pads))
        yield from dataset

    def _generate_batches_poisson(
        self,
        batch_size: int,
        drop_remainder: bool,
    ) -> Iterator[Batch]:
        """Backend to `generate_batches` when `poisson=True`."""
        # Start setting up the dataset, then setup shuffling with replacement.
        dataset, n_samples, n_batches, none_pads = self._prepare_dataset(
            batch_size=batch_size, drop_remainder=drop_remainder
        )
        dataset = self._setup_shuffling(
            dataset, batch_size=batch_size, replacement=True
        )
        # Draw batches' size, that follows a Binomial law.
        srate = batch_size / n_samples
        sizes = self.rng.binomial(n=n_samples, p=srate, size=n_batches)
        # Fetch and batch up samples manually.
        itersamples = iter(dataset)
        stack = get_stack_function(self.batch_mode)
        for size in sizes:
            if not size:  # skip empty batches (edge case)
                continue
            samples = [
                # infinite iterator; pylint: disable=stop-iteration-return
                next(itersamples)
                for _ in range(size)
            ]
            batch = tf.nest.map_structure(stack, *samples)
            yield (*batch, *none_pads)  # type: ignore

    def _prepare_dataset(
        self,
        batch_size: int,
        drop_remainder: bool,
    ) -> Tuple[tf.data.Dataset, int, int, List[None]]:
        """Run initial preparations to generate batches from the dataset.

        Return
        - the wrapped dataset (optionally with some transforms),
        - the number of samples in the initial dataset
        - the number of batches that should be yielded
        - a list of None values to end up adding to complete batches
        """
        dataset = self.dataset
        if self._dspecs.single_el:
            dataset = dataset.map(lambda x: (x,))
        none_pads = [None] * self._dspecs.n_padding
        # Compute the number of batches that are to be yielded.
        n_samples = self._dspecs.n_samples
        n_batches = n_samples // batch_size
        n_batches += (not drop_remainder) and (n_samples % batch_size)
        # Return that information.
        return dataset, n_samples, n_batches, none_pads

    def _setup_shuffling(
        self,
        dataset: tf.data.Dataset,
        batch_size: int,
        replacement: bool,
    ) -> tf.data.Dataset:
        """Transform a dataset into a shuffled one, opt. with replacement.

        Use `self.buffer_size` or `10 * batch_size` as buffer size.
        If `replacement`, the returned dataset is an infinite iterator.
        """
        if replacement:
            dataset = dataset.repeat(count=None)
        return dataset.shuffle(
            seed=self.rng.integers(2**63),
            buffer_size=self.buffer_size or batch_size * 10,
        )

__init__(dataset, buffer_size=None, batch_mode='default', seed=None)

Wrap up a 'tensorflow.data.Dataset' into a declearn Dataset.

Parameters:

Name Type Description Default
dataset tf.data.Dataset

A tensorflow Dataset instance to be wrapped for declearn use. The dataset is expected to yield sample-level records, made of one to three (tuples of) tensorflow tensors: model inputs, target labels and/or sample weights.

required
buffer_size Optional[int]

Optional buffer size denoting the number of samples to pre-fetch and shuffle when sampling from the original dataset. The higher, the better the shuffling, but also the more memory costly. If None, use context-based batch_size * 10 value.

None
batch_mode BatchMode

Flag specifying how to batch inputs. Use "padded" or "ragged" to batch up variable-dimension samples (e.g. sequences of tokens), using either tf.data.Dataset.padded_batch or .ragged_batch.

'default'
seed Optional[int]

Optional seed for the random number generator based on which the dataset is (optionally) shuffled when generating batches. Note that successive batch-generating calls will not yield the same results, as the seeded state is not reset on each call.

None

Notes

The wrapped tensorflow.data.Dataset:

  • must have a fixed length (with TensorFlow <2.13) / should have an established cardinality (TensorFlow >=2.13).
  • should return sample-level (unbatched) elements, as either:
    • (inputs,)
    • (inputs, labels)
    • (inputs, labels, weights) where each element may be a (nested structure of) tensor(s).
  • when using declearn.model.tensorflow.TensorflowModel:
    • inputs may be a single tensor or list of tensors
    • labels may be a single tensor or None (usually, not None)
    • weights may be a single tensor or None
Source code in declearn/dataset/tensorflow/_tensorflow.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    dataset: tf.data.Dataset,
    buffer_size: Optional[int] = None,
    batch_mode: BatchMode = "default",
    seed: Optional[int] = None,
) -> None:
    """Wrap up a 'tensorflow.data.Dataset' into a declearn Dataset.

    Parameters
    ----------
    dataset: tensorflow.data.Dataset
        A tensorflow Dataset instance to be wrapped for declearn use.
        The dataset is expected to yield sample-level records, made
        of one to three (tuples of) tensorflow tensors: model inputs,
        target labels and/or sample weights.
    buffer_size: int or None, default=None
        Optional buffer size denoting the number of samples to pre-fetch
        and shuffle when sampling from the original dataset. The higher,
        the better the shuffling, but also the more memory costly.
        If None, use context-based `batch_size * 10` value.
    batch_mode: str in {"default", "padded", "ragged"}
        Flag specifying how to batch inputs. Use "padded" or "ragged" to
        batch up variable-dimension samples (e.g. sequences of tokens),
        using either `tf.data.Dataset.padded_batch` or `.ragged_batch`.
    seed: int or None, default=None
        Optional seed for the random number generator based on which
        the dataset is (optionally) shuffled when generating batches.
        Note that successive batch-generating calls will not yield
        the same results, as the seeded state is not reset on each
        call.

    Notes
    -----
    The wrapped `tensorflow.data.Dataset`:

    - *must* have a fixed length (with TensorFlow <2.13) / *should*
      have an established `cardinality` (TensorFlow >=2.13).
    - should return sample-level (unbatched) elements, as either:
        - (inputs,)
        - (inputs, labels)
        - (inputs, labels, weights)
      where each element may be a (nested structure of) tensor(s).
    - when using `declearn.model.tensorflow.TensorflowModel`:
        - inputs may be a single tensor or list of tensors
        - labels may be a single tensor or None (usually, not None)
        - weights may be a single tensor or None
    """
    # Assign the dataset, parse and validate its specifications.
    self.dataset = dataset
    self._dspecs = parse_and_validate_tensorflow_dataset(self.dataset)
    warn_if_dataset_is_likely_batched(self.dataset)
    # Assign additional parameters and set up an opt.-seeded RNG.
    self.batch_mode = batch_mode
    self.buffer_size = buffer_size
    self.rng = np.random.default_rng(seed)