Showing
6 changed files
with
863 additions
and
39 deletions
code/FAA2/cifar_utils.py
0 → 100644
1 | +import os | ||
2 | +import time | ||
3 | +import importlib | ||
4 | +import collections | ||
5 | +import pickle as cp | ||
6 | +import numpy as np | ||
7 | + | ||
8 | + | ||
9 | +import torch | ||
10 | +import torchvision | ||
11 | +import torch.nn.functional as F | ||
12 | +import torchvision.models as models | ||
13 | +import torchvision.transforms as transforms | ||
14 | +from torch.utils.data import Subset | ||
15 | + | ||
16 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
17 | + | ||
18 | +from networks import basenet | ||
19 | +from networks import grayResNet | ||
20 | + | ||
21 | + | ||
22 | +DATASET_PATH = './data/' | ||
23 | +current_epoch = 0 | ||
24 | + | ||
25 | + | ||
26 | +def split_dataset(args, dataset, k): | ||
27 | + # load dataset | ||
28 | + X = list(range(len(dataset))) | ||
29 | + Y = dataset.targets | ||
30 | + | ||
31 | + # split to k-fold | ||
32 | + assert len(X) == len(Y) | ||
33 | + | ||
34 | + def _it_to_list(_it): | ||
35 | + return list(zip(*list(_it))) | ||
36 | + | ||
37 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
38 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
39 | + | ||
40 | + return Dm_indexes, Da_indexes | ||
41 | + | ||
42 | + | ||
43 | +def concat_image_features(image, features, max_features=3): | ||
44 | + _, h, w = image.shape | ||
45 | + | ||
46 | + max_features = min(features.size(0), max_features) | ||
47 | + image_feature = image.clone() | ||
48 | + | ||
49 | + for i in range(max_features): | ||
50 | + feature = features[i:i+1] | ||
51 | + _min, _max = torch.min(feature), torch.max(feature) | ||
52 | + feature = (feature - _min) / (_max - _min + 1e-6) | ||
53 | + feature = torch.cat([feature]*3, 0) | ||
54 | + feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
55 | + feature = F.upsample(feature, size=(h,w), mode="bilinear") | ||
56 | + feature = feature.view(3, h, w) | ||
57 | + image_feature = torch.cat((image_feature, feature), 2) | ||
58 | + | ||
59 | + return image_feature | ||
60 | + | ||
61 | + | ||
62 | +def get_model_name(args): | ||
63 | + from datetime import datetime | ||
64 | + now = datetime.now() | ||
65 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
66 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
67 | + return model_name | ||
68 | + | ||
69 | + | ||
70 | +def dict_to_namedtuple(d): | ||
71 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
72 | + | ||
73 | + for k,v in d.items(): | ||
74 | + if type(v) is dict: | ||
75 | + d[k] = dict_to_namedtuple(v) | ||
76 | + | ||
77 | + elif type(v) is str: | ||
78 | + try: | ||
79 | + d[k] = eval(v) | ||
80 | + except: | ||
81 | + d[k] = v | ||
82 | + | ||
83 | + args = Args(**d) | ||
84 | + return args | ||
85 | + | ||
86 | + | ||
87 | +def parse_args(kwargs): | ||
88 | + # combine with default args | ||
89 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10' | ||
90 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10' | ||
91 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
92 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1 | ||
93 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
94 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
95 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
96 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
97 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000 | ||
98 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000 | ||
99 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
100 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
101 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
102 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000 | ||
103 | + kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | ||
104 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
105 | + | ||
106 | + # to named tuple | ||
107 | + args = dict_to_namedtuple(kwargs) | ||
108 | + return args, kwargs | ||
109 | + | ||
110 | + | ||
111 | +def select_model(args): | ||
112 | + resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), | ||
113 | + 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} | ||
114 | + #print("args.network: \n", args.network) | ||
115 | + if args.network in resnet_dict: | ||
116 | + backbone = resnet_dict[args.network] | ||
117 | + model = basenet.BaseNet(backbone, args) | ||
118 | + else: | ||
119 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
120 | + model = Net(args) | ||
121 | + | ||
122 | + print(model) | ||
123 | + return model | ||
124 | + | ||
125 | + | ||
126 | +def select_optimizer(args, model): | ||
127 | + if args.optimizer == 'sgd': | ||
128 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
129 | + elif args.optimizer == 'rms': | ||
130 | + #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5) | ||
131 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
132 | + elif args.optimizer == 'adam': | ||
133 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
134 | + else: | ||
135 | + raise Exception('Unknown Optimizer') | ||
136 | + return optimizer | ||
137 | + | ||
138 | + | ||
139 | +def select_scheduler(args, optimizer): | ||
140 | + if not args.scheduler or args.scheduler == 'None': | ||
141 | + return None | ||
142 | + elif args.scheduler =='clr': | ||
143 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
144 | + elif args.scheduler =='exp': | ||
145 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
146 | + else: | ||
147 | + raise Exception('Unknown Scheduler') | ||
148 | + | ||
149 | + | ||
150 | +def get_dataset(args, transform, split='train'): | ||
151 | + assert split in ['train', 'val', 'test', 'trainval'] | ||
152 | + | ||
153 | + if args.dataset == 'cifar10': | ||
154 | + train = split in ['train', 'val', 'trainval'] | ||
155 | + dataset = torchvision.datasets.CIFAR10(DATASET_PATH, | ||
156 | + train=train, | ||
157 | + transform=transform, | ||
158 | + download=True) | ||
159 | + | ||
160 | + if split in ['train', 'val']: | ||
161 | + split_path = os.path.join(DATASET_PATH, | ||
162 | + 'cifar-10-batches-py', 'train_val_index.cp') | ||
163 | + | ||
164 | + if not os.path.exists(split_path): | ||
165 | + [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
166 | + split_index = {'train':train_index, 'val':val_index} | ||
167 | + cp.dump(split_index, open(split_path, 'wb')) | ||
168 | + | ||
169 | + split_index = cp.load(open(split_path, 'rb')) | ||
170 | + dataset = Subset(dataset, split_index[split]) | ||
171 | + | ||
172 | + elif args.dataset == 'imagenet': | ||
173 | + dataset = torchvision.datasets.ImageNet(DATASET_PATH, | ||
174 | + split=split, | ||
175 | + transform=transform, | ||
176 | + download=(split is 'val')) | ||
177 | + | ||
178 | + else: | ||
179 | + raise Exception('Unknown dataset') | ||
180 | + | ||
181 | + return dataset | ||
182 | + | ||
183 | + | ||
184 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
185 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
186 | + batch_size=args.batch_size, | ||
187 | + shuffle=shuffle, | ||
188 | + num_workers=args.num_workers, | ||
189 | + pin_memory=pin_memory) | ||
190 | + return data_loader | ||
191 | + | ||
192 | + | ||
193 | +def get_inf_dataloader(args, dataset): | ||
194 | + global current_epoch | ||
195 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
196 | + | ||
197 | + while True: | ||
198 | + try: | ||
199 | + batch = next(data_loader) | ||
200 | + | ||
201 | + except StopIteration: | ||
202 | + current_epoch += 1 | ||
203 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
204 | + batch = next(data_loader) | ||
205 | + | ||
206 | + yield batch | ||
207 | + | ||
208 | + | ||
209 | +def get_train_transform(args, model, log_dir=None): | ||
210 | + if args.fast_auto_augment: | ||
211 | + assert args.dataset == 'cifar10' # TODO: FastAutoAugment for Imagenet | ||
212 | + | ||
213 | + from fast_auto_augment import fast_auto_augment | ||
214 | + if args.augment_path: | ||
215 | + transform = cp.load(open(args.augment_path, 'rb')) | ||
216 | + os.system('cp {} {}'.format( | ||
217 | + args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | ||
218 | + else: | ||
219 | + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | ||
220 | + if log_dir: | ||
221 | + cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | ||
222 | + | ||
223 | + elif args.dataset == 'cifar10': | ||
224 | + transform = transforms.Compose([ | ||
225 | + transforms.Pad(4), | ||
226 | + transforms.RandomCrop(32), | ||
227 | + transforms.RandomHorizontalFlip(), | ||
228 | + transforms.ToTensor() | ||
229 | + ]) | ||
230 | + | ||
231 | + elif args.dataset == 'imagenet': | ||
232 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
233 | + transform = transforms.Compose([ | ||
234 | + transforms.Resize([resize_h, resize_w]), | ||
235 | + transforms.RandomCrop(model.img_size), | ||
236 | + transforms.RandomHorizontalFlip(), | ||
237 | + transforms.ToTensor() | ||
238 | + ]) | ||
239 | + | ||
240 | + else: | ||
241 | + raise Exception('Unknown Dataset') | ||
242 | + | ||
243 | + print(transform) | ||
244 | + | ||
245 | + return transform | ||
246 | + | ||
247 | + | ||
248 | +def get_valid_transform(args, model): | ||
249 | + if args.dataset == 'cifar10': | ||
250 | + val_transform = transforms.Compose([ | ||
251 | + transforms.Resize(32), | ||
252 | + transforms.ToTensor() | ||
253 | + ]) | ||
254 | + | ||
255 | + elif args.dataset == 'imagenet': | ||
256 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
257 | + val_transform = transforms.Compose([ | ||
258 | + transforms.Resize([resize_h, resize_w]), | ||
259 | + transforms.ToTensor() | ||
260 | + ]) | ||
261 | + | ||
262 | + else: | ||
263 | + raise Exception('Unknown Dataset') | ||
264 | + | ||
265 | + return val_transform | ||
266 | + | ||
267 | + | ||
268 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
269 | + model.train() | ||
270 | + images, target = batch | ||
271 | + | ||
272 | + if device: | ||
273 | + images = images.to(device) | ||
274 | + target = target.to(device) | ||
275 | + | ||
276 | + elif args.use_cuda: | ||
277 | + images = images.cuda(non_blocking=True) | ||
278 | + target = target.cuda(non_blocking=True) | ||
279 | + | ||
280 | + # compute output | ||
281 | + start_t = time.time() | ||
282 | + output, first = model(images) | ||
283 | + forward_t = time.time() - start_t | ||
284 | + loss = criterion(output, target) | ||
285 | + | ||
286 | + # measure accuracy and record loss | ||
287 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
288 | + acc1 /= images.size(0) | ||
289 | + acc5 /= images.size(0) | ||
290 | + | ||
291 | + # compute gradient and do SGD step | ||
292 | + optimizer.zero_grad() | ||
293 | + start_t = time.time() | ||
294 | + loss.backward() | ||
295 | + backward_t = time.time() - start_t | ||
296 | + optimizer.step() | ||
297 | + if scheduler: scheduler.step() | ||
298 | + | ||
299 | + if writer and step % args.print_step == 0: | ||
300 | + n_imgs = min(images.size(0), 10) | ||
301 | + for j in range(n_imgs): | ||
302 | + writer.add_image('train/input_image', | ||
303 | + concat_image_features(images[j], first[j]), global_step=step) | ||
304 | + | ||
305 | + return acc1, acc5, loss, forward_t, backward_t | ||
306 | + | ||
307 | + | ||
308 | +def validate(args, model, criterion, valid_loader, step, writer, device=None): | ||
309 | + # switch to evaluate mode | ||
310 | + model.eval() | ||
311 | + | ||
312 | + acc1, acc5 = 0, 0 | ||
313 | + samples = 0 | ||
314 | + infer_t = 0 | ||
315 | + | ||
316 | + with torch.no_grad(): | ||
317 | + for i, (images, target) in enumerate(valid_loader): | ||
318 | + | ||
319 | + start_t = time.time() | ||
320 | + if device: | ||
321 | + images = images.to(device) | ||
322 | + target = target.to(device) | ||
323 | + | ||
324 | + elif args.use_cuda is not None: | ||
325 | + images = images.cuda(non_blocking=True) | ||
326 | + target = target.cuda(non_blocking=True) | ||
327 | + | ||
328 | + # compute output | ||
329 | + output, first = model(images) | ||
330 | + loss = criterion(output, target) | ||
331 | + infer_t += time.time() - start_t | ||
332 | + | ||
333 | + # measure accuracy and record loss | ||
334 | + _acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
335 | + acc1 += _acc1 | ||
336 | + acc5 += _acc5 | ||
337 | + samples += images.size(0) | ||
338 | + | ||
339 | + acc1 /= samples | ||
340 | + acc5 /= samples | ||
341 | + | ||
342 | + if writer: | ||
343 | + n_imgs = min(images.size(0), 10) | ||
344 | + for j in range(n_imgs): | ||
345 | + writer.add_image('valid/input_image', | ||
346 | + concat_image_features(images[j], first[j]), global_step=step) | ||
347 | + | ||
348 | + return acc1, acc5, loss, infer_t | ||
349 | + | ||
350 | + | ||
351 | +def accuracy(output, target, topk=(1,)): | ||
352 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
353 | + with torch.no_grad(): | ||
354 | + maxk = max(topk) | ||
355 | + batch_size = target.size(0) | ||
356 | + | ||
357 | + _, pred = output.topk(maxk, 1, True, True) | ||
358 | + pred = pred.t() | ||
359 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
360 | + | ||
361 | + res = [] | ||
362 | + for k in topk: | ||
363 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
364 | + res.append(correct_k) | ||
365 | + return res |
... | @@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): | ... | @@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): |
54 | if torch.cuda.device_count() > 1: | 54 | if torch.cuda.device_count() > 1: |
55 | print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | 55 | print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) |
56 | model = nn.DataParallel(model) | 56 | model = nn.DataParallel(model) |
57 | + elif torch.cuda.device_count() == 1: | ||
58 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
57 | 59 | ||
58 | start_t = time.time() | 60 | start_t = time.time() |
59 | for step in range(args.start_step, args.max_step): | 61 | for step in range(args.start_step, args.max_step): |
60 | batch = next(data_loader) | 62 | batch = next(data_loader) |
63 | + | ||
61 | _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | 64 | _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) |
62 | 65 | ||
63 | if step % args.print_step == 0: | 66 | if step % args.print_step == 0: |
... | @@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
173 | device = torch.device('cuda:%d' % device_id) | 176 | device = torch.device('cuda:%d' % device_id) |
174 | _transform = [] | 177 | _transform = [] |
175 | 178 | ||
176 | - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | 179 | + print('[+] Child %d training started (GPU: %d)' % (k, device_id)) |
177 | 180 | ||
178 | # train child model | 181 | # train child model |
179 | child_model = copy.deepcopy(model) | 182 | child_model = copy.deepcopy(model) |
... | @@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
188 | 191 | ||
189 | return _transform | 192 | return _transform |
190 | 193 | ||
191 | - | 194 | +#fast_auto_augment(args, model, K=4, B=1, num_process=4) |
192 | def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | 195 | def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): |
193 | args_str = json.dumps(args._asdict()) | 196 | args_str = json.dumps(args._asdict()) |
194 | dataset = get_dataset(args, None, 'trainval') | 197 | dataset = get_dataset(args, None, 'trainval') | ... | ... |
... | @@ -4,6 +4,12 @@ class BaseNet(nn.Module): | ... | @@ -4,6 +4,12 @@ class BaseNet(nn.Module): |
4 | def __init__(self, backbone, args): | 4 | def __init__(self, backbone, args): |
5 | super(BaseNet, self).__init__() | 5 | super(BaseNet, self).__init__() |
6 | 6 | ||
7 | + #testing | ||
8 | + for layer in backbone.children(): | ||
9 | + print("\nRESNET50 LAYERS\n") | ||
10 | + print(layer) | ||
11 | + | ||
12 | + | ||
7 | # Separate layers | 13 | # Separate layers |
8 | self.first = nn.Sequential(*list(backbone.children())[:1]) | 14 | self.first = nn.Sequential(*list(backbone.children())[:1]) |
9 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) | 15 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) |
... | @@ -14,6 +20,20 @@ class BaseNet(nn.Module): | ... | @@ -14,6 +20,20 @@ class BaseNet(nn.Module): |
14 | def forward(self, x): | 20 | def forward(self, x): |
15 | f = self.first(x) | 21 | f = self.first(x) |
16 | x = self.after(f) | 22 | x = self.after(f) |
17 | - x = x.reshape(x.size(0), -1) | ||
18 | x = self.fc(x) | 23 | x = self.fc(x) |
19 | return x, f | 24 | return x, f |
25 | + | ||
26 | + | ||
27 | +""" | ||
28 | + print("before reshape:\n", x.size()) | ||
29 | + #[128, 2048, 4, 4] | ||
30 | + # #cifar 내장[128, 2048, 1, 1] | ||
31 | + x = x.reshape(x.size(0), -1) | ||
32 | + print("after reshape:\n", x.size()) | ||
33 | + #[128, 32768] | ||
34 | + #cifar [128, 2048] | ||
35 | + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10] | ||
36 | + print("fc :\n", self.fc) | ||
37 | + #Linear(in_features=2048, out_features=10, bias=True) | ||
38 | + #cifar Linear(in_features=2048, out_features=1000, bias=True) | ||
39 | +""" | ... | ... |
code/FAA2/networks/grayResNet.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | + | ||
5 | + | ||
6 | +class BasicBlock(nn.Module): | ||
7 | + expansion = 1 | ||
8 | + | ||
9 | + def __init__(self, in_planes, planes, stride=1): | ||
10 | + super(BasicBlock, self).__init__() | ||
11 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
12 | + self.bn1 = nn.BatchNorm2d(planes) | ||
13 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
14 | + self.bn2 = nn.BatchNorm2d(planes) | ||
15 | + | ||
16 | + self.shortcut = nn.Sequential() | ||
17 | + if stride != 1 or in_planes != self.expansion*planes: | ||
18 | + self.shortcut = nn.Sequential( | ||
19 | + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
20 | + nn.BatchNorm2d(self.expansion*planes) | ||
21 | + ) | ||
22 | + | ||
23 | + def forward(self, x): | ||
24 | + out = F.relu(self.bn1(self.conv1(x))) | ||
25 | + out = self.bn2(self.conv2(out)) | ||
26 | + out += self.shortcut(x) | ||
27 | + out = F.relu(out) | ||
28 | + return out | ||
29 | + | ||
30 | + | ||
31 | +class Bottleneck(nn.Module): | ||
32 | + expansion = 4 | ||
33 | + | ||
34 | + def __init__(self, in_planes, planes, stride=1): | ||
35 | + super(Bottleneck, self).__init__() | ||
36 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | ||
37 | + self.bn1 = nn.BatchNorm2d(planes) | ||
38 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
39 | + self.bn2 = nn.BatchNorm2d(planes) | ||
40 | + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | ||
41 | + self.bn3 = nn.BatchNorm2d(self.expansion*planes) | ||
42 | + | ||
43 | + self.shortcut = nn.Sequential() | ||
44 | + if stride != 1 or in_planes != self.expansion*planes: | ||
45 | + self.shortcut = nn.Sequential( | ||
46 | + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
47 | + nn.BatchNorm2d(self.expansion*planes) | ||
48 | + ) | ||
49 | + | ||
50 | + def forward(self, x): | ||
51 | + out = F.relu(self.bn1(self.conv1(x))) | ||
52 | + out = F.relu(self.bn2(self.conv2(out))) | ||
53 | + out = self.bn3(self.conv3(out)) | ||
54 | + out += self.shortcut(x) | ||
55 | + out = F.relu(out) | ||
56 | + return out | ||
57 | + | ||
58 | + | ||
59 | +class ResNet(nn.Module): | ||
60 | + def __init__(self, block, num_blocks, num_classes=10): | ||
61 | + super(ResNet, self).__init__() | ||
62 | + self.in_planes = 64 | ||
63 | + | ||
64 | + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
65 | + self.bn1 = nn.BatchNorm2d(64) | ||
66 | + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | ||
67 | + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | ||
68 | + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | ||
69 | + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | ||
70 | + self.linear = nn.Linear(512*block.expansion, num_classes) | ||
71 | + | ||
72 | + def _make_layer(self, block, planes, num_blocks, stride): | ||
73 | + strides = [stride] + [1]*(num_blocks-1) | ||
74 | + layers = [] | ||
75 | + for stride in strides: | ||
76 | + layers.append(block(self.in_planes, planes, stride)) | ||
77 | + self.in_planes = planes * block.expansion | ||
78 | + return nn.Sequential(*layers) | ||
79 | + | ||
80 | + def forward(self, x): | ||
81 | + out = F.relu(self.bn1(self.conv1(x))) | ||
82 | + out = self.layer1(out) | ||
83 | + out = self.layer2(out) | ||
84 | + out = self.layer3(out) | ||
85 | + out = self.layer4(out) | ||
86 | + out = F.avg_pool2d(out, 4) | ||
87 | + out = out.view(out.size(0), -1) | ||
88 | + out = self.linear(out) | ||
89 | + return out | ||
90 | + | ||
91 | + | ||
92 | +def ResNet18(): | ||
93 | + return ResNet(BasicBlock, [2,2,2,2]) | ||
94 | + | ||
95 | +def ResNet34(): | ||
96 | + return ResNet(BasicBlock, [3,4,6,3]) | ||
97 | + | ||
98 | +def ResNet50(): | ||
99 | + return ResNet(Bottleneck, [3,4,6,3]) | ||
100 | + | ||
101 | +def ResNet101(): | ||
102 | + return ResNet(Bottleneck, [3,4,23,3]) | ||
103 | + | ||
104 | +def ResNet152(): | ||
105 | + return ResNet(Bottleneck, [3,8,36,3]) | ||
106 | + | ||
107 | + | ||
108 | +def test(): | ||
109 | + net = ResNet18() | ||
110 | + y = net(torch.randn(1,3,32,32)) | ||
111 | + print(y.size()) |
code/FAA2/networks/grayResNet2.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +#from .utils import load_state_dict_from_url | ||
4 | + | ||
5 | + | ||
6 | +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', | ||
7 | + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', | ||
8 | + 'wide_resnet50_2', 'wide_resnet101_2'] | ||
9 | + | ||
10 | + | ||
11 | +model_urls = { | ||
12 | + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||
13 | + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||
14 | + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||
15 | + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||
16 | + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||
17 | + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', | ||
18 | + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', | ||
19 | + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', | ||
20 | + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', | ||
21 | +} | ||
22 | + | ||
23 | + | ||
24 | +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||
25 | + """3x3 convolution with padding""" | ||
26 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
27 | + padding=dilation, groups=groups, bias=False, dilation=dilation) | ||
28 | + | ||
29 | + | ||
30 | +def conv1x1(in_planes, out_planes, stride=1): | ||
31 | + """1x1 convolution""" | ||
32 | + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||
33 | + | ||
34 | + | ||
35 | +class BasicBlock(nn.Module): | ||
36 | + expansion = 1 | ||
37 | + | ||
38 | + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||
39 | + base_width=64, dilation=1, norm_layer=None): | ||
40 | + super(BasicBlock, self).__init__() | ||
41 | + if norm_layer is None: | ||
42 | + norm_layer = nn.BatchNorm2d | ||
43 | + if groups != 1 or base_width != 64: | ||
44 | + raise ValueError('BasicBlock only supports groups=1 and base_width=64') | ||
45 | + if dilation > 1: | ||
46 | + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | ||
47 | + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||
48 | + self.conv1 = conv3x3(inplanes, planes, stride) | ||
49 | + self.bn1 = norm_layer(planes) | ||
50 | + self.relu = nn.ReLU(inplace=True) | ||
51 | + self.conv2 = conv3x3(planes, planes) | ||
52 | + self.bn2 = norm_layer(planes) | ||
53 | + self.downsample = downsample | ||
54 | + self.stride = stride | ||
55 | + | ||
56 | + def forward(self, x): | ||
57 | + identity = x | ||
58 | + | ||
59 | + out = self.conv1(x) | ||
60 | + out = self.bn1(out) | ||
61 | + out = self.relu(out) | ||
62 | + | ||
63 | + out = self.conv2(out) | ||
64 | + out = self.bn2(out) | ||
65 | + | ||
66 | + if self.downsample is not None: | ||
67 | + identity = self.downsample(x) | ||
68 | + | ||
69 | + out += identity | ||
70 | + out = self.relu(out) | ||
71 | + | ||
72 | + return out | ||
73 | + | ||
74 | + | ||
75 | +class Bottleneck(nn.Module): | ||
76 | + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | ||
77 | + # while original implementation places the stride at the first 1x1 convolution(self.conv1) | ||
78 | + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | ||
79 | + # This variant is also known as ResNet V1.5 and improves accuracy according to | ||
80 | + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | ||
81 | + | ||
82 | + expansion = 4 | ||
83 | + | ||
84 | + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||
85 | + base_width=64, dilation=1, norm_layer=None): | ||
86 | + super(Bottleneck, self).__init__() | ||
87 | + if norm_layer is None: | ||
88 | + norm_layer = nn.BatchNorm2d | ||
89 | + width = int(planes * (base_width / 64.)) * groups | ||
90 | + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||
91 | + self.conv1 = conv1x1(inplanes, width) | ||
92 | + self.bn1 = norm_layer(width) | ||
93 | + self.conv2 = conv3x3(width, width, stride, groups, dilation) | ||
94 | + self.bn2 = norm_layer(width) | ||
95 | + self.conv3 = conv1x1(width, planes * self.expansion) | ||
96 | + self.bn3 = norm_layer(planes * self.expansion) | ||
97 | + self.relu = nn.ReLU(inplace=True) | ||
98 | + self.downsample = downsample | ||
99 | + self.stride = stride | ||
100 | + | ||
101 | + def forward(self, x): | ||
102 | + identity = x | ||
103 | + | ||
104 | + out = self.conv1(x) | ||
105 | + out = self.bn1(out) | ||
106 | + out = self.relu(out) | ||
107 | + | ||
108 | + out = self.conv2(out) | ||
109 | + out = self.bn2(out) | ||
110 | + out = self.relu(out) | ||
111 | + | ||
112 | + out = self.conv3(out) | ||
113 | + out = self.bn3(out) | ||
114 | + | ||
115 | + if self.downsample is not None: | ||
116 | + identity = self.downsample(x) | ||
117 | + | ||
118 | + out += identity | ||
119 | + out = self.relu(out) | ||
120 | + | ||
121 | + return out | ||
122 | + | ||
123 | + | ||
124 | +class ResNet(nn.Module): | ||
125 | + | ||
126 | + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, | ||
127 | + groups=1, width_per_group=64, replace_stride_with_dilation=None, | ||
128 | + norm_layer=None): | ||
129 | + super(ResNet, self).__init__() | ||
130 | + if norm_layer is None: | ||
131 | + norm_layer = nn.BatchNorm2d | ||
132 | + self._norm_layer = norm_layer | ||
133 | + | ||
134 | + self.inplanes = 64 | ||
135 | + self.dilation = 1 | ||
136 | + if replace_stride_with_dilation is None: | ||
137 | + # each element in the tuple indicates if we should replace | ||
138 | + # the 2x2 stride with a dilated convolution instead | ||
139 | + replace_stride_with_dilation = [False, False, False] | ||
140 | + if len(replace_stride_with_dilation) != 3: | ||
141 | + raise ValueError("replace_stride_with_dilation should be None " | ||
142 | + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | ||
143 | + self.groups = groups | ||
144 | + self.base_width = width_per_group | ||
145 | + # change dimension 3->1 for grayscale input | ||
146 | + self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, | ||
147 | + bias=False) | ||
148 | + self.bn1 = norm_layer(self.inplanes) | ||
149 | + self.relu = nn.ReLU(inplace=True) | ||
150 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
151 | + self.layer1 = self._make_layer(block, 64, layers[0]) | ||
152 | + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | ||
153 | + dilate=replace_stride_with_dilation[0]) | ||
154 | + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | ||
155 | + dilate=replace_stride_with_dilation[1]) | ||
156 | + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | ||
157 | + dilate=replace_stride_with_dilation[2]) | ||
158 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
159 | + self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
160 | + | ||
161 | + for m in self.modules(): | ||
162 | + if isinstance(m, nn.Conv2d): | ||
163 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
164 | + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
165 | + nn.init.constant_(m.weight, 1) | ||
166 | + nn.init.constant_(m.bias, 0) | ||
167 | + | ||
168 | + # Zero-initialize the last BN in each residual branch, | ||
169 | + # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||
170 | + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||
171 | + if zero_init_residual: | ||
172 | + for m in self.modules(): | ||
173 | + if isinstance(m, Bottleneck): | ||
174 | + nn.init.constant_(m.bn3.weight, 0) | ||
175 | + elif isinstance(m, BasicBlock): | ||
176 | + nn.init.constant_(m.bn2.weight, 0) | ||
177 | + | ||
178 | + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||
179 | + norm_layer = self._norm_layer | ||
180 | + downsample = None | ||
181 | + previous_dilation = self.dilation | ||
182 | + if dilate: | ||
183 | + self.dilation *= stride | ||
184 | + stride = 1 | ||
185 | + if stride != 1 or self.inplanes != planes * block.expansion: | ||
186 | + downsample = nn.Sequential( | ||
187 | + conv1x1(self.inplanes, planes * block.expansion, stride), | ||
188 | + norm_layer(planes * block.expansion), | ||
189 | + ) | ||
190 | + | ||
191 | + layers = [] | ||
192 | + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | ||
193 | + self.base_width, previous_dilation, norm_layer)) | ||
194 | + self.inplanes = planes * block.expansion | ||
195 | + for _ in range(1, blocks): | ||
196 | + layers.append(block(self.inplanes, planes, groups=self.groups, | ||
197 | + base_width=self.base_width, dilation=self.dilation, | ||
198 | + norm_layer=norm_layer)) | ||
199 | + | ||
200 | + return nn.Sequential(*layers) | ||
201 | + | ||
202 | + def _forward_impl(self, x): | ||
203 | + # See note [TorchScript super()] | ||
204 | + x = self.conv1(x) | ||
205 | + x = self.bn1(x) | ||
206 | + x = self.relu(x) | ||
207 | + x = self.maxpool(x) | ||
208 | + | ||
209 | + x = self.layer1(x) | ||
210 | + x = self.layer2(x) | ||
211 | + x = self.layer3(x) | ||
212 | + x = self.layer4(x) | ||
213 | + | ||
214 | + x = self.avgpool(x) | ||
215 | + x = torch.flatten(x, 1) | ||
216 | + x = self.fc(x) | ||
217 | + | ||
218 | + return x | ||
219 | + | ||
220 | + def forward(self, x): | ||
221 | + return self._forward_impl(x) | ||
222 | + | ||
223 | + | ||
224 | +def _resnet(arch, block, layers, pretrained, progress, **kwargs): | ||
225 | + model = ResNet(block, layers, **kwargs) | ||
226 | + # if pretrained: | ||
227 | + # state_dict = load_state_dict_from_url(model_urls[arch], | ||
228 | + # progress=progress) | ||
229 | + # model.load_state_dict(state_dict) | ||
230 | + return model | ||
231 | + | ||
232 | + | ||
233 | +def resnet18(pretrained=False, progress=True, **kwargs): | ||
234 | + r"""ResNet-18 model from | ||
235 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
236 | + Args: | ||
237 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
238 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
239 | + """ | ||
240 | + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, | ||
241 | + **kwargs) | ||
242 | + | ||
243 | + | ||
244 | +def resnet34(pretrained=False, progress=True, **kwargs): | ||
245 | + r"""ResNet-34 model from | ||
246 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
247 | + Args: | ||
248 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
249 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
250 | + """ | ||
251 | + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, | ||
252 | + **kwargs) | ||
253 | + | ||
254 | + | ||
255 | +def resnet50(pretrained=False, progress=True, **kwargs): | ||
256 | + r"""ResNet-50 model from | ||
257 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
258 | + Args: | ||
259 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
260 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
261 | + """ | ||
262 | + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | ||
263 | + **kwargs) | ||
264 | + | ||
265 | + | ||
266 | +def resnet101(pretrained=False, progress=True, **kwargs): | ||
267 | + r"""ResNet-101 model from | ||
268 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
269 | + Args: | ||
270 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
271 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
272 | + """ | ||
273 | + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, | ||
274 | + **kwargs) | ||
275 | + | ||
276 | + | ||
277 | +def resnet152(pretrained=False, progress=True, **kwargs): | ||
278 | + r"""ResNet-152 model from | ||
279 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
280 | + Args: | ||
281 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
282 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
283 | + """ | ||
284 | + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, | ||
285 | + **kwargs) | ||
286 | + | ||
287 | + | ||
288 | +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): | ||
289 | + r"""ResNeXt-50 32x4d model from | ||
290 | + `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||
291 | + Args: | ||
292 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
293 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
294 | + """ | ||
295 | + kwargs['groups'] = 32 | ||
296 | + kwargs['width_per_group'] = 4 | ||
297 | + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], | ||
298 | + pretrained, progress, **kwargs) | ||
299 | + | ||
300 | + | ||
301 | +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): | ||
302 | + r"""ResNeXt-101 32x8d model from | ||
303 | + `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||
304 | + Args: | ||
305 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
306 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
307 | + """ | ||
308 | + kwargs['groups'] = 32 | ||
309 | + kwargs['width_per_group'] = 8 | ||
310 | + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], | ||
311 | + pretrained, progress, **kwargs) | ||
312 | + | ||
313 | + | ||
314 | +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): | ||
315 | + r"""Wide ResNet-50-2 model from | ||
316 | + `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||
317 | + The model is the same as ResNet except for the bottleneck number of channels | ||
318 | + which is twice larger in every block. The number of channels in outer 1x1 | ||
319 | + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||
320 | + channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||
321 | + Args: | ||
322 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
323 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
324 | + """ | ||
325 | + kwargs['width_per_group'] = 64 * 2 | ||
326 | + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], | ||
327 | + pretrained, progress, **kwargs) | ||
328 | + | ||
329 | + | ||
330 | +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): | ||
331 | + r"""Wide ResNet-101-2 model from | ||
332 | + `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||
333 | + The model is the same as ResNet except for the bottleneck number of channels | ||
334 | + which is twice larger in every block. The number of channels in outer 1x1 | ||
335 | + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||
336 | + channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||
337 | + Args: | ||
338 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
339 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
340 | + """ | ||
341 | + kwargs['width_per_group'] = 64 * 2 | ||
342 | + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], | ||
343 | + pretrained, progress, **kwargs) |
... | @@ -6,6 +6,7 @@ import pickle as cp | ... | @@ -6,6 +6,7 @@ import pickle as cp |
6 | import glob | 6 | import glob |
7 | import numpy as np | 7 | import numpy as np |
8 | import pandas as pd | 8 | import pandas as pd |
9 | + | ||
9 | from natsort import natsorted | 10 | from natsort import natsorted |
10 | from PIL import Image | 11 | from PIL import Image |
11 | import torch | 12 | import torch |
... | @@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split | ... | @@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split |
21 | from sklearn.model_selection import KFold | 22 | from sklearn.model_selection import KFold |
22 | 23 | ||
23 | from networks import basenet | 24 | from networks import basenet |
25 | +from networks import grayResNet, grayResNet2 | ||
24 | 26 | ||
25 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 27 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
26 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
... | @@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): | ... | @@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): |
55 | 57 | ||
56 | return Dm_indexes, Da_indexes | 58 | return Dm_indexes, Da_indexes |
57 | 59 | ||
58 | -def split_dataset2222(args, dataset, k): | ||
59 | - # load dataset | ||
60 | - X = list(range(len(dataset))) | ||
61 | - | ||
62 | - # split to k-fold | ||
63 | - #assert len(X) == len(Y) | ||
64 | - | ||
65 | - def _it_to_list(_it): | ||
66 | - return list(zip(*list(_it))) | ||
67 | - | ||
68 | - x_train = () | ||
69 | - x_test = () | ||
70 | - | ||
71 | - for i in range(k): | ||
72 | - #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | ||
73 | - xtr, xte = train_test_split(X, random_state=None, test_size=0.1) | ||
74 | - x_train.append(np.array(xtr)) | ||
75 | - x_test.append(np.array(xte)) | ||
76 | - | ||
77 | - y_train = np.array([0]* len(x_train)) | ||
78 | - y_test = np.array([0]* len(x_test)) | ||
79 | - | ||
80 | - x_train = tuple(x_train) | ||
81 | - x_test = tuple(x_test) | ||
82 | - | ||
83 | - trainset = (zip(x_train, y_train),) | ||
84 | - testset = (zip(x_test, y_test),) | ||
85 | - | ||
86 | - Dm_indexes, Da_indexes = trainset, testset | ||
87 | - | ||
88 | - print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
89 | - print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes) | ||
90 | - | ||
91 | - return Dm_indexes, Da_indexes | ||
92 | 60 | ||
93 | def concat_image_features(image, features, max_features=3): | 61 | def concat_image_features(image, features, max_features=3): |
94 | _, h, w = image.shape | 62 | _, h, w = image.shape |
... | @@ -159,8 +127,22 @@ def parse_args(kwargs): | ... | @@ -159,8 +127,22 @@ def parse_args(kwargs): |
159 | 127 | ||
160 | 128 | ||
161 | def select_model(args): | 129 | def select_model(args): |
162 | - if args.network in models.__dict__: | 130 | + # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), |
163 | - backbone = models.__dict__[args.network]() | 131 | + # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} |
132 | + | ||
133 | + | ||
134 | + # grayResNet2 | ||
135 | + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
136 | + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
137 | + | ||
138 | + if args.network in resnet_dict: | ||
139 | + backbone = resnet_dict[args.network] | ||
140 | + #testing | ||
141 | + # print("\nRESNET50 LAYERS\n") | ||
142 | + # for layer in backbone.children(): | ||
143 | + # print(layer) | ||
144 | + # print("LAYER THE END\n") | ||
145 | + | ||
164 | model = basenet.BaseNet(backbone, args) | 146 | model = basenet.BaseNet(backbone, args) |
165 | else: | 147 | else: |
166 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 148 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ... | ... |
-
Please register or login to post a comment