Showing
1 changed file
with
54 additions
and
0 deletions
3DCNN_VGGNet_2DResNet/transforms.py
0 → 100644
1 | +import numpy as np | ||
2 | +from imgaug import augmenters as iaa | ||
3 | +from scipy.ndimage import interpolation, rotate | ||
4 | +import torch | ||
5 | + | ||
6 | +class ImgAugTranslation(object): | ||
7 | + """Translation | ||
8 | + Arg: | ||
9 | + Pixels: number of pixels to apply translation to the image""" | ||
10 | + def __init__(self, pixels): | ||
11 | + n_pixels = int(pixels) | ||
12 | + self.aug = iaa.Affine(translate_px=(-n_pixels, n_pixels)) | ||
13 | + | ||
14 | + def __call__(self, img): | ||
15 | + img = np.array(img) | ||
16 | + return self.aug.augment_image(img) | ||
17 | + | ||
18 | +class ImgAugRotation(object): | ||
19 | + """Rotation | ||
20 | + Arg: | ||
21 | + Degrees: number of degrees to rotate the image""" | ||
22 | + def __init__(self, degrees): | ||
23 | + n_degrees = float(degrees) | ||
24 | + self.aug = iaa.Affine(rotate=(-n_degrees, n_degrees), mode='symmetric') | ||
25 | + | ||
26 | + def __call__(self, img): | ||
27 | + img = np.array(img) | ||
28 | + return self.aug.augment_image(img) | ||
29 | + | ||
30 | + | ||
31 | +class Translation(object): | ||
32 | + """Translation""" | ||
33 | + def __init__(self, offset, order=0, isseg=False, mode='nearest'): | ||
34 | + self.order = order if isseg else 5 | ||
35 | + self.offset = offset | ||
36 | + self.mode = 'nearest' if isseg else 'mirror' | ||
37 | + | ||
38 | + def __call__(self, img): | ||
39 | + return interpolation.shift(img, self.offset , order=self.order, mode=self.mode) | ||
40 | + | ||
41 | +class Rotation(object): | ||
42 | + """Rotation""" | ||
43 | + def __init__(self, theta, order=0, isseg=False, mode='nearest'): | ||
44 | + self.order = order if isseg else 5 | ||
45 | + self.theta = float(theta) | ||
46 | + self.mode = 'nearest' if isseg else 'mirror' | ||
47 | + | ||
48 | + def __call__(self, img): | ||
49 | + return rotate(img, self.theta, reshape=False, order=self.order, mode=self.mode) | ||
50 | + | ||
51 | +class ToTensor(object): | ||
52 | + """Convert ndarrays in sample to Tensors.""" | ||
53 | + def __call__(self, sample): | ||
54 | + return torch.from_numpy(np.asarray(sample).astype(np.float32)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment