조현아

run resnet & FAA getBraTS_5

1 +import os
2 +import time
3 +import importlib
4 +import collections
5 +import pickle as cp
6 +import numpy as np
7 +
8 +
9 +import torch
10 +import torchvision
11 +import torch.nn.functional as F
12 +import torchvision.models as models
13 +import torchvision.transforms as transforms
14 +from torch.utils.data import Subset
15 +
16 +from sklearn.model_selection import StratifiedShuffleSplit
17 +
18 +from networks import basenet
19 +from networks import grayResNet
20 +
21 +
22 +DATASET_PATH = './data/'
23 +current_epoch = 0
24 +
25 +
26 +def split_dataset(args, dataset, k):
27 + # load dataset
28 + X = list(range(len(dataset)))
29 + Y = dataset.targets
30 +
31 + # split to k-fold
32 + assert len(X) == len(Y)
33 +
34 + def _it_to_list(_it):
35 + return list(zip(*list(_it)))
36 +
37 + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
38 + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
39 +
40 + return Dm_indexes, Da_indexes
41 +
42 +
43 +def concat_image_features(image, features, max_features=3):
44 + _, h, w = image.shape
45 +
46 + max_features = min(features.size(0), max_features)
47 + image_feature = image.clone()
48 +
49 + for i in range(max_features):
50 + feature = features[i:i+1]
51 + _min, _max = torch.min(feature), torch.max(feature)
52 + feature = (feature - _min) / (_max - _min + 1e-6)
53 + feature = torch.cat([feature]*3, 0)
54 + feature = feature.view(1, 3, feature.size(1), feature.size(2))
55 + feature = F.upsample(feature, size=(h,w), mode="bilinear")
56 + feature = feature.view(3, h, w)
57 + image_feature = torch.cat((image_feature, feature), 2)
58 +
59 + return image_feature
60 +
61 +
62 +def get_model_name(args):
63 + from datetime import datetime
64 + now = datetime.now()
65 + date_time = now.strftime("%B_%d_%H:%M:%S")
66 + model_name = '__'.join([date_time, args.network, str(args.seed)])
67 + return model_name
68 +
69 +
70 +def dict_to_namedtuple(d):
71 + Args = collections.namedtuple('Args', sorted(d.keys()))
72 +
73 + for k,v in d.items():
74 + if type(v) is dict:
75 + d[k] = dict_to_namedtuple(v)
76 +
77 + elif type(v) is str:
78 + try:
79 + d[k] = eval(v)
80 + except:
81 + d[k] = v
82 +
83 + args = Args(**d)
84 + return args
85 +
86 +
87 +def parse_args(kwargs):
88 + # combine with default args
89 + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10'
90 + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10'
91 + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
92 + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1
93 + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
94 + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
95 + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
96 + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
97 + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000
98 + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000
99 + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
100 + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128
101 + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
102 + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000
103 + kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False
104 + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
105 +
106 + # to named tuple
107 + args = dict_to_namedtuple(kwargs)
108 + return args, kwargs
109 +
110 +
111 +def select_model(args):
112 + resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(),
113 + 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()}
114 + #print("args.network: \n", args.network)
115 + if args.network in resnet_dict:
116 + backbone = resnet_dict[args.network]
117 + model = basenet.BaseNet(backbone, args)
118 + else:
119 + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
120 + model = Net(args)
121 +
122 + print(model)
123 + return model
124 +
125 +
126 +def select_optimizer(args, model):
127 + if args.optimizer == 'sgd':
128 + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
129 + elif args.optimizer == 'rms':
130 + #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5)
131 + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
132 + elif args.optimizer == 'adam':
133 + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
134 + else:
135 + raise Exception('Unknown Optimizer')
136 + return optimizer
137 +
138 +
139 +def select_scheduler(args, optimizer):
140 + if not args.scheduler or args.scheduler == 'None':
141 + return None
142 + elif args.scheduler =='clr':
143 + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
144 + elif args.scheduler =='exp':
145 + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
146 + else:
147 + raise Exception('Unknown Scheduler')
148 +
149 +
150 +def get_dataset(args, transform, split='train'):
151 + assert split in ['train', 'val', 'test', 'trainval']
152 +
153 + if args.dataset == 'cifar10':
154 + train = split in ['train', 'val', 'trainval']
155 + dataset = torchvision.datasets.CIFAR10(DATASET_PATH,
156 + train=train,
157 + transform=transform,
158 + download=True)
159 +
160 + if split in ['train', 'val']:
161 + split_path = os.path.join(DATASET_PATH,
162 + 'cifar-10-batches-py', 'train_val_index.cp')
163 +
164 + if not os.path.exists(split_path):
165 + [train_index], [val_index] = split_dataset(args, dataset, k=1)
166 + split_index = {'train':train_index, 'val':val_index}
167 + cp.dump(split_index, open(split_path, 'wb'))
168 +
169 + split_index = cp.load(open(split_path, 'rb'))
170 + dataset = Subset(dataset, split_index[split])
171 +
172 + elif args.dataset == 'imagenet':
173 + dataset = torchvision.datasets.ImageNet(DATASET_PATH,
174 + split=split,
175 + transform=transform,
176 + download=(split is 'val'))
177 +
178 + else:
179 + raise Exception('Unknown dataset')
180 +
181 + return dataset
182 +
183 +
184 +def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
185 + data_loader = torch.utils.data.DataLoader(dataset,
186 + batch_size=args.batch_size,
187 + shuffle=shuffle,
188 + num_workers=args.num_workers,
189 + pin_memory=pin_memory)
190 + return data_loader
191 +
192 +
193 +def get_inf_dataloader(args, dataset):
194 + global current_epoch
195 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
196 +
197 + while True:
198 + try:
199 + batch = next(data_loader)
200 +
201 + except StopIteration:
202 + current_epoch += 1
203 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
204 + batch = next(data_loader)
205 +
206 + yield batch
207 +
208 +
209 +def get_train_transform(args, model, log_dir=None):
210 + if args.fast_auto_augment:
211 + assert args.dataset == 'cifar10' # TODO: FastAutoAugment for Imagenet
212 +
213 + from fast_auto_augment import fast_auto_augment
214 + if args.augment_path:
215 + transform = cp.load(open(args.augment_path, 'rb'))
216 + os.system('cp {} {}'.format(
217 + args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
218 + else:
219 + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4)
220 + if log_dir:
221 + cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
222 +
223 + elif args.dataset == 'cifar10':
224 + transform = transforms.Compose([
225 + transforms.Pad(4),
226 + transforms.RandomCrop(32),
227 + transforms.RandomHorizontalFlip(),
228 + transforms.ToTensor()
229 + ])
230 +
231 + elif args.dataset == 'imagenet':
232 + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
233 + transform = transforms.Compose([
234 + transforms.Resize([resize_h, resize_w]),
235 + transforms.RandomCrop(model.img_size),
236 + transforms.RandomHorizontalFlip(),
237 + transforms.ToTensor()
238 + ])
239 +
240 + else:
241 + raise Exception('Unknown Dataset')
242 +
243 + print(transform)
244 +
245 + return transform
246 +
247 +
248 +def get_valid_transform(args, model):
249 + if args.dataset == 'cifar10':
250 + val_transform = transforms.Compose([
251 + transforms.Resize(32),
252 + transforms.ToTensor()
253 + ])
254 +
255 + elif args.dataset == 'imagenet':
256 + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
257 + val_transform = transforms.Compose([
258 + transforms.Resize([resize_h, resize_w]),
259 + transforms.ToTensor()
260 + ])
261 +
262 + else:
263 + raise Exception('Unknown Dataset')
264 +
265 + return val_transform
266 +
267 +
268 +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
269 + model.train()
270 + images, target = batch
271 +
272 + if device:
273 + images = images.to(device)
274 + target = target.to(device)
275 +
276 + elif args.use_cuda:
277 + images = images.cuda(non_blocking=True)
278 + target = target.cuda(non_blocking=True)
279 +
280 + # compute output
281 + start_t = time.time()
282 + output, first = model(images)
283 + forward_t = time.time() - start_t
284 + loss = criterion(output, target)
285 +
286 + # measure accuracy and record loss
287 + acc1, acc5 = accuracy(output, target, topk=(1, 5))
288 + acc1 /= images.size(0)
289 + acc5 /= images.size(0)
290 +
291 + # compute gradient and do SGD step
292 + optimizer.zero_grad()
293 + start_t = time.time()
294 + loss.backward()
295 + backward_t = time.time() - start_t
296 + optimizer.step()
297 + if scheduler: scheduler.step()
298 +
299 + if writer and step % args.print_step == 0:
300 + n_imgs = min(images.size(0), 10)
301 + for j in range(n_imgs):
302 + writer.add_image('train/input_image',
303 + concat_image_features(images[j], first[j]), global_step=step)
304 +
305 + return acc1, acc5, loss, forward_t, backward_t
306 +
307 +
308 +def validate(args, model, criterion, valid_loader, step, writer, device=None):
309 + # switch to evaluate mode
310 + model.eval()
311 +
312 + acc1, acc5 = 0, 0
313 + samples = 0
314 + infer_t = 0
315 +
316 + with torch.no_grad():
317 + for i, (images, target) in enumerate(valid_loader):
318 +
319 + start_t = time.time()
320 + if device:
321 + images = images.to(device)
322 + target = target.to(device)
323 +
324 + elif args.use_cuda is not None:
325 + images = images.cuda(non_blocking=True)
326 + target = target.cuda(non_blocking=True)
327 +
328 + # compute output
329 + output, first = model(images)
330 + loss = criterion(output, target)
331 + infer_t += time.time() - start_t
332 +
333 + # measure accuracy and record loss
334 + _acc1, _acc5 = accuracy(output, target, topk=(1, 5))
335 + acc1 += _acc1
336 + acc5 += _acc5
337 + samples += images.size(0)
338 +
339 + acc1 /= samples
340 + acc5 /= samples
341 +
342 + if writer:
343 + n_imgs = min(images.size(0), 10)
344 + for j in range(n_imgs):
345 + writer.add_image('valid/input_image',
346 + concat_image_features(images[j], first[j]), global_step=step)
347 +
348 + return acc1, acc5, loss, infer_t
349 +
350 +
351 +def accuracy(output, target, topk=(1,)):
352 + """Computes the accuracy over the k top predictions for the specified values of k"""
353 + with torch.no_grad():
354 + maxk = max(topk)
355 + batch_size = target.size(0)
356 +
357 + _, pred = output.topk(maxk, 1, True, True)
358 + pred = pred.t()
359 + correct = pred.eq(target.view(1, -1).expand_as(pred))
360 +
361 + res = []
362 + for k in topk:
363 + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
364 + res.append(correct_k)
365 + return res
...@@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): ...@@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None):
54 if torch.cuda.device_count() > 1: 54 if torch.cuda.device_count() > 1:
55 print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) 55 print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
56 model = nn.DataParallel(model) 56 model = nn.DataParallel(model)
57 + elif torch.cuda.device_count() == 1:
58 + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
57 59
58 start_t = time.time() 60 start_t = time.time()
59 for step in range(args.start_step, args.max_step): 61 for step in range(args.start_step, args.max_step):
60 batch = next(data_loader) 62 batch = next(data_loader)
63 +
61 _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) 64 _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device)
62 65
63 if step % args.print_step == 0: 66 if step % args.print_step == 0:
...@@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat ...@@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat
173 device = torch.device('cuda:%d' % device_id) 176 device = torch.device('cuda:%d' % device_id)
174 _transform = [] 177 _transform = []
175 178
176 - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) 179 + print('[+] Child %d training started (GPU: %d)' % (k, device_id))
177 180
178 # train child model 181 # train child model
179 child_model = copy.deepcopy(model) 182 child_model = copy.deepcopy(model)
...@@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat ...@@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat
188 191
189 return _transform 192 return _transform
190 193
191 - 194 +#fast_auto_augment(args, model, K=4, B=1, num_process=4)
192 def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): 195 def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5):
193 args_str = json.dumps(args._asdict()) 196 args_str = json.dumps(args._asdict())
194 dataset = get_dataset(args, None, 'trainval') 197 dataset = get_dataset(args, None, 'trainval')
......
...@@ -4,6 +4,12 @@ class BaseNet(nn.Module): ...@@ -4,6 +4,12 @@ class BaseNet(nn.Module):
4 def __init__(self, backbone, args): 4 def __init__(self, backbone, args):
5 super(BaseNet, self).__init__() 5 super(BaseNet, self).__init__()
6 6
7 + #testing
8 + for layer in backbone.children():
9 + print("\nRESNET50 LAYERS\n")
10 + print(layer)
11 +
12 +
7 # Separate layers 13 # Separate layers
8 self.first = nn.Sequential(*list(backbone.children())[:1]) 14 self.first = nn.Sequential(*list(backbone.children())[:1])
9 self.after = nn.Sequential(*list(backbone.children())[1:-1]) 15 self.after = nn.Sequential(*list(backbone.children())[1:-1])
...@@ -14,6 +20,20 @@ class BaseNet(nn.Module): ...@@ -14,6 +20,20 @@ class BaseNet(nn.Module):
14 def forward(self, x): 20 def forward(self, x):
15 f = self.first(x) 21 f = self.first(x)
16 x = self.after(f) 22 x = self.after(f)
17 - x = x.reshape(x.size(0), -1)
18 x = self.fc(x) 23 x = self.fc(x)
19 return x, f 24 return x, f
25 +
26 +
27 +"""
28 + print("before reshape:\n", x.size())
29 + #[128, 2048, 4, 4]
30 + # #cifar 내장[128, 2048, 1, 1]
31 + x = x.reshape(x.size(0), -1)
32 + print("after reshape:\n", x.size())
33 + #[128, 32768]
34 + #cifar [128, 2048]
35 + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10]
36 + print("fc :\n", self.fc)
37 + #Linear(in_features=2048, out_features=10, bias=True)
38 + #cifar Linear(in_features=2048, out_features=1000, bias=True)
39 +"""
......
1 +import torch
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +
5 +
6 +class BasicBlock(nn.Module):
7 + expansion = 1
8 +
9 + def __init__(self, in_planes, planes, stride=1):
10 + super(BasicBlock, self).__init__()
11 + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
12 + self.bn1 = nn.BatchNorm2d(planes)
13 + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
14 + self.bn2 = nn.BatchNorm2d(planes)
15 +
16 + self.shortcut = nn.Sequential()
17 + if stride != 1 or in_planes != self.expansion*planes:
18 + self.shortcut = nn.Sequential(
19 + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
20 + nn.BatchNorm2d(self.expansion*planes)
21 + )
22 +
23 + def forward(self, x):
24 + out = F.relu(self.bn1(self.conv1(x)))
25 + out = self.bn2(self.conv2(out))
26 + out += self.shortcut(x)
27 + out = F.relu(out)
28 + return out
29 +
30 +
31 +class Bottleneck(nn.Module):
32 + expansion = 4
33 +
34 + def __init__(self, in_planes, planes, stride=1):
35 + super(Bottleneck, self).__init__()
36 + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
37 + self.bn1 = nn.BatchNorm2d(planes)
38 + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
39 + self.bn2 = nn.BatchNorm2d(planes)
40 + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
41 + self.bn3 = nn.BatchNorm2d(self.expansion*planes)
42 +
43 + self.shortcut = nn.Sequential()
44 + if stride != 1 or in_planes != self.expansion*planes:
45 + self.shortcut = nn.Sequential(
46 + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
47 + nn.BatchNorm2d(self.expansion*planes)
48 + )
49 +
50 + def forward(self, x):
51 + out = F.relu(self.bn1(self.conv1(x)))
52 + out = F.relu(self.bn2(self.conv2(out)))
53 + out = self.bn3(self.conv3(out))
54 + out += self.shortcut(x)
55 + out = F.relu(out)
56 + return out
57 +
58 +
59 +class ResNet(nn.Module):
60 + def __init__(self, block, num_blocks, num_classes=10):
61 + super(ResNet, self).__init__()
62 + self.in_planes = 64
63 +
64 + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
65 + self.bn1 = nn.BatchNorm2d(64)
66 + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
67 + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
68 + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
69 + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
70 + self.linear = nn.Linear(512*block.expansion, num_classes)
71 +
72 + def _make_layer(self, block, planes, num_blocks, stride):
73 + strides = [stride] + [1]*(num_blocks-1)
74 + layers = []
75 + for stride in strides:
76 + layers.append(block(self.in_planes, planes, stride))
77 + self.in_planes = planes * block.expansion
78 + return nn.Sequential(*layers)
79 +
80 + def forward(self, x):
81 + out = F.relu(self.bn1(self.conv1(x)))
82 + out = self.layer1(out)
83 + out = self.layer2(out)
84 + out = self.layer3(out)
85 + out = self.layer4(out)
86 + out = F.avg_pool2d(out, 4)
87 + out = out.view(out.size(0), -1)
88 + out = self.linear(out)
89 + return out
90 +
91 +
92 +def ResNet18():
93 + return ResNet(BasicBlock, [2,2,2,2])
94 +
95 +def ResNet34():
96 + return ResNet(BasicBlock, [3,4,6,3])
97 +
98 +def ResNet50():
99 + return ResNet(Bottleneck, [3,4,6,3])
100 +
101 +def ResNet101():
102 + return ResNet(Bottleneck, [3,4,23,3])
103 +
104 +def ResNet152():
105 + return ResNet(Bottleneck, [3,8,36,3])
106 +
107 +
108 +def test():
109 + net = ResNet18()
110 + y = net(torch.randn(1,3,32,32))
111 + print(y.size())
1 +import torch
2 +import torch.nn as nn
3 +#from .utils import load_state_dict_from_url
4 +
5 +
6 +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 + 'wide_resnet50_2', 'wide_resnet101_2']
9 +
10 +
11 +model_urls = {
12 + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20 + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21 +}
22 +
23 +
24 +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25 + """3x3 convolution with padding"""
26 + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 + padding=dilation, groups=groups, bias=False, dilation=dilation)
28 +
29 +
30 +def conv1x1(in_planes, out_planes, stride=1):
31 + """1x1 convolution"""
32 + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33 +
34 +
35 +class BasicBlock(nn.Module):
36 + expansion = 1
37 +
38 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39 + base_width=64, dilation=1, norm_layer=None):
40 + super(BasicBlock, self).__init__()
41 + if norm_layer is None:
42 + norm_layer = nn.BatchNorm2d
43 + if groups != 1 or base_width != 64:
44 + raise ValueError('BasicBlock only supports groups=1 and base_width=64')
45 + if dilation > 1:
46 + raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
47 + # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48 + self.conv1 = conv3x3(inplanes, planes, stride)
49 + self.bn1 = norm_layer(planes)
50 + self.relu = nn.ReLU(inplace=True)
51 + self.conv2 = conv3x3(planes, planes)
52 + self.bn2 = norm_layer(planes)
53 + self.downsample = downsample
54 + self.stride = stride
55 +
56 + def forward(self, x):
57 + identity = x
58 +
59 + out = self.conv1(x)
60 + out = self.bn1(out)
61 + out = self.relu(out)
62 +
63 + out = self.conv2(out)
64 + out = self.bn2(out)
65 +
66 + if self.downsample is not None:
67 + identity = self.downsample(x)
68 +
69 + out += identity
70 + out = self.relu(out)
71 +
72 + return out
73 +
74 +
75 +class Bottleneck(nn.Module):
76 + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
77 + # while original implementation places the stride at the first 1x1 convolution(self.conv1)
78 + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
79 + # This variant is also known as ResNet V1.5 and improves accuracy according to
80 + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
81 +
82 + expansion = 4
83 +
84 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
85 + base_width=64, dilation=1, norm_layer=None):
86 + super(Bottleneck, self).__init__()
87 + if norm_layer is None:
88 + norm_layer = nn.BatchNorm2d
89 + width = int(planes * (base_width / 64.)) * groups
90 + # Both self.conv2 and self.downsample layers downsample the input when stride != 1
91 + self.conv1 = conv1x1(inplanes, width)
92 + self.bn1 = norm_layer(width)
93 + self.conv2 = conv3x3(width, width, stride, groups, dilation)
94 + self.bn2 = norm_layer(width)
95 + self.conv3 = conv1x1(width, planes * self.expansion)
96 + self.bn3 = norm_layer(planes * self.expansion)
97 + self.relu = nn.ReLU(inplace=True)
98 + self.downsample = downsample
99 + self.stride = stride
100 +
101 + def forward(self, x):
102 + identity = x
103 +
104 + out = self.conv1(x)
105 + out = self.bn1(out)
106 + out = self.relu(out)
107 +
108 + out = self.conv2(out)
109 + out = self.bn2(out)
110 + out = self.relu(out)
111 +
112 + out = self.conv3(out)
113 + out = self.bn3(out)
114 +
115 + if self.downsample is not None:
116 + identity = self.downsample(x)
117 +
118 + out += identity
119 + out = self.relu(out)
120 +
121 + return out
122 +
123 +
124 +class ResNet(nn.Module):
125 +
126 + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
127 + groups=1, width_per_group=64, replace_stride_with_dilation=None,
128 + norm_layer=None):
129 + super(ResNet, self).__init__()
130 + if norm_layer is None:
131 + norm_layer = nn.BatchNorm2d
132 + self._norm_layer = norm_layer
133 +
134 + self.inplanes = 64
135 + self.dilation = 1
136 + if replace_stride_with_dilation is None:
137 + # each element in the tuple indicates if we should replace
138 + # the 2x2 stride with a dilated convolution instead
139 + replace_stride_with_dilation = [False, False, False]
140 + if len(replace_stride_with_dilation) != 3:
141 + raise ValueError("replace_stride_with_dilation should be None "
142 + "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
143 + self.groups = groups
144 + self.base_width = width_per_group
145 + # change dimension 3->1 for grayscale input
146 + self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
147 + bias=False)
148 + self.bn1 = norm_layer(self.inplanes)
149 + self.relu = nn.ReLU(inplace=True)
150 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151 + self.layer1 = self._make_layer(block, 64, layers[0])
152 + self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
153 + dilate=replace_stride_with_dilation[0])
154 + self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
155 + dilate=replace_stride_with_dilation[1])
156 + self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
157 + dilate=replace_stride_with_dilation[2])
158 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
159 + self.fc = nn.Linear(512 * block.expansion, num_classes)
160 +
161 + for m in self.modules():
162 + if isinstance(m, nn.Conv2d):
163 + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
164 + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
165 + nn.init.constant_(m.weight, 1)
166 + nn.init.constant_(m.bias, 0)
167 +
168 + # Zero-initialize the last BN in each residual branch,
169 + # so that the residual branch starts with zeros, and each residual block behaves like an identity.
170 + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
171 + if zero_init_residual:
172 + for m in self.modules():
173 + if isinstance(m, Bottleneck):
174 + nn.init.constant_(m.bn3.weight, 0)
175 + elif isinstance(m, BasicBlock):
176 + nn.init.constant_(m.bn2.weight, 0)
177 +
178 + def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
179 + norm_layer = self._norm_layer
180 + downsample = None
181 + previous_dilation = self.dilation
182 + if dilate:
183 + self.dilation *= stride
184 + stride = 1
185 + if stride != 1 or self.inplanes != planes * block.expansion:
186 + downsample = nn.Sequential(
187 + conv1x1(self.inplanes, planes * block.expansion, stride),
188 + norm_layer(planes * block.expansion),
189 + )
190 +
191 + layers = []
192 + layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
193 + self.base_width, previous_dilation, norm_layer))
194 + self.inplanes = planes * block.expansion
195 + for _ in range(1, blocks):
196 + layers.append(block(self.inplanes, planes, groups=self.groups,
197 + base_width=self.base_width, dilation=self.dilation,
198 + norm_layer=norm_layer))
199 +
200 + return nn.Sequential(*layers)
201 +
202 + def _forward_impl(self, x):
203 + # See note [TorchScript super()]
204 + x = self.conv1(x)
205 + x = self.bn1(x)
206 + x = self.relu(x)
207 + x = self.maxpool(x)
208 +
209 + x = self.layer1(x)
210 + x = self.layer2(x)
211 + x = self.layer3(x)
212 + x = self.layer4(x)
213 +
214 + x = self.avgpool(x)
215 + x = torch.flatten(x, 1)
216 + x = self.fc(x)
217 +
218 + return x
219 +
220 + def forward(self, x):
221 + return self._forward_impl(x)
222 +
223 +
224 +def _resnet(arch, block, layers, pretrained, progress, **kwargs):
225 + model = ResNet(block, layers, **kwargs)
226 + # if pretrained:
227 + # state_dict = load_state_dict_from_url(model_urls[arch],
228 + # progress=progress)
229 + # model.load_state_dict(state_dict)
230 + return model
231 +
232 +
233 +def resnet18(pretrained=False, progress=True, **kwargs):
234 + r"""ResNet-18 model from
235 + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
236 + Args:
237 + pretrained (bool): If True, returns a model pre-trained on ImageNet
238 + progress (bool): If True, displays a progress bar of the download to stderr
239 + """
240 + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
241 + **kwargs)
242 +
243 +
244 +def resnet34(pretrained=False, progress=True, **kwargs):
245 + r"""ResNet-34 model from
246 + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
247 + Args:
248 + pretrained (bool): If True, returns a model pre-trained on ImageNet
249 + progress (bool): If True, displays a progress bar of the download to stderr
250 + """
251 + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
252 + **kwargs)
253 +
254 +
255 +def resnet50(pretrained=False, progress=True, **kwargs):
256 + r"""ResNet-50 model from
257 + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
258 + Args:
259 + pretrained (bool): If True, returns a model pre-trained on ImageNet
260 + progress (bool): If True, displays a progress bar of the download to stderr
261 + """
262 + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
263 + **kwargs)
264 +
265 +
266 +def resnet101(pretrained=False, progress=True, **kwargs):
267 + r"""ResNet-101 model from
268 + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
269 + Args:
270 + pretrained (bool): If True, returns a model pre-trained on ImageNet
271 + progress (bool): If True, displays a progress bar of the download to stderr
272 + """
273 + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
274 + **kwargs)
275 +
276 +
277 +def resnet152(pretrained=False, progress=True, **kwargs):
278 + r"""ResNet-152 model from
279 + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
280 + Args:
281 + pretrained (bool): If True, returns a model pre-trained on ImageNet
282 + progress (bool): If True, displays a progress bar of the download to stderr
283 + """
284 + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
285 + **kwargs)
286 +
287 +
288 +def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
289 + r"""ResNeXt-50 32x4d model from
290 + `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
291 + Args:
292 + pretrained (bool): If True, returns a model pre-trained on ImageNet
293 + progress (bool): If True, displays a progress bar of the download to stderr
294 + """
295 + kwargs['groups'] = 32
296 + kwargs['width_per_group'] = 4
297 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
298 + pretrained, progress, **kwargs)
299 +
300 +
301 +def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
302 + r"""ResNeXt-101 32x8d model from
303 + `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
304 + Args:
305 + pretrained (bool): If True, returns a model pre-trained on ImageNet
306 + progress (bool): If True, displays a progress bar of the download to stderr
307 + """
308 + kwargs['groups'] = 32
309 + kwargs['width_per_group'] = 8
310 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
311 + pretrained, progress, **kwargs)
312 +
313 +
314 +def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
315 + r"""Wide ResNet-50-2 model from
316 + `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
317 + The model is the same as ResNet except for the bottleneck number of channels
318 + which is twice larger in every block. The number of channels in outer 1x1
319 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
320 + channels, and in Wide ResNet-50-2 has 2048-1024-2048.
321 + Args:
322 + pretrained (bool): If True, returns a model pre-trained on ImageNet
323 + progress (bool): If True, displays a progress bar of the download to stderr
324 + """
325 + kwargs['width_per_group'] = 64 * 2
326 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
327 + pretrained, progress, **kwargs)
328 +
329 +
330 +def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
331 + r"""Wide ResNet-101-2 model from
332 + `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
333 + The model is the same as ResNet except for the bottleneck number of channels
334 + which is twice larger in every block. The number of channels in outer 1x1
335 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
336 + channels, and in Wide ResNet-50-2 has 2048-1024-2048.
337 + Args:
338 + pretrained (bool): If True, returns a model pre-trained on ImageNet
339 + progress (bool): If True, displays a progress bar of the download to stderr
340 + """
341 + kwargs['width_per_group'] = 64 * 2
342 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
343 + pretrained, progress, **kwargs)
...@@ -6,6 +6,7 @@ import pickle as cp ...@@ -6,6 +6,7 @@ import pickle as cp
6 import glob 6 import glob
7 import numpy as np 7 import numpy as np
8 import pandas as pd 8 import pandas as pd
9 +
9 from natsort import natsorted 10 from natsort import natsorted
10 from PIL import Image 11 from PIL import Image
11 import torch 12 import torch
...@@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split ...@@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split
21 from sklearn.model_selection import KFold 22 from sklearn.model_selection import KFold
22 23
23 from networks import basenet 24 from networks import basenet
25 +from networks import grayResNet, grayResNet2
24 26
25 DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' 27 DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
26 TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' 28 TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
...@@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): ...@@ -55,40 +57,6 @@ def split_dataset(args, dataset, k):
55 57
56 return Dm_indexes, Da_indexes 58 return Dm_indexes, Da_indexes
57 59
58 -def split_dataset2222(args, dataset, k):
59 - # load dataset
60 - X = list(range(len(dataset)))
61 -
62 - # split to k-fold
63 - #assert len(X) == len(Y)
64 -
65 - def _it_to_list(_it):
66 - return list(zip(*list(_it)))
67 -
68 - x_train = ()
69 - x_test = ()
70 -
71 - for i in range(k):
72 - #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1)
73 - xtr, xte = train_test_split(X, random_state=None, test_size=0.1)
74 - x_train.append(np.array(xtr))
75 - x_test.append(np.array(xte))
76 -
77 - y_train = np.array([0]* len(x_train))
78 - y_test = np.array([0]* len(x_test))
79 -
80 - x_train = tuple(x_train)
81 - x_test = tuple(x_test)
82 -
83 - trainset = (zip(x_train, y_train),)
84 - testset = (zip(x_test, y_test),)
85 -
86 - Dm_indexes, Da_indexes = trainset, testset
87 -
88 - print(type(Dm_indexes), np.shape(Dm_indexes))
89 - print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes)
90 -
91 - return Dm_indexes, Da_indexes
92 60
93 def concat_image_features(image, features, max_features=3): 61 def concat_image_features(image, features, max_features=3):
94 _, h, w = image.shape 62 _, h, w = image.shape
...@@ -159,8 +127,22 @@ def parse_args(kwargs): ...@@ -159,8 +127,22 @@ def parse_args(kwargs):
159 127
160 128
161 def select_model(args): 129 def select_model(args):
162 - if args.network in models.__dict__: 130 + # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(),
163 - backbone = models.__dict__[args.network]() 131 + # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()}
132 +
133 +
134 + # grayResNet2
135 + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
136 + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
137 +
138 + if args.network in resnet_dict:
139 + backbone = resnet_dict[args.network]
140 + #testing
141 + # print("\nRESNET50 LAYERS\n")
142 + # for layer in backbone.children():
143 + # print(layer)
144 + # print("LAYER THE END\n")
145 +
164 model = basenet.BaseNet(backbone, args) 146 model = basenet.BaseNet(backbone, args)
165 else: 147 else:
166 Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') 148 Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
......