조현아

FAA getBraTS_2

......@@ -5,6 +5,7 @@ import collections
import pickle as cp
import glob
import numpy as np
import pandas as pd
import torch
import torchvision
......@@ -31,7 +32,7 @@ current_epoch = 0
def split_dataset(args, dataset, k):
# load dataset
X = list(range(len(dataset)))
Y = dataset.targets
Y = dataset
# split to k-fold
assert len(X) == len(Y)
......@@ -162,9 +163,11 @@ class CustomDataset(Dataset):
return self.len
def __getitem__(self, idx):
if self.transforms is not None:
img = self.transforms(img)
return img
img, targets = self.img[idx], self.targets[idx]
if self.transform is not None:
img = self.transform(img)
return img, targets
def get_dataset(args, transform, split='train'):
assert split in ['train', 'val', 'test', 'trainval']
......
......@@ -15,6 +15,9 @@ from torchvision.transforms import transforms
from sklearn.model_selection import StratifiedShuffleSplit
from theconf import Config as C
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet
from FastAutoAugment.augmentations import *
from FastAutoAugment.common import get_logger
......@@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
elif 'BraTS' in dataset:
input_size = 240
sized_size = 256
if 'efficientnet' in C.get()['model']['type']:
input_size = EfficientNet.get_image_size(C.get()['model']['type'])
sized_size = input_size + 16 # TODO
logger.info('size changed to %d/%d.' % (input_size, sized_size))
transform_train = transforms.Compose([
EfficientNetRandomCrop(input_size),
transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
EfficientNetCenterCrop(input_size),
transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
transforms.ToTensor(),
])
else:
raise ValueError('dataset=%s' % dataset)
......@@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode
if C.get()['cutout'] > 0:
transform_train.transforms.append(CutoutDefault(C.get()['cutout']))
if dataset == 'cifar10':
if dataset == 'BraTS':
total_trainset =
testset =
elif dataset == 'cifar10':
total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
elif dataset == 'reduced_cifar10':
......
......@@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch
from ray.tune import register_trainable, run_experiments
from tqdm import tqdm
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from FastAutoAugment.archive import remove_deplicates, policy_decoder
from FastAutoAugment.augmentations import augment_list
from FastAutoAugment.common import get_logger, add_filehandler
......
......@@ -19,6 +19,9 @@ import torch.distributed as dist
from tqdm import tqdm
from theconf import Config as C, ConfigArgumentParser
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from FastAutoAugment.common import get_logger, EMA, add_filehandler
from FastAutoAugment.data import get_dataloaders
from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet
......
......@@ -32,6 +32,7 @@ for i = 1 : length(subFolders)
% copy flair, segment flair data
% seg의 검은 부분(정보 x)과 같은 인덱스 = 0
cp_flair(seg == 0) = 0;
% save a segmented data
......
inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\';
outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\frame\';
outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\total_frame\';
files = dir(inputheader);
id = {files.name};
......@@ -38,14 +38,14 @@ for i = 1 : length(files)
c = 0;
step = round(((en) - (st))/11);
for k = st + step : step : st + step*10
c = c+ 1;
type = '.png';
filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png
outpath = strcat(outfolder, filename);
% typecase int16 to double, range[0, 1], rotate 90 and filp updown
cp_data = flipud(rot90(mat2gray(double(data(:,:,k)))));
imwrite(cp_data, outpath);
c = c+ 1;
end
end
......