조현아

grayscale efficientnet

......@@ -4,11 +4,12 @@ import json
from pprint import pprint
import pickle
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from utils import *
......@@ -24,8 +25,6 @@ def eval(model_path):
device = torch.device('cuda' if args.use_cuda else 'cpu')
cp_path = os.path.join(model_path, 'augmentation.cp')
writer = SummaryWriter(log_dir=model_path)
print('\n[+] Load transform')
......@@ -43,17 +42,55 @@ def eval(model_path):
print('\n[+] Save 1 random policy')
os.makedirs(os.path.join(model_path, 'augmented_imgs'))
save_dir = os.path.join(model_path, 'augmented_imgs')
# save segmented lesion images
# os.makedirs(os.path.join(model_path, 'aug_seg'))
# save_dir = os.path.join(model_path, 'aug_seg')
# for i, (image, target) in enumerate(loader):
# image = image.view(240, 240)
# # save img
# save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))
# if(i % 100 == 0):
# print("\n saved images: ", i)
# save synthesized images
save_dir = os.path.join(model_path, 'aug_synthesized')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/Normal_frames_all'
for i, (image, target) in enumerate(loader):
image = image.view(240, 240)
# save img
save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))
# get random normal brain img
nor_file = random.choice(os.listdir(normal_dir))
nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE)
# print(nor_img.shape) # (256, 224)
nor_img = cv2.resize(nor_img, (240, 240))
# save normal, lesion image
# save_image(image, os.path.join(save_dir, 'lesion_'+ str(i) + '.png'))
# cv2.imwrite(os.path.join(save_dir, 'nor_'+ str(i) + '.png'), nor_img)
# synthesize
image = np.asarray(image)
image_255 = image * 255
image_255[image_255 < 5] = 0
nor_img[image_255 != 0] = 0
syn_image = nor_img + image_255
# save synthesized img
cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image)
if(i % 100 == 0):
print("\n saved images: ", i)
break
print('\n[+] Finished to save')
if __name__ == '__main__':
......
import math
import torch.nn as nn
import torch.nn.functional as F
def round_fn(orig, multiplier):
if not multiplier:
return orig
return int(math.ceil(multiplier * orig))
def get_activation_fn(activation):
if activation == "swish":
return Swish
elif activation == "relu":
return nn.ReLU
else:
raise Exception('Unkown activation %s' % activation)
class Swish(nn.Module):
""" Swish activation function, s(x) = x * sigmoid(x) """
def __init__(self, inplace=False):
super().__init__()
self.inplace = True
def forward(self, x):
if self.inplace:
x.mul_(F.sigmoid(x))
return x
else:
return x * F.sigmoid(x)
class ConvBlock(nn.Module):
""" Conv + BatchNorm + Activation """
def __init__(self, in_channel, out_channel, kernel_size,
padding=0, stride=1, activation="swish"):
super().__init__()
self.fw = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size,
padding=padding, stride=stride, bias=False),
nn.BatchNorm2d(out_channel),
get_activation_fn(activation)())
def forward(self, x):
return self.fw(x)
class DepthwiseConvBlock(nn.Module):
""" DepthwiseConv2D + BatchNorm + Activation """
def __init__(self, in_channel, kernel_size,
padding=0, stride=1, activation="swish"):
super().__init__()
self.fw = nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size,
padding=padding, stride=stride, groups=in_channel, bias=False),
nn.BatchNorm2d(in_channel),
get_activation_fn(activation)())
def forward(self, x):
return self.fw(x)
class MBConv(nn.Module):
""" Inverted residual block """
def __init__(self, in_channel, out_channel, kernel_size,
stride=1, expand_ratio=1, activation="swish"):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.expand_ratio = expand_ratio
self.stride = stride
if expand_ratio != 1:
self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
activation=activation)
self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
padding=(kernel_size-1)//2,
stride=stride, activation=activation)
self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
activation=activation)
def forward(self, inputs):
if self.expand_ratio != 1:
x = self.expand(inputs)
else:
x = inputs
x = self.dw_conv(x)
x = self.pw_conv(x)
if self.in_channel == self.out_channel and \
self.stride == 1:
x = x + inputs
return x
class Net(nn.Module):
""" EfficientNet """
def __init__(self, args):
super(Net, self).__init__()
pi = args.pi
activation = args.activation
num_classes = args.num_classes
self.d = 1.2 ** pi
self.w = 1.1 ** pi
self.r = 1.15 ** pi
self.img_size = (round_fn(224, self.r), round_fn(224, self.r))
self.stage1 = ConvBlock(1, round_fn(32, self.w),
kernel_size=3, padding=1, stride=2, activation=activation)
self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
depth=round_fn(1, self.d), kernel_size=3,
half_resolution=False, expand_ratio=1, activation=activation)
self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
depth=round_fn(2, self.d), kernel_size=3,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
depth=round_fn(2, self.d), kernel_size=5,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
depth=round_fn(3, self.d), kernel_size=3,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
depth=round_fn(3, self.d), kernel_size=5,
half_resolution=False, expand_ratio=6, activation=activation)
self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
depth=round_fn(4, self.d), kernel_size=5,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
depth=round_fn(1, self.d), kernel_size=3,
half_resolution=False, expand_ratio=6, activation=activation)
self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
kernel_size=1, activation=activation)
self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes)
def make_layers(self, in_channel, out_channel, depth, kernel_size,
half_resolution=False, expand_ratio=1, activation="swish"):
blocks = []
for i in range(depth):
stride = 2 if half_resolution and i==0 else 1
blocks.append(
MBConv(in_channel, out_channel, kernel_size,
stride=stride, expand_ratio=expand_ratio, activation=activation))
in_channel = out_channel
return nn.Sequential(*blocks)
def forward(self, x):
assert x.size()[-2:] == self.img_size, \
'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.stage6(x)
x = self.stage7(x)
x = self.stage8(x)
x = self.stage9(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x, x
inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\NonLesion_flair_frame\';
files = dir(inputheader);
id = {files.name};
% files + dir dir
dirFlag = [files.isdir] & ~strcmp(id, '.') & ~strcmp(id, '..');
subFolders = files(dirFlag);
% get filenames in getname_path folder
for i = 1 : length(subFolders)
id = subFolders(i).name;
fprintf('\nSub folder #%d = %s: ', i, id);
type = 'seg.nii';
filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii
seg_path = strcat(inputheader, id, '\', filename,'\', filename);
seg = niftiread(seg_path); %size 240x240x155
segdata = seg;
[x,y,z] = size(segdata);
type = 'flair.nii';
filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
flair_path = strcat(inputheader, id, '\', filename,'\', filename);
flair = niftiread(flair_path); %size 240x240x155
flairdata = flair;
idx = 1;
frames = zeros(1,1);
for j = 1 : z
n_seg_nonblack = numel(find(segdata(:,:,j) > 0));
n_flair_nonblack = numel(find(flairdata(:,:,j) > 0));
if((n_seg_nonblack == 0) && n_flair_nonblack > 12000) % frames without lesions
frames(idx,1) = j;
idx = idx + 1;
end
end
c = 0;
[nrow, ncol] = size(frames);
if frames(1, 1) ~= 0 % n(non lesion frames) > 0
for k = 1 : nrow
fprintf('%d ', frames(k, 1));
type = '.png';
filename = strcat('nml_', id, '_', int2str(c), type); % BraTS19_2013_2_1_c.png
outpath = convertCharsToStrings(strcat(outfolder, filename));
% typecase int16 to double, range[0, 1], rotate 90 and filp updown
% range [0, 1]
cp_data = flipud(rot90(mat2gray(double(flairdata(:,:,frames(k, 1))))));
% M = max(cp_data(:));
% disp(M);
imwrite(cp_data, outpath);
c = c+ 1;
end
end
end
% modified from shape_exception_handling_all.m
% get index of 10 frames from seg.nii
% find the same index of flair.nii data and save them to outfolder
inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_flair_frames\';
files = dir(inputheader);
id = {files.name};
% files + dir dir
dirFlag = [files.isdir] & ~strcmp(id, '.') & ~strcmp(id, '..');
subFolders = files(dirFlag);
ecp_cnt = 0;
% get filenames in getname_path folder
for i = 1 : length(subFolders)
id = subFolders(i).name;
fprintf('\nSub folder #%d = %s: ', i, id);
type = 'seg.nii';
filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii
seg_path = strcat(inputheader, id, '\', filename,'\', filename);
seg = niftiread(seg_path); %size 240x240x155
segdata = seg;
[x,y,z] = size(segdata);
type = 'flair.nii';
filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
flair_path = strcat(inputheader, id, '\', filename,'\', filename);
flair = niftiread(flair_path); %size 240x240x155
flairdata = flair;
idx = 1;
frames = zeros(1,1);
for j = 1 : z
n_nonblack = numel(find(segdata(:,:,j) > 0));
if(n_nonblack > 70)
frames(idx,1) = j;
idx = idx + 1;
end
end
c = 0;
[nrow, ncol] = size(frames);
step = round(nrow/11);
for k = 1 : step : step*10
%fprintf('%d ', k);
if k > size(frames, 1)
k = size(frames, 1);
fprintf('%s', 'EXCEPTION occured');
ecp_cnt = ecp_cnt +1;
end
fprintf('%d ', frames(k, 1));
type = '.png';
filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_c.png
outpath = convertCharsToStrings(strcat(outfolder, filename));
% typecase int16 to double, range[0, 1], rotate 90 and filp updown
% range [0, 1]
cp_data = flipud(rot90(mat2gray(double(flairdata(:,:,frames(k, 1))))));
% M = max(cp_data(:));
% disp(M);
imwrite(cp_data, outpath);
c = c+ 1;
end
end
fprintf('\n%s: %d', 'num exception', ecp_cnt);