조현아

random policy

1 +
2 +import os
3 +import fire
4 +import json
5 +from pprint import pprint
6 +import pickle
7 +import random
8 +import numpy as np
9 +import cv2
10 +
11 +import torch
12 +import torch.nn as nn
13 +from torchvision.utils import save_image
14 +import torchvision.transforms as transforms
15 +
16 +from transforms import *
17 +from utils import *
18 +
19 +# command
20 +# python getRDAugmented_saveimg.py --model_path='logs/April_26_17:36:17_NL_resnet50__None'
21 +
22 +
23 +DEFALUT_CANDIDATES = [
24 + ShearXY,
25 + TranslateXY,
26 + # Rotate,
27 + AutoContrast,
28 + # Invert,
29 + # Equalize, # Histogram Equalize --> white tumor
30 + # Solarize,
31 + Posterize,
32 + Contrast,
33 + # Color,
34 + Brightness,
35 + Sharpness,
36 + Cutout
37 +]
38 +
39 +def get_next_subpolicy(transform_candidates = DEFALUT_CANDIDATES, op_per_subpolicy=2):
40 + if not transform_candidates:
41 + transform_candidates = DEFALUT_CANDIDATES
42 +
43 + n_candidates = len(transform_candidates)
44 + subpolicy = []
45 +
46 + for i in range(op_per_subpolicy):
47 + indx = random.randrange(n_candidates)
48 + prob = random.random()
49 + mag = random.random()
50 + subpolicy.append(transform_candidates[indx](prob, mag))
51 +
52 + subpolicy = transforms.Compose([
53 + transforms.Pad(4),
54 + transforms.RandomHorizontalFlip(),
55 + *subpolicy,
56 + transforms.Resize([240, 240]),
57 + transforms.ToTensor()
58 + ])
59 +
60 + return subpolicy
61 +
62 +def eval(model_path):
63 + print('\n[+] Parse arguments')
64 + kwargs_path = os.path.join(model_path, 'kwargs.json')
65 + kwargs = json.loads(open(kwargs_path).read())
66 + args, kwargs = parse_args(kwargs)
67 + pprint(args)
68 + device = torch.device('cuda' if args.use_cuda else 'cpu')
69 +
70 + cp_path = os.path.join(model_path, 'augmentation.cp')
71 +
72 +
73 + print('\n[+] Load transform')
74 + # list to tensor
75 +
76 + aug_transform_list = []
77 +
78 + for i in range (16):
79 + aug_transform_list.append(get_next_subpolicy(DEFALUT_CANDIDATES))
80 +
81 +
82 + transform = transforms.RandomChoice(aug_transform_list)
83 + print(transform)
84 +"""
85 + print('\n[+] Load dataset')
86 +
87 + dataset = get_dataset(args, transform, 'train')
88 + loader = iter(get_aug_dataloader(args, dataset))
89 +
90 +
91 + print('\n[+] Save 1 random policy')
92 +
93 + save_dir = os.path.join(model_path, 'RD_aug_synthesized')
94 + if not os.path.exists(save_dir):
95 + os.makedirs(save_dir)
96 +
97 + # non lesion
98 + normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/NonLesion_flair_frame_all'
99 +
100 + for i, (image, target) in enumerate(loader):
101 + image = image.view(240, 240)
102 +
103 + # get random normal brain img
104 + nor_file = random.choice(os.listdir(normal_dir))
105 + nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE)
106 + # print(nor_img.shape) # (256, 224)
107 + nor_img = cv2.resize(nor_img, (240, 240))
108 +
109 + # synthesize
110 + image = np.asarray(image)
111 + image_255 = image * 255
112 + image_255[image_255 < 10] = 0
113 + nor_img[image_255 > 10] = 0
114 + syn_image = nor_img + image_255
115 +
116 + # save synthesized img
117 + cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image)
118 +
119 + if((i+1) % 1000 == 0):
120 + print("\n saved images: ", i)
121 + break
122 +
123 + print('\n[+] Finished to save')
124 + """
125 +
126 +if __name__ == '__main__':
127 + fire.Fire(eval)
128 +
129 +
130 +
131 +
...@@ -25,12 +25,12 @@ from networks import basenet, grayResNet2 ...@@ -25,12 +25,12 @@ from networks import basenet, grayResNet2
25 25
26 DATASET_PATH = '/content/drive/My Drive/CD2 Project/' 26 DATASET_PATH = '/content/drive/My Drive/CD2 Project/'
27 27
28 -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' 28 +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/aug&HGG+NL_train/'
29 -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' 29 +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/train_augNL_classify_target.csv'
30 -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_val/' 30 +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/NL+HGG_val/'
31 -VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/val_nonaug_classify_target.csv' 31 +VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/val_nonaugNL_classify_target.csv'
32 -TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_test/' 32 +TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/NL+HGG_test/'
33 -TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/test_nonaug_classify_target.csv' 33 +TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/test_nonaugNL_classify_target.csv'
34 34
35 current_epoch = 0 35 current_epoch = 0
36 36
......