Skip to content

declearn.dataset.split_data

Randomly split a dataset into shards.

The resulting folder structure is:

folder/
└─── data*/
    └─── client*/
    │      train_data.* - training data
    │      train_target.* - training labels
    │      valid_data.* - validation data
    │      valid_target.* - validation labels
    └─── client*/
    │    ...

Parameters:

Name Type Description Default
folder str

Path to the folder where to add a data folder holding output shard-wise files

'.'
data_file Optional[str]

Optional path to a folder where to find the data. If None, default to the MNIST example.

None
label_file Optional[Union[str, int]]

If str, path to the labels file to import, or name of a data column to use as labels (only if data points to a csv file). If int, index of a data column of to use as labels). Required if data is not None, ignored if data is None.

None
n_shards int

Number of shards between which to split the data.

3
scheme str

Splitting scheme(s) to use. In all cases, shards contain mutually- exclusive samples and cover the full raw training data. - If "iid", split the dataset through iid random sampling. - If "labels", split into shards that hold all samples associated with mutually-exclusive target classes. - If "biased", split the dataset through random sampling according to a shard-specific random labels distribution.

'iid'
perc_train float

Train/validation split in each client dataset, must be in the ]0,1] range.

0.8
seed Optional[int]

Optional seed to the RNG used for all sampling operations.

None
Source code in declearn/dataset/_split_data.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def split_data(
    folder: str = ".",
    data_file: Optional[str] = None,
    label_file: Optional[Union[str, int]] = None,
    n_shards: int = 3,
    scheme: str = "iid",
    perc_train: float = 0.8,
    seed: Optional[int] = None,
) -> None:
    """Randomly split a dataset into shards.

    The resulting folder structure is:

        folder/
        └─── data*/
            └─── client*/
            │      train_data.* - training data
            │      train_target.* - training labels
            │      valid_data.* - validation data
            │      valid_target.* - validation labels
            └─── client*/
            │    ...

    Parameters
    ----------
    folder: str, default = "."
        Path to the folder where to add a data folder
        holding output shard-wise files
    data_file: str or None, default=None
        Optional path to a folder where to find the data.
        If None, default to the MNIST example.
    label_file: str or int or None, default=None
        If str, path to the labels file to import, or name of a `data`
        column to use as labels (only if `data` points to a csv file).
        If int, index of a `data` column of to use as labels).
        Required if data is not None, ignored if data is None.
    n_shards: int
        Number of shards between which to split the data.
    scheme: {"iid", "labels", "biased"}, default="iid"
        Splitting scheme(s) to use. In all cases, shards contain mutually-
        exclusive samples and cover the full raw training data.
        - If "iid", split the dataset through iid random sampling.
        - If "labels", split into shards that hold all samples associated
        with mutually-exclusive target classes.
        - If "biased", split the dataset through random sampling according
        to a shard-specific random labels distribution.
    perc_train: float, default= 0.8
        Train/validation split in each client dataset, must be in the
        ]0,1] range.
    seed: int or None, default=None
        Optional seed to the RNG used for all sampling operations.
    """
    # pylint: disable=too-many-arguments,too-many-locals
    # Select output folder.
    folder = os.path.join(folder, f"data_{scheme}")
    # Value-check the 'perc_train' parameter.
    if not (isinstance(perc_train, float) and (0.0 < perc_train <= 1.0)):
        raise ValueError("'perc_train' should be a float in ]0,1]")
    # Load the dataset and split it.
    inputs, labels = load_data(data_file, label_file)
    print(
        f"Splitting data into {n_shards} shards using the '{scheme}' scheme."
    )
    split = split_multi_classif_dataset(
        dataset=(inputs, labels),
        n_shards=n_shards,
        scheme=scheme,  # type: ignore
        p_valid=(1 - perc_train),
        seed=seed,
    )
    # Export the resulting shard-wise data to files.
    for idx, ((x_train, y_train), (x_valid, y_valid)) in enumerate(split):
        subdir = os.path.join(folder, f"client_{idx}")
        os.makedirs(subdir, exist_ok=True)
        save_data_array(os.path.join(subdir, "train_data"), x_train)
        save_data_array(os.path.join(subdir, "train_target"), y_train)
        if x_valid.shape[0]:
            save_data_array(os.path.join(subdir, "valid_data"), x_valid)
            save_data_array(os.path.join(subdir, "valid_target"), y_valid)