조현아

grayscale efficientnet

...@@ -4,11 +4,12 @@ import json ...@@ -4,11 +4,12 @@ import json
4 from pprint import pprint 4 from pprint import pprint
5 import pickle 5 import pickle
6 import random 6 import random
7 +import numpy as np
8 +import cv2
7 9
8 import torch 10 import torch
9 import torch.nn as nn 11 import torch.nn as nn
10 from torchvision.utils import save_image 12 from torchvision.utils import save_image
11 -from torch.utils.tensorboard import SummaryWriter
12 13
13 from utils import * 14 from utils import *
14 15
...@@ -24,8 +25,6 @@ def eval(model_path): ...@@ -24,8 +25,6 @@ def eval(model_path):
24 device = torch.device('cuda' if args.use_cuda else 'cpu') 25 device = torch.device('cuda' if args.use_cuda else 'cpu')
25 26
26 cp_path = os.path.join(model_path, 'augmentation.cp') 27 cp_path = os.path.join(model_path, 'augmentation.cp')
27 -
28 - writer = SummaryWriter(log_dir=model_path)
29 28
30 29
31 print('\n[+] Load transform') 30 print('\n[+] Load transform')
...@@ -43,17 +42,55 @@ def eval(model_path): ...@@ -43,17 +42,55 @@ def eval(model_path):
43 42
44 43
45 print('\n[+] Save 1 random policy') 44 print('\n[+] Save 1 random policy')
46 - os.makedirs(os.path.join(model_path, 'augmented_imgs')) 45 +
47 - save_dir = os.path.join(model_path, 'augmented_imgs') 46 + # save segmented lesion images
47 + # os.makedirs(os.path.join(model_path, 'aug_seg'))
48 + # save_dir = os.path.join(model_path, 'aug_seg')
49 +
50 + # for i, (image, target) in enumerate(loader):
51 + # image = image.view(240, 240)
52 + # # save img
53 + # save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))
54 +
55 + # if(i % 100 == 0):
56 + # print("\n saved images: ", i)
57 +
58 +
59 + # save synthesized images
60 + save_dir = os.path.join(model_path, 'aug_synthesized')
61 + if not os.path.exists(save_dir):
62 + os.makedirs(save_dir)
63 +
64 + normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/Normal_frames_all'
48 65
49 for i, (image, target) in enumerate(loader): 66 for i, (image, target) in enumerate(loader):
50 image = image.view(240, 240) 67 image = image.view(240, 240)
51 - # save img 68 +
52 - save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png')) 69 + # get random normal brain img
70 + nor_file = random.choice(os.listdir(normal_dir))
71 + nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE)
72 + # print(nor_img.shape) # (256, 224)
73 + nor_img = cv2.resize(nor_img, (240, 240))
74 +
75 + # save normal, lesion image
76 + # save_image(image, os.path.join(save_dir, 'lesion_'+ str(i) + '.png'))
77 + # cv2.imwrite(os.path.join(save_dir, 'nor_'+ str(i) + '.png'), nor_img)
78 +
79 +
80 + # synthesize
81 + image = np.asarray(image)
82 + image_255 = image * 255
83 + image_255[image_255 < 5] = 0
84 + nor_img[image_255 != 0] = 0
85 + syn_image = nor_img + image_255
86 +
87 + # save synthesized img
88 + cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image)
53 89
54 if(i % 100 == 0): 90 if(i % 100 == 0):
55 print("\n saved images: ", i) 91 print("\n saved images: ", i)
56 - 92 + break
93 +
57 print('\n[+] Finished to save') 94 print('\n[+] Finished to save')
58 95
59 if __name__ == '__main__': 96 if __name__ == '__main__':
......
1 +import math
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +
5 +
6 +def round_fn(orig, multiplier):
7 + if not multiplier:
8 + return orig
9 +
10 + return int(math.ceil(multiplier * orig))
11 +
12 +
13 +def get_activation_fn(activation):
14 + if activation == "swish":
15 + return Swish
16 +
17 + elif activation == "relu":
18 + return nn.ReLU
19 +
20 + else:
21 + raise Exception('Unkown activation %s' % activation)
22 +
23 +
24 +class Swish(nn.Module):
25 + """ Swish activation function, s(x) = x * sigmoid(x) """
26 +
27 + def __init__(self, inplace=False):
28 + super().__init__()
29 + self.inplace = True
30 +
31 + def forward(self, x):
32 + if self.inplace:
33 + x.mul_(F.sigmoid(x))
34 + return x
35 + else:
36 + return x * F.sigmoid(x)
37 +
38 +
39 +class ConvBlock(nn.Module):
40 + """ Conv + BatchNorm + Activation """
41 +
42 + def __init__(self, in_channel, out_channel, kernel_size,
43 + padding=0, stride=1, activation="swish"):
44 + super().__init__()
45 + self.fw = nn.Sequential(
46 + nn.Conv2d(in_channel, out_channel, kernel_size,
47 + padding=padding, stride=stride, bias=False),
48 + nn.BatchNorm2d(out_channel),
49 + get_activation_fn(activation)())
50 +
51 + def forward(self, x):
52 + return self.fw(x)
53 +
54 +
55 +class DepthwiseConvBlock(nn.Module):
56 + """ DepthwiseConv2D + BatchNorm + Activation """
57 +
58 + def __init__(self, in_channel, kernel_size,
59 + padding=0, stride=1, activation="swish"):
60 + super().__init__()
61 + self.fw = nn.Sequential(
62 + nn.Conv2d(in_channel, in_channel, kernel_size,
63 + padding=padding, stride=stride, groups=in_channel, bias=False),
64 + nn.BatchNorm2d(in_channel),
65 + get_activation_fn(activation)())
66 +
67 + def forward(self, x):
68 + return self.fw(x)
69 +
70 +
71 +class MBConv(nn.Module):
72 + """ Inverted residual block """
73 +
74 + def __init__(self, in_channel, out_channel, kernel_size,
75 + stride=1, expand_ratio=1, activation="swish"):
76 + super().__init__()
77 + self.in_channel = in_channel
78 + self.out_channel = out_channel
79 + self.expand_ratio = expand_ratio
80 + self.stride = stride
81 +
82 + if expand_ratio != 1:
83 + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
84 + activation=activation)
85 +
86 + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
87 + padding=(kernel_size-1)//2,
88 + stride=stride, activation=activation)
89 +
90 + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
91 + activation=activation)
92 +
93 + def forward(self, inputs):
94 + if self.expand_ratio != 1:
95 + x = self.expand(inputs)
96 + else:
97 + x = inputs
98 +
99 + x = self.dw_conv(x)
100 + x = self.pw_conv(x)
101 +
102 + if self.in_channel == self.out_channel and \
103 + self.stride == 1:
104 + x = x + inputs
105 +
106 + return x
107 +
108 +
109 +class Net(nn.Module):
110 + """ EfficientNet """
111 +
112 + def __init__(self, args):
113 + super(Net, self).__init__()
114 + pi = args.pi
115 + activation = args.activation
116 + num_classes = args.num_classes
117 +
118 + self.d = 1.2 ** pi
119 + self.w = 1.1 ** pi
120 + self.r = 1.15 ** pi
121 + self.img_size = (round_fn(224, self.r), round_fn(224, self.r))
122 +
123 + self.stage1 = ConvBlock(1, round_fn(32, self.w),
124 + kernel_size=3, padding=1, stride=2, activation=activation)
125 +
126 + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
127 + depth=round_fn(1, self.d), kernel_size=3,
128 + half_resolution=False, expand_ratio=1, activation=activation)
129 +
130 + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
131 + depth=round_fn(2, self.d), kernel_size=3,
132 + half_resolution=True, expand_ratio=6, activation=activation)
133 +
134 + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
135 + depth=round_fn(2, self.d), kernel_size=5,
136 + half_resolution=True, expand_ratio=6, activation=activation)
137 +
138 + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
139 + depth=round_fn(3, self.d), kernel_size=3,
140 + half_resolution=True, expand_ratio=6, activation=activation)
141 +
142 + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
143 + depth=round_fn(3, self.d), kernel_size=5,
144 + half_resolution=False, expand_ratio=6, activation=activation)
145 +
146 + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
147 + depth=round_fn(4, self.d), kernel_size=5,
148 + half_resolution=True, expand_ratio=6, activation=activation)
149 +
150 + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
151 + depth=round_fn(1, self.d), kernel_size=3,
152 + half_resolution=False, expand_ratio=6, activation=activation)
153 +
154 + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
155 + kernel_size=1, activation=activation)
156 +
157 + self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes)
158 +
159 + def make_layers(self, in_channel, out_channel, depth, kernel_size,
160 + half_resolution=False, expand_ratio=1, activation="swish"):
161 + blocks = []
162 + for i in range(depth):
163 + stride = 2 if half_resolution and i==0 else 1
164 + blocks.append(
165 + MBConv(in_channel, out_channel, kernel_size,
166 + stride=stride, expand_ratio=expand_ratio, activation=activation))
167 + in_channel = out_channel
168 +
169 + return nn.Sequential(*blocks)
170 +
171 + def forward(self, x):
172 + assert x.size()[-2:] == self.img_size, \
173 + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
174 +
175 + x = self.stage1(x)
176 + x = self.stage2(x)
177 + x = self.stage3(x)
178 + x = self.stage4(x)
179 + x = self.stage5(x)
180 + x = self.stage6(x)
181 + x = self.stage7(x)
182 + x = self.stage8(x)
183 + x = self.stage9(x)
184 + x = x.reshape(x.size(0), -1)
185 + x = self.fc(x)
186 + return x, x
187 +
1 +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
2 +outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\NonLesion_flair_frame\';
3 +
4 +files = dir(inputheader);
5 +id = {files.name};
6 +% files + dir dir
7 +dirFlag = [files.isdir] & ~strcmp(id, '.') & ~strcmp(id, '..');
8 +subFolders = files(dirFlag);
9 +
10 +
11 +% get filenames in getname_path folder
12 +for i = 1 : length(subFolders)
13 + id = subFolders(i).name;
14 + fprintf('\nSub folder #%d = %s: ', i, id);
15 +
16 + type = 'seg.nii';
17 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii
18 + seg_path = strcat(inputheader, id, '\', filename,'\', filename);
19 + seg = niftiread(seg_path); %size 240x240x155
20 + segdata = seg;
21 +
22 + [x,y,z] = size(segdata);
23 +
24 +
25 + type = 'flair.nii';
26 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
27 + flair_path = strcat(inputheader, id, '\', filename,'\', filename);
28 + flair = niftiread(flair_path); %size 240x240x155
29 + flairdata = flair;
30 +
31 +
32 + idx = 1;
33 + frames = zeros(1,1);
34 + for j = 1 : z
35 + n_seg_nonblack = numel(find(segdata(:,:,j) > 0));
36 + n_flair_nonblack = numel(find(flairdata(:,:,j) > 0));
37 + if((n_seg_nonblack == 0) && n_flair_nonblack > 12000) % frames without lesions
38 + frames(idx,1) = j;
39 + idx = idx + 1;
40 + end
41 + end
42 +
43 + c = 0;
44 + [nrow, ncol] = size(frames);
45 +
46 + if frames(1, 1) ~= 0 % n(non lesion frames) > 0
47 + for k = 1 : nrow
48 + fprintf('%d ', frames(k, 1));
49 + type = '.png';
50 + filename = strcat('nml_', id, '_', int2str(c), type); % BraTS19_2013_2_1_c.png
51 + outpath = convertCharsToStrings(strcat(outfolder, filename));
52 + % typecase int16 to double, range[0, 1], rotate 90 and filp updown
53 + % range [0, 1]
54 + cp_data = flipud(rot90(mat2gray(double(flairdata(:,:,frames(k, 1))))));
55 + % M = max(cp_data(:));
56 + % disp(M);
57 + imwrite(cp_data, outpath);
58 +
59 + c = c+ 1;
60 + end
61 + end
62 +
63 +
64 +end
65 +
66 +
1 +% modified from shape_exception_handling_all.m
2 +% get index of 10 frames from seg.nii
3 +% find the same index of flair.nii data and save them to outfolder
4 +
5 +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
6 +outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_flair_frames\';
7 +
8 +files = dir(inputheader);
9 +id = {files.name};
10 +% files + dir dir
11 +dirFlag = [files.isdir] & ~strcmp(id, '.') & ~strcmp(id, '..');
12 +subFolders = files(dirFlag);
13 +
14 +ecp_cnt = 0;
15 +% get filenames in getname_path folder
16 +for i = 1 : length(subFolders)
17 + id = subFolders(i).name;
18 + fprintf('\nSub folder #%d = %s: ', i, id);
19 +
20 + type = 'seg.nii';
21 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii
22 + seg_path = strcat(inputheader, id, '\', filename,'\', filename);
23 + seg = niftiread(seg_path); %size 240x240x155
24 + segdata = seg;
25 +
26 + [x,y,z] = size(segdata);
27 +
28 +
29 + type = 'flair.nii';
30 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
31 + flair_path = strcat(inputheader, id, '\', filename,'\', filename);
32 + flair = niftiread(flair_path); %size 240x240x155
33 + flairdata = flair;
34 +
35 +
36 +
37 + idx = 1;
38 + frames = zeros(1,1);
39 + for j = 1 : z
40 + n_nonblack = numel(find(segdata(:,:,j) > 0));
41 + if(n_nonblack > 70)
42 + frames(idx,1) = j;
43 + idx = idx + 1;
44 + end
45 + end
46 +
47 + c = 0;
48 + [nrow, ncol] = size(frames);
49 + step = round(nrow/11);
50 +
51 +
52 + for k = 1 : step : step*10
53 + %fprintf('%d ', k);
54 + if k > size(frames, 1)
55 + k = size(frames, 1);
56 + fprintf('%s', 'EXCEPTION occured');
57 + ecp_cnt = ecp_cnt +1;
58 + end
59 + fprintf('%d ', frames(k, 1));
60 + type = '.png';
61 + filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_c.png
62 + outpath = convertCharsToStrings(strcat(outfolder, filename));
63 + % typecase int16 to double, range[0, 1], rotate 90 and filp updown
64 + % range [0, 1]
65 + cp_data = flipud(rot90(mat2gray(double(flairdata(:,:,frames(k, 1))))));
66 +% M = max(cp_data(:));
67 +% disp(M);
68 + imwrite(cp_data, outpath);
69 +
70 + c = c+ 1;
71 + end
72 +
73 +
74 +end
75 +fprintf('\n%s: %d', 'num exception', ecp_cnt);
76 +