조현아

FAA getBraTS_2

...@@ -5,6 +5,7 @@ import collections ...@@ -5,6 +5,7 @@ import collections
5 import pickle as cp 5 import pickle as cp
6 import glob 6 import glob
7 import numpy as np 7 import numpy as np
8 +import pandas as pd
8 9
9 import torch 10 import torch
10 import torchvision 11 import torchvision
...@@ -31,7 +32,7 @@ current_epoch = 0 ...@@ -31,7 +32,7 @@ current_epoch = 0
31 def split_dataset(args, dataset, k): 32 def split_dataset(args, dataset, k):
32 # load dataset 33 # load dataset
33 X = list(range(len(dataset))) 34 X = list(range(len(dataset)))
34 - Y = dataset.targets 35 + Y = dataset
35 36
36 # split to k-fold 37 # split to k-fold
37 assert len(X) == len(Y) 38 assert len(X) == len(Y)
...@@ -162,9 +163,11 @@ class CustomDataset(Dataset): ...@@ -162,9 +163,11 @@ class CustomDataset(Dataset):
162 return self.len 163 return self.len
163 164
164 def __getitem__(self, idx): 165 def __getitem__(self, idx):
165 - if self.transforms is not None: 166 + img, targets = self.img[idx], self.targets[idx]
166 - img = self.transforms(img) 167 +
167 - return img 168 + if self.transform is not None:
169 + img = self.transform(img)
170 + return img, targets
168 171
169 def get_dataset(args, transform, split='train'): 172 def get_dataset(args, transform, split='train'):
170 assert split in ['train', 'val', 'test', 'trainval'] 173 assert split in ['train', 'val', 'test', 'trainval']
......
...@@ -15,6 +15,9 @@ from torchvision.transforms import transforms ...@@ -15,6 +15,9 @@ from torchvision.transforms import transforms
15 from sklearn.model_selection import StratifiedShuffleSplit 15 from sklearn.model_selection import StratifiedShuffleSplit
16 from theconf import Config as C 16 from theconf import Config as C
17 17
18 +
19 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
20 +
18 from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet 21 from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet
19 from FastAutoAugment.augmentations import * 22 from FastAutoAugment.augmentations import *
20 from FastAutoAugment.common import get_logger 23 from FastAutoAugment.common import get_logger
...@@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode ...@@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode
79 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 82 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80 ]) 83 ])
81 84
85 + elif 'BraTS' in dataset:
86 + input_size = 240
87 + sized_size = 256
88 +
89 + if 'efficientnet' in C.get()['model']['type']:
90 + input_size = EfficientNet.get_image_size(C.get()['model']['type'])
91 + sized_size = input_size + 16 # TODO
92 +
93 + logger.info('size changed to %d/%d.' % (input_size, sized_size))
94 +
95 + transform_train = transforms.Compose([
96 + EfficientNetRandomCrop(input_size),
97 + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
98 + transforms.RandomHorizontalFlip(),
99 + transforms.ToTensor(),
100 + ])
101 +
102 + transform_test = transforms.Compose([
103 + EfficientNetCenterCrop(input_size),
104 + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
105 + transforms.ToTensor(),
106 + ])
107 +
82 else: 108 else:
83 raise ValueError('dataset=%s' % dataset) 109 raise ValueError('dataset=%s' % dataset)
84 110
...@@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode ...@@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode
111 if C.get()['cutout'] > 0: 137 if C.get()['cutout'] > 0:
112 transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) 138 transform_train.transforms.append(CutoutDefault(C.get()['cutout']))
113 139
114 - if dataset == 'cifar10': 140 + if dataset == 'BraTS':
141 + total_trainset =
142 + testset =
143 + elif dataset == 'cifar10':
115 total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 144 total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
116 testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 145 testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
117 elif dataset == 'reduced_cifar10': 146 elif dataset == 'reduced_cifar10':
......
...@@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch ...@@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch
16 from ray.tune import register_trainable, run_experiments 16 from ray.tune import register_trainable, run_experiments
17 from tqdm import tqdm 17 from tqdm import tqdm
18 18
19 +
20 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
21 +
19 from FastAutoAugment.archive import remove_deplicates, policy_decoder 22 from FastAutoAugment.archive import remove_deplicates, policy_decoder
20 from FastAutoAugment.augmentations import augment_list 23 from FastAutoAugment.augmentations import augment_list
21 from FastAutoAugment.common import get_logger, add_filehandler 24 from FastAutoAugment.common import get_logger, add_filehandler
......
...@@ -19,6 +19,9 @@ import torch.distributed as dist ...@@ -19,6 +19,9 @@ import torch.distributed as dist
19 from tqdm import tqdm 19 from tqdm import tqdm
20 from theconf import Config as C, ConfigArgumentParser 20 from theconf import Config as C, ConfigArgumentParser
21 21
22 +
23 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
24 +
22 from FastAutoAugment.common import get_logger, EMA, add_filehandler 25 from FastAutoAugment.common import get_logger, EMA, add_filehandler
23 from FastAutoAugment.data import get_dataloaders 26 from FastAutoAugment.data import get_dataloaders
24 from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet 27 from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet
......
...@@ -32,6 +32,7 @@ for i = 1 : length(subFolders) ...@@ -32,6 +32,7 @@ for i = 1 : length(subFolders)
32 32
33 % copy flair, segment flair data 33 % copy flair, segment flair data
34 34
35 + % seg의 검은 부분(정보 x)과 같은 인덱스 = 0
35 cp_flair(seg == 0) = 0; 36 cp_flair(seg == 0) = 0;
36 37
37 % save a segmented data 38 % save a segmented data
......
1 inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\'; 1 inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\';
2 -outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\frame\'; 2 +outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\total_frame\';
3 3
4 files = dir(inputheader); 4 files = dir(inputheader);
5 id = {files.name}; 5 id = {files.name};
...@@ -38,14 +38,14 @@ for i = 1 : length(files) ...@@ -38,14 +38,14 @@ for i = 1 : length(files)
38 c = 0; 38 c = 0;
39 step = round(((en) - (st))/11); 39 step = round(((en) - (st))/11);
40 for k = st + step : step : st + step*10 40 for k = st + step : step : st + step*10
41 - c = c+ 1;
42 -
43 type = '.png'; 41 type = '.png';
44 filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png 42 filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png
45 outpath = strcat(outfolder, filename); 43 outpath = strcat(outfolder, filename);
46 % typecase int16 to double, range[0, 1], rotate 90 and filp updown 44 % typecase int16 to double, range[0, 1], rotate 90 and filp updown
47 cp_data = flipud(rot90(mat2gray(double(data(:,:,k))))); 45 cp_data = flipud(rot90(mat2gray(double(data(:,:,k)))));
48 imwrite(cp_data, outpath); 46 imwrite(cp_data, outpath);
47 +
48 + c = c+ 1;
49 end 49 end
50 50
51 end 51 end
......