Hyunji

util

1 +""" general utility function for data: mostly transformations """
2 +
3 +import logging
4 +import os
5 +import random
6 +
7 +import numpy
8 +import numpy as np
9 +from PIL import ImageFilter
10 +
11 +logger = logging.getLogger()
12 +DATA_FOLDER = os.getenv("DATA") if os.getenv("DATA") else "data"
13 +
14 +
15 +def uniform_label_noise(p, labels, seed=None):
16 + if seed is not None:
17 + rng_state = numpy.random.get_state()
18 + numpy.random.seed(seed)
19 +
20 + labels = numpy.array(labels.tolist())
21 + N = len(labels)
22 + lst = numpy.unique(labels)
23 +
24 + # generate random labels
25 + rnd_labels = numpy.random.choice(lst, size=N)
26 +
27 + flip = numpy.random.rand(N) <= p
28 + labels = labels * (1 - flip) + rnd_labels * flip
29 +
30 + if seed is not None:
31 + numpy.random.set_state(rng_state)
32 +
33 + return labels
34 +
35 +
36 +class GaussianBlur(object):
37 + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
38 +
39 + def __init__(self, sigma=None):
40 + if sigma is None:
41 + sigma = [0.1, 2.0]
42 + self.sigma = sigma
43 +
44 + def __call__(self, x):
45 + sigma = random.uniform(self.sigma[0], self.sigma[1])
46 + x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
47 + return x
48 +
49 +
50 +def lines_to_np_array(lines):
51 + return np.array([[int(i) for i in line.split()] for line in lines])
52 +
53 +
54 +def load_binary_mnist():
55 + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_train.amat")) as f:
56 + lines = f.readlines()
57 + train_data = lines_to_np_array(lines).astype("float32")
58 + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_valid.amat")) as f:
59 + lines = f.readlines()
60 + validation_data = lines_to_np_array(lines).astype("float32")
61 + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_test.amat")) as f:
62 + lines = f.readlines()
63 + test_data = lines_to_np_array(lines).astype("float32")
64 +
65 + return {"train": train_data, "valid": validation_data, "test": test_data}
66 +
67 +
68 +def load_mnist():
69 + import gzip
70 + import _pickle
71 +
72 + train, valid, test = _pickle.load(
73 + gzip.open(os.path.join(DATA_FOLDER, "mnist", "mnist.pkl.gz")), encoding="latin1",
74 + )
75 + return {"train": train, "valid": valid, "test": test}