Skip to content

declearn.dataset.examples.load_mnist

Load and/or download the MNIST digit-classification dataset.

See [https://en.wikipedia.org/wiki/MNIST_database] for information on the MNIST dataset.

Arguments

train: bool, default=True Whether to return the 60k training subset, or the 10k testing one. folder: str or None, default=None Optional path to a root folder where to find or download the raw MNIST data. If None, download the data but only store it in memory.

Returns:

Name Type Description
images np.ndarray

Input images, as a (n_images, 28, 28) float numpy array. May be passed as data of a declearn InMemoryDataset.

labels np.ndarray

Target labels, as a (n_images) int numpy array. May be passed as target of a declearn InMemoryDataset.

Source code in declearn/dataset/examples/_mnist.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def load_mnist(
    train: bool = True,
    folder: Optional[str] = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """Load and/or download the MNIST digit-classification dataset.

    See [https://en.wikipedia.org/wiki/MNIST_database] for information
    on the MNIST dataset.

    Arguments
    ---------
    train: bool, default=True
        Whether to return the 60k training subset, or the 10k testing one.
    folder: str or None, default=None
        Optional path to a root folder where to find or download the
        raw MNIST data. If None, download the data but only store it
        in memory.

    Returns
    -------
    images: np.ndarray
        Input images, as a (n_images, 28, 28) float numpy array.
        May be passed as `data` of a declearn `InMemoryDataset`.
    labels: np.ndarray
        Target labels, as a (n_images) int numpy array.
        May be passed as `target` of a declearn `InMemoryDataset`.
    """
    tag = "train" if train else "t10k"
    images = _load_mnist_data(folder, tag, images=True)
    labels = _load_mnist_data(folder, tag, images=False)
    return images, labels