Showing
1 changed file
with
75 additions
and
0 deletions
2DCNN/lib/data/util.py
0 → 100644
1 | +""" general utility function for data: mostly transformations """ | ||
2 | + | ||
3 | +import logging | ||
4 | +import os | ||
5 | +import random | ||
6 | + | ||
7 | +import numpy | ||
8 | +import numpy as np | ||
9 | +from PIL import ImageFilter | ||
10 | + | ||
11 | +logger = logging.getLogger() | ||
12 | +DATA_FOLDER = os.getenv("DATA") if os.getenv("DATA") else "data" | ||
13 | + | ||
14 | + | ||
15 | +def uniform_label_noise(p, labels, seed=None): | ||
16 | + if seed is not None: | ||
17 | + rng_state = numpy.random.get_state() | ||
18 | + numpy.random.seed(seed) | ||
19 | + | ||
20 | + labels = numpy.array(labels.tolist()) | ||
21 | + N = len(labels) | ||
22 | + lst = numpy.unique(labels) | ||
23 | + | ||
24 | + # generate random labels | ||
25 | + rnd_labels = numpy.random.choice(lst, size=N) | ||
26 | + | ||
27 | + flip = numpy.random.rand(N) <= p | ||
28 | + labels = labels * (1 - flip) + rnd_labels * flip | ||
29 | + | ||
30 | + if seed is not None: | ||
31 | + numpy.random.set_state(rng_state) | ||
32 | + | ||
33 | + return labels | ||
34 | + | ||
35 | + | ||
36 | +class GaussianBlur(object): | ||
37 | + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" | ||
38 | + | ||
39 | + def __init__(self, sigma=None): | ||
40 | + if sigma is None: | ||
41 | + sigma = [0.1, 2.0] | ||
42 | + self.sigma = sigma | ||
43 | + | ||
44 | + def __call__(self, x): | ||
45 | + sigma = random.uniform(self.sigma[0], self.sigma[1]) | ||
46 | + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) | ||
47 | + return x | ||
48 | + | ||
49 | + | ||
50 | +def lines_to_np_array(lines): | ||
51 | + return np.array([[int(i) for i in line.split()] for line in lines]) | ||
52 | + | ||
53 | + | ||
54 | +def load_binary_mnist(): | ||
55 | + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_train.amat")) as f: | ||
56 | + lines = f.readlines() | ||
57 | + train_data = lines_to_np_array(lines).astype("float32") | ||
58 | + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_valid.amat")) as f: | ||
59 | + lines = f.readlines() | ||
60 | + validation_data = lines_to_np_array(lines).astype("float32") | ||
61 | + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_test.amat")) as f: | ||
62 | + lines = f.readlines() | ||
63 | + test_data = lines_to_np_array(lines).astype("float32") | ||
64 | + | ||
65 | + return {"train": train_data, "valid": validation_data, "test": test_data} | ||
66 | + | ||
67 | + | ||
68 | +def load_mnist(): | ||
69 | + import gzip | ||
70 | + import _pickle | ||
71 | + | ||
72 | + train, valid, test = _pickle.load( | ||
73 | + gzip.open(os.path.join(DATA_FOLDER, "mnist", "mnist.pkl.gz")), encoding="latin1", | ||
74 | + ) | ||
75 | + return {"train": train, "valid": valid, "test": test} |
-
Please register or login to post a comment