Bases: torch.utils.data.Sampler
Custom torch.utils.data.Sampler
implementing Poisson sampling.
This sampler is equivalent to the UniformWithReplacementSampler
from the third-party opacus
library, with the exception that it
skips empty batches, preventing issues at collate time.
Source code in declearn/dataset/torch/_utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76 | class PoissonSampler(torch.utils.data.Sampler):
"""Custom `torch.utils.data.Sampler` implementing Poisson sampling.
This sampler is equivalent to the `UniformWithReplacementSampler`
from the third-party `opacus` library, with the exception that it
skips empty batches, preventing issues at collate time.
"""
def __init__(
self,
num_samples: int,
sample_rate: float,
generator: Optional[torch.Generator] = None,
# false-positive on 'torch.Generator'; pylint: disable=no-member
) -> None:
"""Instantiate a Poisson (UniformWithReplacement) Sampler.
Parameters
----------
num_samples: int
Number of samples in the dataset to sample from.
sample_rate: float
Sampling rate, i.e. probability for each sample to be included
in any given batch. Hence, average number of samples per batch.
generator: torch.Generator or None, default=None
Optional RNG, that may be used to produce seeded results.
"""
# super init is empty and its signature will change in torch 2.2
# pylint: disable=super-init-not-called
self.num_samples = num_samples
self.sample_rate = sample_rate
self.generator = generator
def __len__(self):
return int(1 / self.sample_rate)
def __iter__(self):
for _ in range(len(self)):
# Draw a random batch of samples based on Poisson sampling.
rand = torch.rand( # false-positive; pylint: disable=no-member
self.num_samples, generator=self.generator, device="cpu"
)
indx = (rand < self.sample_rate).nonzero().reshape(-1).tolist()
# Yield selected indices, unless the batch would be empty.
if not indx:
continue
yield indx
|
__init__(num_samples, sample_rate, generator=None)
Instantiate a Poisson (UniformWithReplacement) Sampler.
Parameters:
Name |
Type |
Description |
Default |
num_samples |
int
|
Number of samples in the dataset to sample from. |
required
|
sample_rate |
float
|
Sampling rate, i.e. probability for each sample to be included
in any given batch. Hence, average number of samples per batch. |
required
|
generator |
Optional[torch.Generator]
|
Optional RNG, that may be used to produce seeded results. |
None
|
Source code in declearn/dataset/torch/_utils.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 | def __init__(
self,
num_samples: int,
sample_rate: float,
generator: Optional[torch.Generator] = None,
# false-positive on 'torch.Generator'; pylint: disable=no-member
) -> None:
"""Instantiate a Poisson (UniformWithReplacement) Sampler.
Parameters
----------
num_samples: int
Number of samples in the dataset to sample from.
sample_rate: float
Sampling rate, i.e. probability for each sample to be included
in any given batch. Hence, average number of samples per batch.
generator: torch.Generator or None, default=None
Optional RNG, that may be used to produce seeded results.
"""
# super init is empty and its signature will change in torch 2.2
# pylint: disable=super-init-not-called
self.num_samples = num_samples
self.sample_rate = sample_rate
self.generator = generator
|