declearn.dataset.utils.split_multi_classif_dataset
Split a classification dataset into (opt. heterogeneous) shards.
The data-splitting schemes are the following:
- If "iid", split the dataset through iid random sampling.
- If "labels", split into shards that hold all samples associated with mutually-exclusive target classes.
- If "dirichlet", split the dataset through random sampling using
label-wise shard-assignment probabilities drawn from a symmetrical
Dirichlet distribution, parametrized by an
alpha
parameter. - If "biased", split the dataset through random sampling according to a shard-specific random labels distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Tuple[Union[np.ndarray, spmatrix], np.ndarray]
|
Raw dataset, as a pair of numpy arrays that respectively contain the input features and (aligned) labels. Input features may also be a scipy sparse matrix, that will temporarily be cast to CSR. |
required |
n_shards |
int
|
Number of shards between which to split the dataset. |
required |
scheme |
Literal['iid', 'labels', 'dirichlet', 'biased']
|
Splitting scheme to use. In all cases, shards contain mutually- exclusive samples and cover the full dataset. See details above. |
required |
p_valid |
float
|
Share of each shard to turn into a validation subset. |
0.2
|
seed |
Optional[int]
|
Optional seed to the RNG used for all sampling operations. |
None
|
**kwargs |
Any
|
Additional hyper-parameters specific to the split scheme.
Exhaustive list of possible values:
- |
{}
|
Returns:
Name | Type | Description |
---|---|---|
shards |
List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
|
List of dataset shards, where each element is formatted as a
tuple of tuples: |
Raises:
Type | Description |
---|---|
TypeError
|
If |
ValueError
|
If |
Source code in declearn/dataset/utils/_split_classif.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
|