Showing
2 changed files
with
137 additions
and
6 deletions
code/FAA2_VM/getRDAugmented_saveimg.py
0 → 100644
1 | + | ||
2 | +import os | ||
3 | +import fire | ||
4 | +import json | ||
5 | +from pprint import pprint | ||
6 | +import pickle | ||
7 | +import random | ||
8 | +import numpy as np | ||
9 | +import cv2 | ||
10 | + | ||
11 | +import torch | ||
12 | +import torch.nn as nn | ||
13 | +from torchvision.utils import save_image | ||
14 | +import torchvision.transforms as transforms | ||
15 | + | ||
16 | +from transforms import * | ||
17 | +from utils import * | ||
18 | + | ||
19 | +# command | ||
20 | +# python getRDAugmented_saveimg.py --model_path='logs/April_26_17:36:17_NL_resnet50__None' | ||
21 | + | ||
22 | + | ||
23 | +DEFALUT_CANDIDATES = [ | ||
24 | + ShearXY, | ||
25 | + TranslateXY, | ||
26 | + # Rotate, | ||
27 | + AutoContrast, | ||
28 | + # Invert, | ||
29 | + # Equalize, # Histogram Equalize --> white tumor | ||
30 | + # Solarize, | ||
31 | + Posterize, | ||
32 | + Contrast, | ||
33 | + # Color, | ||
34 | + Brightness, | ||
35 | + Sharpness, | ||
36 | + Cutout | ||
37 | +] | ||
38 | + | ||
39 | +def get_next_subpolicy(transform_candidates = DEFALUT_CANDIDATES, op_per_subpolicy=2): | ||
40 | + if not transform_candidates: | ||
41 | + transform_candidates = DEFALUT_CANDIDATES | ||
42 | + | ||
43 | + n_candidates = len(transform_candidates) | ||
44 | + subpolicy = [] | ||
45 | + | ||
46 | + for i in range(op_per_subpolicy): | ||
47 | + indx = random.randrange(n_candidates) | ||
48 | + prob = random.random() | ||
49 | + mag = random.random() | ||
50 | + subpolicy.append(transform_candidates[indx](prob, mag)) | ||
51 | + | ||
52 | + subpolicy = transforms.Compose([ | ||
53 | + transforms.Pad(4), | ||
54 | + transforms.RandomHorizontalFlip(), | ||
55 | + *subpolicy, | ||
56 | + transforms.Resize([240, 240]), | ||
57 | + transforms.ToTensor() | ||
58 | + ]) | ||
59 | + | ||
60 | + return subpolicy | ||
61 | + | ||
62 | +def eval(model_path): | ||
63 | + print('\n[+] Parse arguments') | ||
64 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
65 | + kwargs = json.loads(open(kwargs_path).read()) | ||
66 | + args, kwargs = parse_args(kwargs) | ||
67 | + pprint(args) | ||
68 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
69 | + | ||
70 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
71 | + | ||
72 | + | ||
73 | + print('\n[+] Load transform') | ||
74 | + # list to tensor | ||
75 | + | ||
76 | + aug_transform_list = [] | ||
77 | + | ||
78 | + for i in range (16): | ||
79 | + aug_transform_list.append(get_next_subpolicy(DEFALUT_CANDIDATES)) | ||
80 | + | ||
81 | + | ||
82 | + transform = transforms.RandomChoice(aug_transform_list) | ||
83 | + print(transform) | ||
84 | +""" | ||
85 | + print('\n[+] Load dataset') | ||
86 | + | ||
87 | + dataset = get_dataset(args, transform, 'train') | ||
88 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
89 | + | ||
90 | + | ||
91 | + print('\n[+] Save 1 random policy') | ||
92 | + | ||
93 | + save_dir = os.path.join(model_path, 'RD_aug_synthesized') | ||
94 | + if not os.path.exists(save_dir): | ||
95 | + os.makedirs(save_dir) | ||
96 | + | ||
97 | + # non lesion | ||
98 | + normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/NonLesion_flair_frame_all' | ||
99 | + | ||
100 | + for i, (image, target) in enumerate(loader): | ||
101 | + image = image.view(240, 240) | ||
102 | + | ||
103 | + # get random normal brain img | ||
104 | + nor_file = random.choice(os.listdir(normal_dir)) | ||
105 | + nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE) | ||
106 | + # print(nor_img.shape) # (256, 224) | ||
107 | + nor_img = cv2.resize(nor_img, (240, 240)) | ||
108 | + | ||
109 | + # synthesize | ||
110 | + image = np.asarray(image) | ||
111 | + image_255 = image * 255 | ||
112 | + image_255[image_255 < 10] = 0 | ||
113 | + nor_img[image_255 > 10] = 0 | ||
114 | + syn_image = nor_img + image_255 | ||
115 | + | ||
116 | + # save synthesized img | ||
117 | + cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image) | ||
118 | + | ||
119 | + if((i+1) % 1000 == 0): | ||
120 | + print("\n saved images: ", i) | ||
121 | + break | ||
122 | + | ||
123 | + print('\n[+] Finished to save') | ||
124 | + """ | ||
125 | + | ||
126 | +if __name__ == '__main__': | ||
127 | + fire.Fire(eval) | ||
128 | + | ||
129 | + | ||
130 | + | ||
131 | + |
... | @@ -25,12 +25,12 @@ from networks import basenet, grayResNet2 | ... | @@ -25,12 +25,12 @@ from networks import basenet, grayResNet2 |
25 | 25 | ||
26 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/' | 26 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/' |
27 | 27 | ||
28 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | 28 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/aug&HGG+NL_train/' |
29 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | 29 | +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/train_augNL_classify_target.csv' |
30 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_val/' | 30 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/NL+HGG_val/' |
31 | -VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/val_nonaug_classify_target.csv' | 31 | +VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/val_nonaugNL_classify_target.csv' |
32 | -TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_test/' | 32 | +TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/NL+HGG_test/' |
33 | -TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/test_nonaug_classify_target.csv' | 33 | +TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/test_nonaugNL_classify_target.csv' |
34 | 34 | ||
35 | current_epoch = 0 | 35 | current_epoch = 0 |
36 | 36 | ... | ... |
-
Please register or login to post a comment