Created
August 25, 2020 10:50
-
-
Save usamec/1b3b4dcbafad2d58faa71a9633eea6a5 to your computer and use it in GitHub Desktop.
Resumable (and savable) random sampler for Pytorch data loader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class ResumableRandomSampler(torch.utils.data.Sampler): | |
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. | |
If with replacement, then user can specify :attr:`num_samples` to draw. | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` | |
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument | |
is supposed to be specified only when `replacement` is ``True``. | |
generator (Generator): Generator used in sampling. | |
""" | |
#data_source: Sized | |
#replacement: bool | |
def __init__(self, data_source): | |
self.data_source = data_source | |
self.generator = torch.Generator() | |
self.generator.manual_seed(47) | |
self.perm_index = 0 | |
self.perm = torch.randperm(self.num_samples, generator=self.generator) | |
@property | |
def num_samples(self) -> int: | |
return len(self.data_source) | |
def __iter__(self): | |
if self.perm_index >= len(self.perm): | |
self.perm_index = 0 | |
self.perm = torch.randperm(self.num_samples, generator=self.generator) | |
while self.perm_index < len(self.perm): | |
self.perm_index += 1 | |
yield self.perm[self.perm_index-1] | |
def __len__(self): | |
return self.num_samples | |
def get_state(self): | |
return {"perm": self.perm, "perm_index": self.perm_index, "generator_state": self.generator.get_state()} | |
def set_state(self, state): | |
self.perm = state["perm"] | |
self.perm_index = state["perm_index"] | |
self.generator.set_state(state["generator_state"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment