Showing
2 changed files
with
137 additions
and
6 deletions
code/FAA2_VM/getRDAugmented_saveimg.py
0 → 100644
| 1 | + | ||
| 2 | +import os | ||
| 3 | +import fire | ||
| 4 | +import json | ||
| 5 | +from pprint import pprint | ||
| 6 | +import pickle | ||
| 7 | +import random | ||
| 8 | +import numpy as np | ||
| 9 | +import cv2 | ||
| 10 | + | ||
| 11 | +import torch | ||
| 12 | +import torch.nn as nn | ||
| 13 | +from torchvision.utils import save_image | ||
| 14 | +import torchvision.transforms as transforms | ||
| 15 | + | ||
| 16 | +from transforms import * | ||
| 17 | +from utils import * | ||
| 18 | + | ||
| 19 | +# command | ||
| 20 | +# python getRDAugmented_saveimg.py --model_path='logs/April_26_17:36:17_NL_resnet50__None' | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +DEFALUT_CANDIDATES = [ | ||
| 24 | + ShearXY, | ||
| 25 | + TranslateXY, | ||
| 26 | + # Rotate, | ||
| 27 | + AutoContrast, | ||
| 28 | + # Invert, | ||
| 29 | + # Equalize, # Histogram Equalize --> white tumor | ||
| 30 | + # Solarize, | ||
| 31 | + Posterize, | ||
| 32 | + Contrast, | ||
| 33 | + # Color, | ||
| 34 | + Brightness, | ||
| 35 | + Sharpness, | ||
| 36 | + Cutout | ||
| 37 | +] | ||
| 38 | + | ||
| 39 | +def get_next_subpolicy(transform_candidates = DEFALUT_CANDIDATES, op_per_subpolicy=2): | ||
| 40 | + if not transform_candidates: | ||
| 41 | + transform_candidates = DEFALUT_CANDIDATES | ||
| 42 | + | ||
| 43 | + n_candidates = len(transform_candidates) | ||
| 44 | + subpolicy = [] | ||
| 45 | + | ||
| 46 | + for i in range(op_per_subpolicy): | ||
| 47 | + indx = random.randrange(n_candidates) | ||
| 48 | + prob = random.random() | ||
| 49 | + mag = random.random() | ||
| 50 | + subpolicy.append(transform_candidates[indx](prob, mag)) | ||
| 51 | + | ||
| 52 | + subpolicy = transforms.Compose([ | ||
| 53 | + transforms.Pad(4), | ||
| 54 | + transforms.RandomHorizontalFlip(), | ||
| 55 | + *subpolicy, | ||
| 56 | + transforms.Resize([240, 240]), | ||
| 57 | + transforms.ToTensor() | ||
| 58 | + ]) | ||
| 59 | + | ||
| 60 | + return subpolicy | ||
| 61 | + | ||
| 62 | +def eval(model_path): | ||
| 63 | + print('\n[+] Parse arguments') | ||
| 64 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 65 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 66 | + args, kwargs = parse_args(kwargs) | ||
| 67 | + pprint(args) | ||
| 68 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 69 | + | ||
| 70 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 71 | + | ||
| 72 | + | ||
| 73 | + print('\n[+] Load transform') | ||
| 74 | + # list to tensor | ||
| 75 | + | ||
| 76 | + aug_transform_list = [] | ||
| 77 | + | ||
| 78 | + for i in range (16): | ||
| 79 | + aug_transform_list.append(get_next_subpolicy(DEFALUT_CANDIDATES)) | ||
| 80 | + | ||
| 81 | + | ||
| 82 | + transform = transforms.RandomChoice(aug_transform_list) | ||
| 83 | + print(transform) | ||
| 84 | +""" | ||
| 85 | + print('\n[+] Load dataset') | ||
| 86 | + | ||
| 87 | + dataset = get_dataset(args, transform, 'train') | ||
| 88 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 89 | + | ||
| 90 | + | ||
| 91 | + print('\n[+] Save 1 random policy') | ||
| 92 | + | ||
| 93 | + save_dir = os.path.join(model_path, 'RD_aug_synthesized') | ||
| 94 | + if not os.path.exists(save_dir): | ||
| 95 | + os.makedirs(save_dir) | ||
| 96 | + | ||
| 97 | + # non lesion | ||
| 98 | + normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/NonLesion_flair_frame_all' | ||
| 99 | + | ||
| 100 | + for i, (image, target) in enumerate(loader): | ||
| 101 | + image = image.view(240, 240) | ||
| 102 | + | ||
| 103 | + # get random normal brain img | ||
| 104 | + nor_file = random.choice(os.listdir(normal_dir)) | ||
| 105 | + nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE) | ||
| 106 | + # print(nor_img.shape) # (256, 224) | ||
| 107 | + nor_img = cv2.resize(nor_img, (240, 240)) | ||
| 108 | + | ||
| 109 | + # synthesize | ||
| 110 | + image = np.asarray(image) | ||
| 111 | + image_255 = image * 255 | ||
| 112 | + image_255[image_255 < 10] = 0 | ||
| 113 | + nor_img[image_255 > 10] = 0 | ||
| 114 | + syn_image = nor_img + image_255 | ||
| 115 | + | ||
| 116 | + # save synthesized img | ||
| 117 | + cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image) | ||
| 118 | + | ||
| 119 | + if((i+1) % 1000 == 0): | ||
| 120 | + print("\n saved images: ", i) | ||
| 121 | + break | ||
| 122 | + | ||
| 123 | + print('\n[+] Finished to save') | ||
| 124 | + """ | ||
| 125 | + | ||
| 126 | +if __name__ == '__main__': | ||
| 127 | + fire.Fire(eval) | ||
| 128 | + | ||
| 129 | + | ||
| 130 | + | ||
| 131 | + |
| ... | @@ -25,12 +25,12 @@ from networks import basenet, grayResNet2 | ... | @@ -25,12 +25,12 @@ from networks import basenet, grayResNet2 |
| 25 | 25 | ||
| 26 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/' | 26 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/' |
| 27 | 27 | ||
| 28 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | 28 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/aug&HGG+NL_train/' |
| 29 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | 29 | +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/train_augNL_classify_target.csv' |
| 30 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_val/' | 30 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/NL+HGG_val/' |
| 31 | -VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/val_nonaug_classify_target.csv' | 31 | +VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/val_nonaugNL_classify_target.csv' |
| 32 | -TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_test/' | 32 | +TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/NL+HGG_test/' |
| 33 | -TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/test_nonaug_classify_target.csv' | 33 | +TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/classification data/test_nonaugNL_classify_target.csv' |
| 34 | 34 | ||
| 35 | current_epoch = 0 | 35 | current_epoch = 0 |
| 36 | 36 | ... | ... |
-
Please register or login to post a comment