Showing
6 changed files
with
47 additions
and
8 deletions
... | @@ -5,6 +5,7 @@ import collections | ... | @@ -5,6 +5,7 @@ import collections |
5 | import pickle as cp | 5 | import pickle as cp |
6 | import glob | 6 | import glob |
7 | import numpy as np | 7 | import numpy as np |
8 | +import pandas as pd | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
10 | import torchvision | 11 | import torchvision |
... | @@ -31,7 +32,7 @@ current_epoch = 0 | ... | @@ -31,7 +32,7 @@ current_epoch = 0 |
31 | def split_dataset(args, dataset, k): | 32 | def split_dataset(args, dataset, k): |
32 | # load dataset | 33 | # load dataset |
33 | X = list(range(len(dataset))) | 34 | X = list(range(len(dataset))) |
34 | - Y = dataset.targets | 35 | + Y = dataset |
35 | 36 | ||
36 | # split to k-fold | 37 | # split to k-fold |
37 | assert len(X) == len(Y) | 38 | assert len(X) == len(Y) |
... | @@ -162,9 +163,11 @@ class CustomDataset(Dataset): | ... | @@ -162,9 +163,11 @@ class CustomDataset(Dataset): |
162 | return self.len | 163 | return self.len |
163 | 164 | ||
164 | def __getitem__(self, idx): | 165 | def __getitem__(self, idx): |
165 | - if self.transforms is not None: | 166 | + img, targets = self.img[idx], self.targets[idx] |
166 | - img = self.transforms(img) | 167 | + |
167 | - return img | 168 | + if self.transform is not None: |
169 | + img = self.transform(img) | ||
170 | + return img, targets | ||
168 | 171 | ||
169 | def get_dataset(args, transform, split='train'): | 172 | def get_dataset(args, transform, split='train'): |
170 | assert split in ['train', 'val', 'test', 'trainval'] | 173 | assert split in ['train', 'val', 'test', 'trainval'] | ... | ... |
... | @@ -15,6 +15,9 @@ from torchvision.transforms import transforms | ... | @@ -15,6 +15,9 @@ from torchvision.transforms import transforms |
15 | from sklearn.model_selection import StratifiedShuffleSplit | 15 | from sklearn.model_selection import StratifiedShuffleSplit |
16 | from theconf import Config as C | 16 | from theconf import Config as C |
17 | 17 | ||
18 | + | ||
19 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
20 | + | ||
18 | from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet | 21 | from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet |
19 | from FastAutoAugment.augmentations import * | 22 | from FastAutoAugment.augmentations import * |
20 | from FastAutoAugment.common import get_logger | 23 | from FastAutoAugment.common import get_logger |
... | @@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode | ... | @@ -79,6 +82,29 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode |
79 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
80 | ]) | 83 | ]) |
81 | 84 | ||
85 | + elif 'BraTS' in dataset: | ||
86 | + input_size = 240 | ||
87 | + sized_size = 256 | ||
88 | + | ||
89 | + if 'efficientnet' in C.get()['model']['type']: | ||
90 | + input_size = EfficientNet.get_image_size(C.get()['model']['type']) | ||
91 | + sized_size = input_size + 16 # TODO | ||
92 | + | ||
93 | + logger.info('size changed to %d/%d.' % (input_size, sized_size)) | ||
94 | + | ||
95 | + transform_train = transforms.Compose([ | ||
96 | + EfficientNetRandomCrop(input_size), | ||
97 | + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), | ||
98 | + transforms.RandomHorizontalFlip(), | ||
99 | + transforms.ToTensor(), | ||
100 | + ]) | ||
101 | + | ||
102 | + transform_test = transforms.Compose([ | ||
103 | + EfficientNetCenterCrop(input_size), | ||
104 | + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), | ||
105 | + transforms.ToTensor(), | ||
106 | + ]) | ||
107 | + | ||
82 | else: | 108 | else: |
83 | raise ValueError('dataset=%s' % dataset) | 109 | raise ValueError('dataset=%s' % dataset) |
84 | 110 | ||
... | @@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode | ... | @@ -111,7 +137,10 @@ def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode |
111 | if C.get()['cutout'] > 0: | 137 | if C.get()['cutout'] > 0: |
112 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) | 138 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) |
113 | 139 | ||
114 | - if dataset == 'cifar10': | 140 | + if dataset == 'BraTS': |
141 | + total_trainset = | ||
142 | + testset = | ||
143 | + elif dataset == 'cifar10': | ||
115 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) | 144 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) |
116 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) | 145 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) |
117 | elif dataset == 'reduced_cifar10': | 146 | elif dataset == 'reduced_cifar10': | ... | ... |
... | @@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch | ... | @@ -16,6 +16,9 @@ from ray.tune.suggest import HyperOptSearch |
16 | from ray.tune import register_trainable, run_experiments | 16 | from ray.tune import register_trainable, run_experiments |
17 | from tqdm import tqdm | 17 | from tqdm import tqdm |
18 | 18 | ||
19 | + | ||
20 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
21 | + | ||
19 | from FastAutoAugment.archive import remove_deplicates, policy_decoder | 22 | from FastAutoAugment.archive import remove_deplicates, policy_decoder |
20 | from FastAutoAugment.augmentations import augment_list | 23 | from FastAutoAugment.augmentations import augment_list |
21 | from FastAutoAugment.common import get_logger, add_filehandler | 24 | from FastAutoAugment.common import get_logger, add_filehandler | ... | ... |
... | @@ -19,6 +19,9 @@ import torch.distributed as dist | ... | @@ -19,6 +19,9 @@ import torch.distributed as dist |
19 | from tqdm import tqdm | 19 | from tqdm import tqdm |
20 | from theconf import Config as C, ConfigArgumentParser | 20 | from theconf import Config as C, ConfigArgumentParser |
21 | 21 | ||
22 | + | ||
23 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | ||
24 | + | ||
22 | from FastAutoAugment.common import get_logger, EMA, add_filehandler | 25 | from FastAutoAugment.common import get_logger, EMA, add_filehandler |
23 | from FastAutoAugment.data import get_dataloaders | 26 | from FastAutoAugment.data import get_dataloaders |
24 | from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet | 27 | from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet | ... | ... |
... | @@ -32,6 +32,7 @@ for i = 1 : length(subFolders) | ... | @@ -32,6 +32,7 @@ for i = 1 : length(subFolders) |
32 | 32 | ||
33 | % copy flair, segment flair data | 33 | % copy flair, segment flair data |
34 | 34 | ||
35 | + % seg의 검은 부분(정보 x)과 같은 인덱스 = 0 | ||
35 | cp_flair(seg == 0) = 0; | 36 | cp_flair(seg == 0) = 0; |
36 | 37 | ||
37 | % save a segmented data | 38 | % save a segmented data | ... | ... |
1 | inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\'; | 1 | inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\'; |
2 | -outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\frame\'; | 2 | +outfolder = '..\data\MICCAI_BraTS_2019_Data_Training\total_frame\'; |
3 | 3 | ||
4 | files = dir(inputheader); | 4 | files = dir(inputheader); |
5 | id = {files.name}; | 5 | id = {files.name}; |
... | @@ -38,14 +38,14 @@ for i = 1 : length(files) | ... | @@ -38,14 +38,14 @@ for i = 1 : length(files) |
38 | c = 0; | 38 | c = 0; |
39 | step = round(((en) - (st))/11); | 39 | step = round(((en) - (st))/11); |
40 | for k = st + step : step : st + step*10 | 40 | for k = st + step : step : st + step*10 |
41 | - c = c+ 1; | ||
42 | - | ||
43 | type = '.png'; | 41 | type = '.png'; |
44 | filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png | 42 | filename = strcat(id, '_', int2str(c), type); % BraTS19_2013_2_1_seg_flair_c.png |
45 | outpath = strcat(outfolder, filename); | 43 | outpath = strcat(outfolder, filename); |
46 | % typecase int16 to double, range[0, 1], rotate 90 and filp updown | 44 | % typecase int16 to double, range[0, 1], rotate 90 and filp updown |
47 | cp_data = flipud(rot90(mat2gray(double(data(:,:,k))))); | 45 | cp_data = flipud(rot90(mat2gray(double(data(:,:,k))))); |
48 | imwrite(cp_data, outpath); | 46 | imwrite(cp_data, outpath); |
47 | + | ||
48 | + c = c+ 1; | ||
49 | end | 49 | end |
50 | 50 | ||
51 | end | 51 | end | ... | ... |
-
Please register or login to post a comment