getAugmented_saveimg.py
2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import fire
import json
from pprint import pprint
import pickle
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import *
# command
# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/'
def eval(model_path):
print('\n[+] Parse arguments')
kwargs_path = os.path.join(model_path, 'kwargs.json')
kwargs = json.loads(open(kwargs_path).read())
args, kwargs = parse_args(kwargs)
pprint(args)
device = torch.device('cuda' if args.use_cuda else 'cpu')
cp_path = os.path.join(model_path, 'augmentation.cp')
print('\n[+] Load transform')
# list to tensor
with open(cp_path, 'rb') as f:
aug_transform_list = pickle.load(f)
transform = transforms.RandomChoice(aug_transform_list)
print('\n[+] Load dataset')
dataset = get_dataset(args, transform, 'train')
loader = iter(get_aug_dataloader(args, dataset))
print('\n[+] Save 1 random policy')
# save segmented lesion images
# os.makedirs(os.path.join(model_path, 'aug_seg'))
# save_dir = os.path.join(model_path, 'aug_seg')
# for i, (image, target) in enumerate(loader):
# image = image.view(240, 240)
# # save img
# save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))
# if(i % 100 == 0):
# print("\n saved images: ", i)
# save synthesized images
save_dir = os.path.join(model_path, 'aug_synthesized')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/Normal_frames_all'
for i, (image, target) in enumerate(loader):
image = image.view(240, 240)
# get random normal brain img
nor_file = random.choice(os.listdir(normal_dir))
nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE)
# print(nor_img.shape) # (256, 224)
nor_img = cv2.resize(nor_img, (240, 240))
# save normal, lesion image
# save_image(image, os.path.join(save_dir, 'lesion_'+ str(i) + '.png'))
# cv2.imwrite(os.path.join(save_dir, 'nor_'+ str(i) + '.png'), nor_img)
# synthesize
image = np.asarray(image)
image_255 = image * 255
image_255[image_255 < 5] = 0
nor_img[image_255 != 0] = 0
syn_image = nor_img + image_255
# save synthesized img
cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image)
if(i % 100 == 0):
print("\n saved images: ", i)
break
print('\n[+] Finished to save')
if __name__ == '__main__':
fire.Fire(eval)