Skip to content

declearn.dataset.utils.split_multi_classif_dataset

Split a classification dataset into (opt. heterogeneous) shards.

The data-splitting schemes are the following:

  • If "iid", split the dataset through iid random sampling.
  • If "labels", split into shards that hold all samples associated with mutually-exclusive target classes.
  • If "dirichlet", split the dataset through random sampling using label-wise shard-assignment probabilities drawn from a symmetrical Dirichlet distribution, parametrized by an alpha parameter.
  • If "biased", split the dataset through random sampling according to a shard-specific random labels distribution.

Parameters:

Name Type Description Default
dataset Tuple[Union[np.ndarray, spmatrix], np.ndarray]

Raw dataset, as a pair of numpy arrays that respectively contain the input features and (aligned) labels. Input features may also be a scipy sparse matrix, that will temporarily be cast to CSR.

required
n_shards int

Number of shards between which to split the dataset.

required
scheme Literal['iid', 'labels', 'dirichlet', 'biased']

Splitting scheme to use. In all cases, shards contain mutually- exclusive samples and cover the full dataset. See details above.

required
p_valid float

Share of each shard to turn into a validation subset.

0.2
seed Optional[int]

Optional seed to the RNG used for all sampling operations.

None
**kwargs Any

Additional hyper-parameters specific to the split scheme. Exhaustive list of possible values: - alpha: float = 0.5 for scheme="dirichlet"

{}

Returns:

Name Type Description
shards List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]

List of dataset shards, where each element is formatted as a tuple of tuples: ((x_train, y_train), (x_valid, y_valid)). Input features will be of same type as inputs.

Raises:

Type Description
TypeError

If inputs is not a numpy array or scipy sparse matrix.

ValueError

If scheme has an invalid value.

Source code in declearn/dataset/utils/_split_classif.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def split_multi_classif_dataset(
    dataset: Tuple[Union[np.ndarray, spmatrix], np.ndarray],
    n_shards: int,
    scheme: Literal["iid", "labels", "dirichlet", "biased"],
    p_valid: float = 0.2,
    seed: Optional[int] = None,
    **kwargs: Any,
) -> List[Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]:
    """Split a classification dataset into (opt. heterogeneous) shards.

    The data-splitting schemes are the following:

    - If "iid", split the dataset through iid random sampling.
    - If "labels", split into shards that hold all samples associated
      with mutually-exclusive target classes.
    - If "dirichlet", split the dataset through random sampling using
      label-wise shard-assignment probabilities drawn from a symmetrical
      Dirichlet distribution, parametrized by an `alpha` parameter.
    - If "biased", split the dataset through random sampling according
      to a shard-specific random labels distribution.

    Parameters
    ----------
    dataset: tuple(np.ndarray|spmatrix, np.ndarray)
        Raw dataset, as a pair of numpy arrays that respectively contain
        the input features and (aligned) labels. Input features may also
        be a scipy sparse matrix, that will temporarily be cast to CSR.
    n_shards: int
        Number of shards between which to split the dataset.
    scheme: {"iid", "labels", "dirichlet", "biased"}
        Splitting scheme to use. In all cases, shards contain mutually-
        exclusive samples and cover the full dataset. See details above.
    p_valid: float, default=0.2
        Share of each shard to turn into a validation subset.
    seed: int or None, default=None
        Optional seed to the RNG used for all sampling operations.
    **kwargs:
        Additional hyper-parameters specific to the split scheme.
        Exhaustive list of possible values:
            - `alpha: float = 0.5` for `scheme="dirichlet"`

    Returns
    -------
    shards:
        List of dataset shards, where each element is formatted as a
        tuple of tuples: `((x_train, y_train), (x_valid, y_valid))`.
        Input features will be of same type as `inputs`.

    Raises
    ------
    TypeError
        If `inputs` is not a numpy array or scipy sparse matrix.
    ValueError
        If `scheme` has an invalid value.
    """
    # Select the splitting function to be used.
    if scheme == "iid":
        func = split_iid
    elif scheme == "labels":
        func = split_labels
    elif scheme == "dirichlet":
        func = functools.partial(
            split_dirichlet, alpha=kwargs.get("alpha", 0.5)
        )
    elif scheme == "biased":
        func = split_biased
    else:
        raise ValueError(f"Invalid 'scheme' value: '{scheme}'.")
    # Set up the RNG and unpack the dataset.
    rng = np.random.default_rng(seed)
    inputs, target = dataset
    # Optionally handle sparse matrix inputs.
    sp_type = None  # type: Optional[Type[spmatrix]]
    if isinstance(inputs, spmatrix):
        sp_type = type(inputs)
        inputs = csr_matrix(inputs)
    elif not isinstance(inputs, np.ndarray):
        raise TypeError(
            "'inputs' should be a numpy array or scipy sparse matrix."
        )
    # Split the dataset into shards.
    split = func(inputs, target, n_shards, rng)
    # Further split shards into training and validation subsets.
    shards = [train_valid_split(inp, tgt, p_valid, rng) for inp, tgt in split]
    # Optionally convert back sparse inputs, then return.
    if sp_type is not None:
        shards = [
            ((sp_type(xt), yt), (sp_type(xv), yv))
            for (xt, yt), (xv, yv) in shards
        ]
    return shards