Hyunji

mnist

1 +import torch
2 +from torchvision.datasets import MNIST
3 +
4 +from lib.data.util import DATA_FOLDER
5 +
6 +
7 +class Mnist(MNIST):
8 + def __init__(self, root=f"{DATA_FOLDER}/mnist", train=True, transform=None,
9 + target_transform=None, download=True, init_transform=None,
10 + init_target_transform=None, seed=None, fraction=1.0):
11 + super().__init__(root, train=train, transform=transform, target_transform=target_transform,
12 + download=download)
13 +
14 + if seed is not None:
15 + rng_state = torch.get_rng_state()
16 + torch.manual_seed(seed)
17 +
18 + N = len(self.data)
19 + n = None
20 +
21 + if 0 < fraction < 1.0:
22 + n = int(N * fraction)
23 + elif N > fraction > 1:
24 + n = int(fraction)
25 + if n:
26 + indices = torch.randperm(N)[:n]
27 + self.data, self.targets = self.data[indices], self.targets[indices]
28 +
29 + if init_transform:
30 + self.data = self.data = init_transform(self.data)
31 + if init_target_transform:
32 + self.targets = init_target_transform(self.targets)
33 +
34 + if seed is not None:
35 + torch.set_rng_state(rng_state)