Skip to content

declearn.dataset.torch.PoissonSampler

Bases: torch.utils.data.Sampler

Custom torch.utils.data.Sampler implementing Poisson sampling.

This sampler is equivalent to the UniformWithReplacementSampler from the third-party opacus library, with the exception that it skips empty batches, preventing issues at collate time.

Source code in declearn/dataset/torch/_utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class PoissonSampler(torch.utils.data.Sampler):
    """Custom `torch.utils.data.Sampler` implementing Poisson sampling.

    This sampler is equivalent to the `UniformWithReplacementSampler`
    from the third-party `opacus` library, with the exception that it
    skips empty batches, preventing issues at collate time.
    """

    def __init__(
        self,
        num_samples: int,
        sample_rate: float,
        generator: Optional[torch.Generator] = None,
        # false-positive on 'torch.Generator'; pylint: disable=no-member
    ) -> None:
        """Instantiate a Poisson (UniformWithReplacement) Sampler.

        Parameters
        ----------
        num_samples: int
            Number of samples in the dataset to sample from.
        sample_rate: float
            Sampling rate, i.e. probability for each sample to be included
            in any given batch. Hence, average number of samples per batch.
        generator: torch.Generator or None, default=None
            Optional RNG, that may be used to produce seeded results.
        """
        # super init is empty and its signature will change in torch 2.2
        # pylint: disable=super-init-not-called
        self.num_samples = num_samples
        self.sample_rate = sample_rate
        self.generator = generator

    def __len__(self):
        return int(1 / self.sample_rate)

    def __iter__(self):
        for _ in range(len(self)):
            # Draw a random batch of samples based on Poisson sampling.
            rand = torch.rand(  # false-positive; pylint: disable=no-member
                self.num_samples, generator=self.generator, device="cpu"
            )
            indx = (rand < self.sample_rate).nonzero().reshape(-1).tolist()
            # Yield selected indices, unless the batch would be empty.
            if not indx:
                continue
            yield indx

__init__(num_samples, sample_rate, generator=None)

Instantiate a Poisson (UniformWithReplacement) Sampler.

Parameters:

Name Type Description Default
num_samples int

Number of samples in the dataset to sample from.

required
sample_rate float

Sampling rate, i.e. probability for each sample to be included in any given batch. Hence, average number of samples per batch.

required
generator Optional[torch.Generator]

Optional RNG, that may be used to produce seeded results.

None
Source code in declearn/dataset/torch/_utils.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    num_samples: int,
    sample_rate: float,
    generator: Optional[torch.Generator] = None,
    # false-positive on 'torch.Generator'; pylint: disable=no-member
) -> None:
    """Instantiate a Poisson (UniformWithReplacement) Sampler.

    Parameters
    ----------
    num_samples: int
        Number of samples in the dataset to sample from.
    sample_rate: float
        Sampling rate, i.e. probability for each sample to be included
        in any given batch. Hence, average number of samples per batch.
    generator: torch.Generator or None, default=None
        Optional RNG, that may be used to produce seeded results.
    """
    # super init is empty and its signature will change in torch 2.2
    # pylint: disable=super-init-not-called
    self.num_samples = num_samples
    self.sample_rate = sample_rate
    self.generator = generator