Showing
1 changed file
with
35 additions
and
0 deletions
2DCNN/lib/data/mnist.py
0 → 100644
1 | +import torch | ||
2 | +from torchvision.datasets import MNIST | ||
3 | + | ||
4 | +from lib.data.util import DATA_FOLDER | ||
5 | + | ||
6 | + | ||
7 | +class Mnist(MNIST): | ||
8 | + def __init__(self, root=f"{DATA_FOLDER}/mnist", train=True, transform=None, | ||
9 | + target_transform=None, download=True, init_transform=None, | ||
10 | + init_target_transform=None, seed=None, fraction=1.0): | ||
11 | + super().__init__(root, train=train, transform=transform, target_transform=target_transform, | ||
12 | + download=download) | ||
13 | + | ||
14 | + if seed is not None: | ||
15 | + rng_state = torch.get_rng_state() | ||
16 | + torch.manual_seed(seed) | ||
17 | + | ||
18 | + N = len(self.data) | ||
19 | + n = None | ||
20 | + | ||
21 | + if 0 < fraction < 1.0: | ||
22 | + n = int(N * fraction) | ||
23 | + elif N > fraction > 1: | ||
24 | + n = int(fraction) | ||
25 | + if n: | ||
26 | + indices = torch.randperm(N)[:n] | ||
27 | + self.data, self.targets = self.data[indices], self.targets[indices] | ||
28 | + | ||
29 | + if init_transform: | ||
30 | + self.data = self.data = init_transform(self.data) | ||
31 | + if init_target_transform: | ||
32 | + self.targets = init_target_transform(self.targets) | ||
33 | + | ||
34 | + if seed is not None: | ||
35 | + torch.set_rng_state(rng_state) |
-
Please register or login to post a comment