Showing
6 changed files
with
863 additions
and
39 deletions
code/FAA2/cifar_utils.py
0 → 100644
| 1 | +import os | ||
| 2 | +import time | ||
| 3 | +import importlib | ||
| 4 | +import collections | ||
| 5 | +import pickle as cp | ||
| 6 | +import numpy as np | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +import torch | ||
| 10 | +import torchvision | ||
| 11 | +import torch.nn.functional as F | ||
| 12 | +import torchvision.models as models | ||
| 13 | +import torchvision.transforms as transforms | ||
| 14 | +from torch.utils.data import Subset | ||
| 15 | + | ||
| 16 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
| 17 | + | ||
| 18 | +from networks import basenet | ||
| 19 | +from networks import grayResNet | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +DATASET_PATH = './data/' | ||
| 23 | +current_epoch = 0 | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +def split_dataset(args, dataset, k): | ||
| 27 | + # load dataset | ||
| 28 | + X = list(range(len(dataset))) | ||
| 29 | + Y = dataset.targets | ||
| 30 | + | ||
| 31 | + # split to k-fold | ||
| 32 | + assert len(X) == len(Y) | ||
| 33 | + | ||
| 34 | + def _it_to_list(_it): | ||
| 35 | + return list(zip(*list(_it))) | ||
| 36 | + | ||
| 37 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
| 38 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
| 39 | + | ||
| 40 | + return Dm_indexes, Da_indexes | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +def concat_image_features(image, features, max_features=3): | ||
| 44 | + _, h, w = image.shape | ||
| 45 | + | ||
| 46 | + max_features = min(features.size(0), max_features) | ||
| 47 | + image_feature = image.clone() | ||
| 48 | + | ||
| 49 | + for i in range(max_features): | ||
| 50 | + feature = features[i:i+1] | ||
| 51 | + _min, _max = torch.min(feature), torch.max(feature) | ||
| 52 | + feature = (feature - _min) / (_max - _min + 1e-6) | ||
| 53 | + feature = torch.cat([feature]*3, 0) | ||
| 54 | + feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
| 55 | + feature = F.upsample(feature, size=(h,w), mode="bilinear") | ||
| 56 | + feature = feature.view(3, h, w) | ||
| 57 | + image_feature = torch.cat((image_feature, feature), 2) | ||
| 58 | + | ||
| 59 | + return image_feature | ||
| 60 | + | ||
| 61 | + | ||
| 62 | +def get_model_name(args): | ||
| 63 | + from datetime import datetime | ||
| 64 | + now = datetime.now() | ||
| 65 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
| 66 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
| 67 | + return model_name | ||
| 68 | + | ||
| 69 | + | ||
| 70 | +def dict_to_namedtuple(d): | ||
| 71 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
| 72 | + | ||
| 73 | + for k,v in d.items(): | ||
| 74 | + if type(v) is dict: | ||
| 75 | + d[k] = dict_to_namedtuple(v) | ||
| 76 | + | ||
| 77 | + elif type(v) is str: | ||
| 78 | + try: | ||
| 79 | + d[k] = eval(v) | ||
| 80 | + except: | ||
| 81 | + d[k] = v | ||
| 82 | + | ||
| 83 | + args = Args(**d) | ||
| 84 | + return args | ||
| 85 | + | ||
| 86 | + | ||
| 87 | +def parse_args(kwargs): | ||
| 88 | + # combine with default args | ||
| 89 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10' | ||
| 90 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10' | ||
| 91 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
| 92 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1 | ||
| 93 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
| 94 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
| 95 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
| 96 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
| 97 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000 | ||
| 98 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000 | ||
| 99 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
| 100 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
| 101 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
| 102 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000 | ||
| 103 | + kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | ||
| 104 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
| 105 | + | ||
| 106 | + # to named tuple | ||
| 107 | + args = dict_to_namedtuple(kwargs) | ||
| 108 | + return args, kwargs | ||
| 109 | + | ||
| 110 | + | ||
| 111 | +def select_model(args): | ||
| 112 | + resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), | ||
| 113 | + 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} | ||
| 114 | + #print("args.network: \n", args.network) | ||
| 115 | + if args.network in resnet_dict: | ||
| 116 | + backbone = resnet_dict[args.network] | ||
| 117 | + model = basenet.BaseNet(backbone, args) | ||
| 118 | + else: | ||
| 119 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
| 120 | + model = Net(args) | ||
| 121 | + | ||
| 122 | + print(model) | ||
| 123 | + return model | ||
| 124 | + | ||
| 125 | + | ||
| 126 | +def select_optimizer(args, model): | ||
| 127 | + if args.optimizer == 'sgd': | ||
| 128 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
| 129 | + elif args.optimizer == 'rms': | ||
| 130 | + #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5) | ||
| 131 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
| 132 | + elif args.optimizer == 'adam': | ||
| 133 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
| 134 | + else: | ||
| 135 | + raise Exception('Unknown Optimizer') | ||
| 136 | + return optimizer | ||
| 137 | + | ||
| 138 | + | ||
| 139 | +def select_scheduler(args, optimizer): | ||
| 140 | + if not args.scheduler or args.scheduler == 'None': | ||
| 141 | + return None | ||
| 142 | + elif args.scheduler =='clr': | ||
| 143 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
| 144 | + elif args.scheduler =='exp': | ||
| 145 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
| 146 | + else: | ||
| 147 | + raise Exception('Unknown Scheduler') | ||
| 148 | + | ||
| 149 | + | ||
| 150 | +def get_dataset(args, transform, split='train'): | ||
| 151 | + assert split in ['train', 'val', 'test', 'trainval'] | ||
| 152 | + | ||
| 153 | + if args.dataset == 'cifar10': | ||
| 154 | + train = split in ['train', 'val', 'trainval'] | ||
| 155 | + dataset = torchvision.datasets.CIFAR10(DATASET_PATH, | ||
| 156 | + train=train, | ||
| 157 | + transform=transform, | ||
| 158 | + download=True) | ||
| 159 | + | ||
| 160 | + if split in ['train', 'val']: | ||
| 161 | + split_path = os.path.join(DATASET_PATH, | ||
| 162 | + 'cifar-10-batches-py', 'train_val_index.cp') | ||
| 163 | + | ||
| 164 | + if not os.path.exists(split_path): | ||
| 165 | + [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
| 166 | + split_index = {'train':train_index, 'val':val_index} | ||
| 167 | + cp.dump(split_index, open(split_path, 'wb')) | ||
| 168 | + | ||
| 169 | + split_index = cp.load(open(split_path, 'rb')) | ||
| 170 | + dataset = Subset(dataset, split_index[split]) | ||
| 171 | + | ||
| 172 | + elif args.dataset == 'imagenet': | ||
| 173 | + dataset = torchvision.datasets.ImageNet(DATASET_PATH, | ||
| 174 | + split=split, | ||
| 175 | + transform=transform, | ||
| 176 | + download=(split is 'val')) | ||
| 177 | + | ||
| 178 | + else: | ||
| 179 | + raise Exception('Unknown dataset') | ||
| 180 | + | ||
| 181 | + return dataset | ||
| 182 | + | ||
| 183 | + | ||
| 184 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 185 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
| 186 | + batch_size=args.batch_size, | ||
| 187 | + shuffle=shuffle, | ||
| 188 | + num_workers=args.num_workers, | ||
| 189 | + pin_memory=pin_memory) | ||
| 190 | + return data_loader | ||
| 191 | + | ||
| 192 | + | ||
| 193 | +def get_inf_dataloader(args, dataset): | ||
| 194 | + global current_epoch | ||
| 195 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 196 | + | ||
| 197 | + while True: | ||
| 198 | + try: | ||
| 199 | + batch = next(data_loader) | ||
| 200 | + | ||
| 201 | + except StopIteration: | ||
| 202 | + current_epoch += 1 | ||
| 203 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 204 | + batch = next(data_loader) | ||
| 205 | + | ||
| 206 | + yield batch | ||
| 207 | + | ||
| 208 | + | ||
| 209 | +def get_train_transform(args, model, log_dir=None): | ||
| 210 | + if args.fast_auto_augment: | ||
| 211 | + assert args.dataset == 'cifar10' # TODO: FastAutoAugment for Imagenet | ||
| 212 | + | ||
| 213 | + from fast_auto_augment import fast_auto_augment | ||
| 214 | + if args.augment_path: | ||
| 215 | + transform = cp.load(open(args.augment_path, 'rb')) | ||
| 216 | + os.system('cp {} {}'.format( | ||
| 217 | + args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | ||
| 218 | + else: | ||
| 219 | + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | ||
| 220 | + if log_dir: | ||
| 221 | + cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | ||
| 222 | + | ||
| 223 | + elif args.dataset == 'cifar10': | ||
| 224 | + transform = transforms.Compose([ | ||
| 225 | + transforms.Pad(4), | ||
| 226 | + transforms.RandomCrop(32), | ||
| 227 | + transforms.RandomHorizontalFlip(), | ||
| 228 | + transforms.ToTensor() | ||
| 229 | + ]) | ||
| 230 | + | ||
| 231 | + elif args.dataset == 'imagenet': | ||
| 232 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
| 233 | + transform = transforms.Compose([ | ||
| 234 | + transforms.Resize([resize_h, resize_w]), | ||
| 235 | + transforms.RandomCrop(model.img_size), | ||
| 236 | + transforms.RandomHorizontalFlip(), | ||
| 237 | + transforms.ToTensor() | ||
| 238 | + ]) | ||
| 239 | + | ||
| 240 | + else: | ||
| 241 | + raise Exception('Unknown Dataset') | ||
| 242 | + | ||
| 243 | + print(transform) | ||
| 244 | + | ||
| 245 | + return transform | ||
| 246 | + | ||
| 247 | + | ||
| 248 | +def get_valid_transform(args, model): | ||
| 249 | + if args.dataset == 'cifar10': | ||
| 250 | + val_transform = transforms.Compose([ | ||
| 251 | + transforms.Resize(32), | ||
| 252 | + transforms.ToTensor() | ||
| 253 | + ]) | ||
| 254 | + | ||
| 255 | + elif args.dataset == 'imagenet': | ||
| 256 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
| 257 | + val_transform = transforms.Compose([ | ||
| 258 | + transforms.Resize([resize_h, resize_w]), | ||
| 259 | + transforms.ToTensor() | ||
| 260 | + ]) | ||
| 261 | + | ||
| 262 | + else: | ||
| 263 | + raise Exception('Unknown Dataset') | ||
| 264 | + | ||
| 265 | + return val_transform | ||
| 266 | + | ||
| 267 | + | ||
| 268 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
| 269 | + model.train() | ||
| 270 | + images, target = batch | ||
| 271 | + | ||
| 272 | + if device: | ||
| 273 | + images = images.to(device) | ||
| 274 | + target = target.to(device) | ||
| 275 | + | ||
| 276 | + elif args.use_cuda: | ||
| 277 | + images = images.cuda(non_blocking=True) | ||
| 278 | + target = target.cuda(non_blocking=True) | ||
| 279 | + | ||
| 280 | + # compute output | ||
| 281 | + start_t = time.time() | ||
| 282 | + output, first = model(images) | ||
| 283 | + forward_t = time.time() - start_t | ||
| 284 | + loss = criterion(output, target) | ||
| 285 | + | ||
| 286 | + # measure accuracy and record loss | ||
| 287 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 288 | + acc1 /= images.size(0) | ||
| 289 | + acc5 /= images.size(0) | ||
| 290 | + | ||
| 291 | + # compute gradient and do SGD step | ||
| 292 | + optimizer.zero_grad() | ||
| 293 | + start_t = time.time() | ||
| 294 | + loss.backward() | ||
| 295 | + backward_t = time.time() - start_t | ||
| 296 | + optimizer.step() | ||
| 297 | + if scheduler: scheduler.step() | ||
| 298 | + | ||
| 299 | + if writer and step % args.print_step == 0: | ||
| 300 | + n_imgs = min(images.size(0), 10) | ||
| 301 | + for j in range(n_imgs): | ||
| 302 | + writer.add_image('train/input_image', | ||
| 303 | + concat_image_features(images[j], first[j]), global_step=step) | ||
| 304 | + | ||
| 305 | + return acc1, acc5, loss, forward_t, backward_t | ||
| 306 | + | ||
| 307 | + | ||
| 308 | +def validate(args, model, criterion, valid_loader, step, writer, device=None): | ||
| 309 | + # switch to evaluate mode | ||
| 310 | + model.eval() | ||
| 311 | + | ||
| 312 | + acc1, acc5 = 0, 0 | ||
| 313 | + samples = 0 | ||
| 314 | + infer_t = 0 | ||
| 315 | + | ||
| 316 | + with torch.no_grad(): | ||
| 317 | + for i, (images, target) in enumerate(valid_loader): | ||
| 318 | + | ||
| 319 | + start_t = time.time() | ||
| 320 | + if device: | ||
| 321 | + images = images.to(device) | ||
| 322 | + target = target.to(device) | ||
| 323 | + | ||
| 324 | + elif args.use_cuda is not None: | ||
| 325 | + images = images.cuda(non_blocking=True) | ||
| 326 | + target = target.cuda(non_blocking=True) | ||
| 327 | + | ||
| 328 | + # compute output | ||
| 329 | + output, first = model(images) | ||
| 330 | + loss = criterion(output, target) | ||
| 331 | + infer_t += time.time() - start_t | ||
| 332 | + | ||
| 333 | + # measure accuracy and record loss | ||
| 334 | + _acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 335 | + acc1 += _acc1 | ||
| 336 | + acc5 += _acc5 | ||
| 337 | + samples += images.size(0) | ||
| 338 | + | ||
| 339 | + acc1 /= samples | ||
| 340 | + acc5 /= samples | ||
| 341 | + | ||
| 342 | + if writer: | ||
| 343 | + n_imgs = min(images.size(0), 10) | ||
| 344 | + for j in range(n_imgs): | ||
| 345 | + writer.add_image('valid/input_image', | ||
| 346 | + concat_image_features(images[j], first[j]), global_step=step) | ||
| 347 | + | ||
| 348 | + return acc1, acc5, loss, infer_t | ||
| 349 | + | ||
| 350 | + | ||
| 351 | +def accuracy(output, target, topk=(1,)): | ||
| 352 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
| 353 | + with torch.no_grad(): | ||
| 354 | + maxk = max(topk) | ||
| 355 | + batch_size = target.size(0) | ||
| 356 | + | ||
| 357 | + _, pred = output.topk(maxk, 1, True, True) | ||
| 358 | + pred = pred.t() | ||
| 359 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 360 | + | ||
| 361 | + res = [] | ||
| 362 | + for k in topk: | ||
| 363 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
| 364 | + res.append(correct_k) | ||
| 365 | + return res |
| ... | @@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): | ... | @@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): |
| 54 | if torch.cuda.device_count() > 1: | 54 | if torch.cuda.device_count() > 1: |
| 55 | print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | 55 | print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) |
| 56 | model = nn.DataParallel(model) | 56 | model = nn.DataParallel(model) |
| 57 | + elif torch.cuda.device_count() == 1: | ||
| 58 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 57 | 59 | ||
| 58 | start_t = time.time() | 60 | start_t = time.time() |
| 59 | for step in range(args.start_step, args.max_step): | 61 | for step in range(args.start_step, args.max_step): |
| 60 | batch = next(data_loader) | 62 | batch = next(data_loader) |
| 63 | + | ||
| 61 | _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | 64 | _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) |
| 62 | 65 | ||
| 63 | if step % args.print_step == 0: | 66 | if step % args.print_step == 0: |
| ... | @@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
| 173 | device = torch.device('cuda:%d' % device_id) | 176 | device = torch.device('cuda:%d' % device_id) |
| 174 | _transform = [] | 177 | _transform = [] |
| 175 | 178 | ||
| 176 | - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | 179 | + print('[+] Child %d training started (GPU: %d)' % (k, device_id)) |
| 177 | 180 | ||
| 178 | # train child model | 181 | # train child model |
| 179 | child_model = copy.deepcopy(model) | 182 | child_model = copy.deepcopy(model) |
| ... | @@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
| 188 | 191 | ||
| 189 | return _transform | 192 | return _transform |
| 190 | 193 | ||
| 191 | - | 194 | +#fast_auto_augment(args, model, K=4, B=1, num_process=4) |
| 192 | def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | 195 | def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): |
| 193 | args_str = json.dumps(args._asdict()) | 196 | args_str = json.dumps(args._asdict()) |
| 194 | dataset = get_dataset(args, None, 'trainval') | 197 | dataset = get_dataset(args, None, 'trainval') | ... | ... |
| ... | @@ -4,6 +4,12 @@ class BaseNet(nn.Module): | ... | @@ -4,6 +4,12 @@ class BaseNet(nn.Module): |
| 4 | def __init__(self, backbone, args): | 4 | def __init__(self, backbone, args): |
| 5 | super(BaseNet, self).__init__() | 5 | super(BaseNet, self).__init__() |
| 6 | 6 | ||
| 7 | + #testing | ||
| 8 | + for layer in backbone.children(): | ||
| 9 | + print("\nRESNET50 LAYERS\n") | ||
| 10 | + print(layer) | ||
| 11 | + | ||
| 12 | + | ||
| 7 | # Separate layers | 13 | # Separate layers |
| 8 | self.first = nn.Sequential(*list(backbone.children())[:1]) | 14 | self.first = nn.Sequential(*list(backbone.children())[:1]) |
| 9 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) | 15 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) |
| ... | @@ -14,6 +20,20 @@ class BaseNet(nn.Module): | ... | @@ -14,6 +20,20 @@ class BaseNet(nn.Module): |
| 14 | def forward(self, x): | 20 | def forward(self, x): |
| 15 | f = self.first(x) | 21 | f = self.first(x) |
| 16 | x = self.after(f) | 22 | x = self.after(f) |
| 17 | - x = x.reshape(x.size(0), -1) | ||
| 18 | x = self.fc(x) | 23 | x = self.fc(x) |
| 19 | return x, f | 24 | return x, f |
| 25 | + | ||
| 26 | + | ||
| 27 | +""" | ||
| 28 | + print("before reshape:\n", x.size()) | ||
| 29 | + #[128, 2048, 4, 4] | ||
| 30 | + # #cifar 내장[128, 2048, 1, 1] | ||
| 31 | + x = x.reshape(x.size(0), -1) | ||
| 32 | + print("after reshape:\n", x.size()) | ||
| 33 | + #[128, 32768] | ||
| 34 | + #cifar [128, 2048] | ||
| 35 | + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10] | ||
| 36 | + print("fc :\n", self.fc) | ||
| 37 | + #Linear(in_features=2048, out_features=10, bias=True) | ||
| 38 | + #cifar Linear(in_features=2048, out_features=1000, bias=True) | ||
| 39 | +""" | ... | ... |
code/FAA2/networks/grayResNet.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torch.nn.functional as F | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +class BasicBlock(nn.Module): | ||
| 7 | + expansion = 1 | ||
| 8 | + | ||
| 9 | + def __init__(self, in_planes, planes, stride=1): | ||
| 10 | + super(BasicBlock, self).__init__() | ||
| 11 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
| 12 | + self.bn1 = nn.BatchNorm2d(planes) | ||
| 13 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
| 14 | + self.bn2 = nn.BatchNorm2d(planes) | ||
| 15 | + | ||
| 16 | + self.shortcut = nn.Sequential() | ||
| 17 | + if stride != 1 or in_planes != self.expansion*planes: | ||
| 18 | + self.shortcut = nn.Sequential( | ||
| 19 | + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
| 20 | + nn.BatchNorm2d(self.expansion*planes) | ||
| 21 | + ) | ||
| 22 | + | ||
| 23 | + def forward(self, x): | ||
| 24 | + out = F.relu(self.bn1(self.conv1(x))) | ||
| 25 | + out = self.bn2(self.conv2(out)) | ||
| 26 | + out += self.shortcut(x) | ||
| 27 | + out = F.relu(out) | ||
| 28 | + return out | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +class Bottleneck(nn.Module): | ||
| 32 | + expansion = 4 | ||
| 33 | + | ||
| 34 | + def __init__(self, in_planes, planes, stride=1): | ||
| 35 | + super(Bottleneck, self).__init__() | ||
| 36 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | ||
| 37 | + self.bn1 = nn.BatchNorm2d(planes) | ||
| 38 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
| 39 | + self.bn2 = nn.BatchNorm2d(planes) | ||
| 40 | + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | ||
| 41 | + self.bn3 = nn.BatchNorm2d(self.expansion*planes) | ||
| 42 | + | ||
| 43 | + self.shortcut = nn.Sequential() | ||
| 44 | + if stride != 1 or in_planes != self.expansion*planes: | ||
| 45 | + self.shortcut = nn.Sequential( | ||
| 46 | + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
| 47 | + nn.BatchNorm2d(self.expansion*planes) | ||
| 48 | + ) | ||
| 49 | + | ||
| 50 | + def forward(self, x): | ||
| 51 | + out = F.relu(self.bn1(self.conv1(x))) | ||
| 52 | + out = F.relu(self.bn2(self.conv2(out))) | ||
| 53 | + out = self.bn3(self.conv3(out)) | ||
| 54 | + out += self.shortcut(x) | ||
| 55 | + out = F.relu(out) | ||
| 56 | + return out | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +class ResNet(nn.Module): | ||
| 60 | + def __init__(self, block, num_blocks, num_classes=10): | ||
| 61 | + super(ResNet, self).__init__() | ||
| 62 | + self.in_planes = 64 | ||
| 63 | + | ||
| 64 | + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
| 65 | + self.bn1 = nn.BatchNorm2d(64) | ||
| 66 | + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | ||
| 67 | + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | ||
| 68 | + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | ||
| 69 | + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | ||
| 70 | + self.linear = nn.Linear(512*block.expansion, num_classes) | ||
| 71 | + | ||
| 72 | + def _make_layer(self, block, planes, num_blocks, stride): | ||
| 73 | + strides = [stride] + [1]*(num_blocks-1) | ||
| 74 | + layers = [] | ||
| 75 | + for stride in strides: | ||
| 76 | + layers.append(block(self.in_planes, planes, stride)) | ||
| 77 | + self.in_planes = planes * block.expansion | ||
| 78 | + return nn.Sequential(*layers) | ||
| 79 | + | ||
| 80 | + def forward(self, x): | ||
| 81 | + out = F.relu(self.bn1(self.conv1(x))) | ||
| 82 | + out = self.layer1(out) | ||
| 83 | + out = self.layer2(out) | ||
| 84 | + out = self.layer3(out) | ||
| 85 | + out = self.layer4(out) | ||
| 86 | + out = F.avg_pool2d(out, 4) | ||
| 87 | + out = out.view(out.size(0), -1) | ||
| 88 | + out = self.linear(out) | ||
| 89 | + return out | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +def ResNet18(): | ||
| 93 | + return ResNet(BasicBlock, [2,2,2,2]) | ||
| 94 | + | ||
| 95 | +def ResNet34(): | ||
| 96 | + return ResNet(BasicBlock, [3,4,6,3]) | ||
| 97 | + | ||
| 98 | +def ResNet50(): | ||
| 99 | + return ResNet(Bottleneck, [3,4,6,3]) | ||
| 100 | + | ||
| 101 | +def ResNet101(): | ||
| 102 | + return ResNet(Bottleneck, [3,4,23,3]) | ||
| 103 | + | ||
| 104 | +def ResNet152(): | ||
| 105 | + return ResNet(Bottleneck, [3,8,36,3]) | ||
| 106 | + | ||
| 107 | + | ||
| 108 | +def test(): | ||
| 109 | + net = ResNet18() | ||
| 110 | + y = net(torch.randn(1,3,32,32)) | ||
| 111 | + print(y.size()) |
code/FAA2/networks/grayResNet2.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +#from .utils import load_state_dict_from_url | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', | ||
| 7 | + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', | ||
| 8 | + 'wide_resnet50_2', 'wide_resnet101_2'] | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +model_urls = { | ||
| 12 | + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||
| 13 | + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||
| 14 | + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||
| 15 | + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||
| 16 | + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||
| 17 | + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', | ||
| 18 | + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', | ||
| 19 | + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', | ||
| 20 | + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||
| 25 | + """3x3 convolution with padding""" | ||
| 26 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
| 27 | + padding=dilation, groups=groups, bias=False, dilation=dilation) | ||
| 28 | + | ||
| 29 | + | ||
| 30 | +def conv1x1(in_planes, out_planes, stride=1): | ||
| 31 | + """1x1 convolution""" | ||
| 32 | + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||
| 33 | + | ||
| 34 | + | ||
| 35 | +class BasicBlock(nn.Module): | ||
| 36 | + expansion = 1 | ||
| 37 | + | ||
| 38 | + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||
| 39 | + base_width=64, dilation=1, norm_layer=None): | ||
| 40 | + super(BasicBlock, self).__init__() | ||
| 41 | + if norm_layer is None: | ||
| 42 | + norm_layer = nn.BatchNorm2d | ||
| 43 | + if groups != 1 or base_width != 64: | ||
| 44 | + raise ValueError('BasicBlock only supports groups=1 and base_width=64') | ||
| 45 | + if dilation > 1: | ||
| 46 | + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | ||
| 47 | + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||
| 48 | + self.conv1 = conv3x3(inplanes, planes, stride) | ||
| 49 | + self.bn1 = norm_layer(planes) | ||
| 50 | + self.relu = nn.ReLU(inplace=True) | ||
| 51 | + self.conv2 = conv3x3(planes, planes) | ||
| 52 | + self.bn2 = norm_layer(planes) | ||
| 53 | + self.downsample = downsample | ||
| 54 | + self.stride = stride | ||
| 55 | + | ||
| 56 | + def forward(self, x): | ||
| 57 | + identity = x | ||
| 58 | + | ||
| 59 | + out = self.conv1(x) | ||
| 60 | + out = self.bn1(out) | ||
| 61 | + out = self.relu(out) | ||
| 62 | + | ||
| 63 | + out = self.conv2(out) | ||
| 64 | + out = self.bn2(out) | ||
| 65 | + | ||
| 66 | + if self.downsample is not None: | ||
| 67 | + identity = self.downsample(x) | ||
| 68 | + | ||
| 69 | + out += identity | ||
| 70 | + out = self.relu(out) | ||
| 71 | + | ||
| 72 | + return out | ||
| 73 | + | ||
| 74 | + | ||
| 75 | +class Bottleneck(nn.Module): | ||
| 76 | + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | ||
| 77 | + # while original implementation places the stride at the first 1x1 convolution(self.conv1) | ||
| 78 | + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | ||
| 79 | + # This variant is also known as ResNet V1.5 and improves accuracy according to | ||
| 80 | + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | ||
| 81 | + | ||
| 82 | + expansion = 4 | ||
| 83 | + | ||
| 84 | + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||
| 85 | + base_width=64, dilation=1, norm_layer=None): | ||
| 86 | + super(Bottleneck, self).__init__() | ||
| 87 | + if norm_layer is None: | ||
| 88 | + norm_layer = nn.BatchNorm2d | ||
| 89 | + width = int(planes * (base_width / 64.)) * groups | ||
| 90 | + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||
| 91 | + self.conv1 = conv1x1(inplanes, width) | ||
| 92 | + self.bn1 = norm_layer(width) | ||
| 93 | + self.conv2 = conv3x3(width, width, stride, groups, dilation) | ||
| 94 | + self.bn2 = norm_layer(width) | ||
| 95 | + self.conv3 = conv1x1(width, planes * self.expansion) | ||
| 96 | + self.bn3 = norm_layer(planes * self.expansion) | ||
| 97 | + self.relu = nn.ReLU(inplace=True) | ||
| 98 | + self.downsample = downsample | ||
| 99 | + self.stride = stride | ||
| 100 | + | ||
| 101 | + def forward(self, x): | ||
| 102 | + identity = x | ||
| 103 | + | ||
| 104 | + out = self.conv1(x) | ||
| 105 | + out = self.bn1(out) | ||
| 106 | + out = self.relu(out) | ||
| 107 | + | ||
| 108 | + out = self.conv2(out) | ||
| 109 | + out = self.bn2(out) | ||
| 110 | + out = self.relu(out) | ||
| 111 | + | ||
| 112 | + out = self.conv3(out) | ||
| 113 | + out = self.bn3(out) | ||
| 114 | + | ||
| 115 | + if self.downsample is not None: | ||
| 116 | + identity = self.downsample(x) | ||
| 117 | + | ||
| 118 | + out += identity | ||
| 119 | + out = self.relu(out) | ||
| 120 | + | ||
| 121 | + return out | ||
| 122 | + | ||
| 123 | + | ||
| 124 | +class ResNet(nn.Module): | ||
| 125 | + | ||
| 126 | + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, | ||
| 127 | + groups=1, width_per_group=64, replace_stride_with_dilation=None, | ||
| 128 | + norm_layer=None): | ||
| 129 | + super(ResNet, self).__init__() | ||
| 130 | + if norm_layer is None: | ||
| 131 | + norm_layer = nn.BatchNorm2d | ||
| 132 | + self._norm_layer = norm_layer | ||
| 133 | + | ||
| 134 | + self.inplanes = 64 | ||
| 135 | + self.dilation = 1 | ||
| 136 | + if replace_stride_with_dilation is None: | ||
| 137 | + # each element in the tuple indicates if we should replace | ||
| 138 | + # the 2x2 stride with a dilated convolution instead | ||
| 139 | + replace_stride_with_dilation = [False, False, False] | ||
| 140 | + if len(replace_stride_with_dilation) != 3: | ||
| 141 | + raise ValueError("replace_stride_with_dilation should be None " | ||
| 142 | + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | ||
| 143 | + self.groups = groups | ||
| 144 | + self.base_width = width_per_group | ||
| 145 | + # change dimension 3->1 for grayscale input | ||
| 146 | + self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, | ||
| 147 | + bias=False) | ||
| 148 | + self.bn1 = norm_layer(self.inplanes) | ||
| 149 | + self.relu = nn.ReLU(inplace=True) | ||
| 150 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
| 151 | + self.layer1 = self._make_layer(block, 64, layers[0]) | ||
| 152 | + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | ||
| 153 | + dilate=replace_stride_with_dilation[0]) | ||
| 154 | + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | ||
| 155 | + dilate=replace_stride_with_dilation[1]) | ||
| 156 | + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | ||
| 157 | + dilate=replace_stride_with_dilation[2]) | ||
| 158 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
| 159 | + self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
| 160 | + | ||
| 161 | + for m in self.modules(): | ||
| 162 | + if isinstance(m, nn.Conv2d): | ||
| 163 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
| 164 | + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
| 165 | + nn.init.constant_(m.weight, 1) | ||
| 166 | + nn.init.constant_(m.bias, 0) | ||
| 167 | + | ||
| 168 | + # Zero-initialize the last BN in each residual branch, | ||
| 169 | + # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||
| 170 | + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||
| 171 | + if zero_init_residual: | ||
| 172 | + for m in self.modules(): | ||
| 173 | + if isinstance(m, Bottleneck): | ||
| 174 | + nn.init.constant_(m.bn3.weight, 0) | ||
| 175 | + elif isinstance(m, BasicBlock): | ||
| 176 | + nn.init.constant_(m.bn2.weight, 0) | ||
| 177 | + | ||
| 178 | + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||
| 179 | + norm_layer = self._norm_layer | ||
| 180 | + downsample = None | ||
| 181 | + previous_dilation = self.dilation | ||
| 182 | + if dilate: | ||
| 183 | + self.dilation *= stride | ||
| 184 | + stride = 1 | ||
| 185 | + if stride != 1 or self.inplanes != planes * block.expansion: | ||
| 186 | + downsample = nn.Sequential( | ||
| 187 | + conv1x1(self.inplanes, planes * block.expansion, stride), | ||
| 188 | + norm_layer(planes * block.expansion), | ||
| 189 | + ) | ||
| 190 | + | ||
| 191 | + layers = [] | ||
| 192 | + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | ||
| 193 | + self.base_width, previous_dilation, norm_layer)) | ||
| 194 | + self.inplanes = planes * block.expansion | ||
| 195 | + for _ in range(1, blocks): | ||
| 196 | + layers.append(block(self.inplanes, planes, groups=self.groups, | ||
| 197 | + base_width=self.base_width, dilation=self.dilation, | ||
| 198 | + norm_layer=norm_layer)) | ||
| 199 | + | ||
| 200 | + return nn.Sequential(*layers) | ||
| 201 | + | ||
| 202 | + def _forward_impl(self, x): | ||
| 203 | + # See note [TorchScript super()] | ||
| 204 | + x = self.conv1(x) | ||
| 205 | + x = self.bn1(x) | ||
| 206 | + x = self.relu(x) | ||
| 207 | + x = self.maxpool(x) | ||
| 208 | + | ||
| 209 | + x = self.layer1(x) | ||
| 210 | + x = self.layer2(x) | ||
| 211 | + x = self.layer3(x) | ||
| 212 | + x = self.layer4(x) | ||
| 213 | + | ||
| 214 | + x = self.avgpool(x) | ||
| 215 | + x = torch.flatten(x, 1) | ||
| 216 | + x = self.fc(x) | ||
| 217 | + | ||
| 218 | + return x | ||
| 219 | + | ||
| 220 | + def forward(self, x): | ||
| 221 | + return self._forward_impl(x) | ||
| 222 | + | ||
| 223 | + | ||
| 224 | +def _resnet(arch, block, layers, pretrained, progress, **kwargs): | ||
| 225 | + model = ResNet(block, layers, **kwargs) | ||
| 226 | + # if pretrained: | ||
| 227 | + # state_dict = load_state_dict_from_url(model_urls[arch], | ||
| 228 | + # progress=progress) | ||
| 229 | + # model.load_state_dict(state_dict) | ||
| 230 | + return model | ||
| 231 | + | ||
| 232 | + | ||
| 233 | +def resnet18(pretrained=False, progress=True, **kwargs): | ||
| 234 | + r"""ResNet-18 model from | ||
| 235 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
| 236 | + Args: | ||
| 237 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 238 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 239 | + """ | ||
| 240 | + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, | ||
| 241 | + **kwargs) | ||
| 242 | + | ||
| 243 | + | ||
| 244 | +def resnet34(pretrained=False, progress=True, **kwargs): | ||
| 245 | + r"""ResNet-34 model from | ||
| 246 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
| 247 | + Args: | ||
| 248 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 249 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 250 | + """ | ||
| 251 | + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, | ||
| 252 | + **kwargs) | ||
| 253 | + | ||
| 254 | + | ||
| 255 | +def resnet50(pretrained=False, progress=True, **kwargs): | ||
| 256 | + r"""ResNet-50 model from | ||
| 257 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
| 258 | + Args: | ||
| 259 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 260 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 261 | + """ | ||
| 262 | + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | ||
| 263 | + **kwargs) | ||
| 264 | + | ||
| 265 | + | ||
| 266 | +def resnet101(pretrained=False, progress=True, **kwargs): | ||
| 267 | + r"""ResNet-101 model from | ||
| 268 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
| 269 | + Args: | ||
| 270 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 271 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 272 | + """ | ||
| 273 | + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, | ||
| 274 | + **kwargs) | ||
| 275 | + | ||
| 276 | + | ||
| 277 | +def resnet152(pretrained=False, progress=True, **kwargs): | ||
| 278 | + r"""ResNet-152 model from | ||
| 279 | + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||
| 280 | + Args: | ||
| 281 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 282 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 283 | + """ | ||
| 284 | + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, | ||
| 285 | + **kwargs) | ||
| 286 | + | ||
| 287 | + | ||
| 288 | +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): | ||
| 289 | + r"""ResNeXt-50 32x4d model from | ||
| 290 | + `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||
| 291 | + Args: | ||
| 292 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 293 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 294 | + """ | ||
| 295 | + kwargs['groups'] = 32 | ||
| 296 | + kwargs['width_per_group'] = 4 | ||
| 297 | + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], | ||
| 298 | + pretrained, progress, **kwargs) | ||
| 299 | + | ||
| 300 | + | ||
| 301 | +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): | ||
| 302 | + r"""ResNeXt-101 32x8d model from | ||
| 303 | + `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||
| 304 | + Args: | ||
| 305 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 306 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 307 | + """ | ||
| 308 | + kwargs['groups'] = 32 | ||
| 309 | + kwargs['width_per_group'] = 8 | ||
| 310 | + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], | ||
| 311 | + pretrained, progress, **kwargs) | ||
| 312 | + | ||
| 313 | + | ||
| 314 | +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): | ||
| 315 | + r"""Wide ResNet-50-2 model from | ||
| 316 | + `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||
| 317 | + The model is the same as ResNet except for the bottleneck number of channels | ||
| 318 | + which is twice larger in every block. The number of channels in outer 1x1 | ||
| 319 | + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||
| 320 | + channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||
| 321 | + Args: | ||
| 322 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 323 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 324 | + """ | ||
| 325 | + kwargs['width_per_group'] = 64 * 2 | ||
| 326 | + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], | ||
| 327 | + pretrained, progress, **kwargs) | ||
| 328 | + | ||
| 329 | + | ||
| 330 | +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): | ||
| 331 | + r"""Wide ResNet-101-2 model from | ||
| 332 | + `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||
| 333 | + The model is the same as ResNet except for the bottleneck number of channels | ||
| 334 | + which is twice larger in every block. The number of channels in outer 1x1 | ||
| 335 | + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||
| 336 | + channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||
| 337 | + Args: | ||
| 338 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
| 339 | + progress (bool): If True, displays a progress bar of the download to stderr | ||
| 340 | + """ | ||
| 341 | + kwargs['width_per_group'] = 64 * 2 | ||
| 342 | + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], | ||
| 343 | + pretrained, progress, **kwargs) |
| ... | @@ -6,6 +6,7 @@ import pickle as cp | ... | @@ -6,6 +6,7 @@ import pickle as cp |
| 6 | import glob | 6 | import glob |
| 7 | import numpy as np | 7 | import numpy as np |
| 8 | import pandas as pd | 8 | import pandas as pd |
| 9 | + | ||
| 9 | from natsort import natsorted | 10 | from natsort import natsorted |
| 10 | from PIL import Image | 11 | from PIL import Image |
| 11 | import torch | 12 | import torch |
| ... | @@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split | ... | @@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split |
| 21 | from sklearn.model_selection import KFold | 22 | from sklearn.model_selection import KFold |
| 22 | 23 | ||
| 23 | from networks import basenet | 24 | from networks import basenet |
| 25 | +from networks import grayResNet, grayResNet2 | ||
| 24 | 26 | ||
| 25 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 27 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
| 26 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
| ... | @@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): | ... | @@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): |
| 55 | 57 | ||
| 56 | return Dm_indexes, Da_indexes | 58 | return Dm_indexes, Da_indexes |
| 57 | 59 | ||
| 58 | -def split_dataset2222(args, dataset, k): | ||
| 59 | - # load dataset | ||
| 60 | - X = list(range(len(dataset))) | ||
| 61 | - | ||
| 62 | - # split to k-fold | ||
| 63 | - #assert len(X) == len(Y) | ||
| 64 | - | ||
| 65 | - def _it_to_list(_it): | ||
| 66 | - return list(zip(*list(_it))) | ||
| 67 | - | ||
| 68 | - x_train = () | ||
| 69 | - x_test = () | ||
| 70 | - | ||
| 71 | - for i in range(k): | ||
| 72 | - #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | ||
| 73 | - xtr, xte = train_test_split(X, random_state=None, test_size=0.1) | ||
| 74 | - x_train.append(np.array(xtr)) | ||
| 75 | - x_test.append(np.array(xte)) | ||
| 76 | - | ||
| 77 | - y_train = np.array([0]* len(x_train)) | ||
| 78 | - y_test = np.array([0]* len(x_test)) | ||
| 79 | - | ||
| 80 | - x_train = tuple(x_train) | ||
| 81 | - x_test = tuple(x_test) | ||
| 82 | - | ||
| 83 | - trainset = (zip(x_train, y_train),) | ||
| 84 | - testset = (zip(x_test, y_test),) | ||
| 85 | - | ||
| 86 | - Dm_indexes, Da_indexes = trainset, testset | ||
| 87 | - | ||
| 88 | - print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
| 89 | - print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes) | ||
| 90 | - | ||
| 91 | - return Dm_indexes, Da_indexes | ||
| 92 | 60 | ||
| 93 | def concat_image_features(image, features, max_features=3): | 61 | def concat_image_features(image, features, max_features=3): |
| 94 | _, h, w = image.shape | 62 | _, h, w = image.shape |
| ... | @@ -159,8 +127,22 @@ def parse_args(kwargs): | ... | @@ -159,8 +127,22 @@ def parse_args(kwargs): |
| 159 | 127 | ||
| 160 | 128 | ||
| 161 | def select_model(args): | 129 | def select_model(args): |
| 162 | - if args.network in models.__dict__: | 130 | + # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), |
| 163 | - backbone = models.__dict__[args.network]() | 131 | + # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} |
| 132 | + | ||
| 133 | + | ||
| 134 | + # grayResNet2 | ||
| 135 | + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
| 136 | + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
| 137 | + | ||
| 138 | + if args.network in resnet_dict: | ||
| 139 | + backbone = resnet_dict[args.network] | ||
| 140 | + #testing | ||
| 141 | + # print("\nRESNET50 LAYERS\n") | ||
| 142 | + # for layer in backbone.children(): | ||
| 143 | + # print(layer) | ||
| 144 | + # print("LAYER THE END\n") | ||
| 145 | + | ||
| 164 | model = basenet.BaseNet(backbone, args) | 146 | model = basenet.BaseNet(backbone, args) |
| 165 | else: | 147 | else: |
| 166 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 148 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ... | ... |
-
Please register or login to post a comment