조현아

FAA ver2.0 for colab

Showing 21 changed files with 1636 additions and 0 deletions
1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# C extensions
7 +*.so
8 +
9 +# Distribution / packaging
10 +.Python
11 +build/
12 +develop-eggs/
13 +dist/
14 +downloads/
15 +eggs/
16 +.eggs/
17 +lib/
18 +lib64/
19 +parts/
20 +sdist/
21 +var/
22 +wheels/
23 +*.egg-info/
24 +.installed.cfg
25 +*.egg
26 +MANIFEST
27 +
28 +# PyInstaller
29 +# Usually these files are written by a python script from a template
30 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
31 +*.manifest
32 +*.spec
33 +
34 +# Installer logs
35 +pip-log.txt
36 +pip-delete-this-directory.txt
37 +
38 +# Unit test / coverage reports
39 +htmlcov/
40 +.tox/
41 +.coverage
42 +.coverage.*
43 +.cache
44 +nosetests.xml
45 +coverage.xml
46 +*.cover
47 +.hypothesis/
48 +.pytest_cache/
49 +
50 +# Translations
51 +*.mo
52 +*.pot
53 +
54 +# Django stuff:
55 +*.log
56 +local_settings.py
57 +db.sqlite3
58 +
59 +# Flask stuff:
60 +instance/
61 +.webassets-cache
62 +
63 +# Scrapy stuff:
64 +.scrapy
65 +
66 +# Sphinx documentation
67 +docs/_build/
68 +
69 +# PyBuilder
70 +target/
71 +
72 +# Jupyter Notebook
73 +.ipynb_checkpoints
74 +
75 +# pyenv
76 +.python-version
77 +
78 +# celery beat schedule file
79 +celerybeat-schedule
80 +
81 +# SageMath parsed files
82 +*.sage.py
83 +
84 +# Environments
85 +.env
86 +.venv
87 +env/
88 +venv/
89 +ENV/
90 +env.bak/
91 +venv.bak/
92 +
93 +# Spyder project settings
94 +.spyderproject
95 +.spyproject
96 +
97 +# Rope project settings
98 +.ropeproject
99 +
100 +# mkdocs documentation
101 +/site
102 +
103 +# mypy
104 +.mypy_cache/
105 +
1 +# Fast Autoaugment
2 +<img src="figures/faa.png" width=800px>
3 +
4 +A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946).
5 +
6 +## Prerequisite
7 +* torch==1.1.0
8 +* torchvision==0.2.2
9 +* hyperopt==0.1.2
10 +* future==0.17.1
11 +* tb-nightly==1.15.0a20190622
12 +
13 +## Usage
14 +### Training
15 +#### CIFAR10
16 +```bash
17 +# ResNet20 (w/o FastAutoAugment)
18 +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False
19 +
20 +# ResNet20 (w/ FastAutoAugment)
21 +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True
22 +
23 +# ResNet20 (w/ FastAutoAugment, Pre-found policy)
24 +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \
25 + --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp
26 +
27 +# ResNet32 (w/o FastAutoAugment)
28 +python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False
29 +
30 +# ResNet32 (w/ FastAutoAugment)
31 +python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True
32 +
33 +# EfficientNet (w/ FastAutoAugment)
34 +python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \
35 + --network=efficientnet_cifar10 --activation=swish
36 +```
37 +
38 +#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html))
39 +```bash
40 +
41 +# BaseNet (w/o FastAutoAugment)
42 +python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50
43 +
44 +# EfficientNet (w/ FastAutoAugment) (UnderConstruction)
45 +python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \
46 + --network=efficientnet --activation=swish
47 +```
48 +
49 +### Eval
50 +```bash
51 +# Single Image testing
52 +python eval.py --model_path=runs/ResNet_Scale3_Basline
53 +
54 +# 5-crops testing
55 +python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True
56 +```
57 +
58 +## Experiments
59 +### Fast AutoAugment
60 +#### ResNet20 (CIFAR10)
61 +* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)]
62 +* Validation Curve
63 +<img src="figures/resnet20_valid.png">
64 +
65 +* Evaluation (Acc @1)
66 +
67 +| | Valid | Test(Single) |
68 +|----------------|-------|-------------|
69 +| ResNet20 | 90.70 | **91.45** |
70 +| ResNet20 + FAA |**92.46**| **91.45** |
71 +
72 +#### ResNet34 (CIFAR10)
73 +* Validation Curve
74 +<img src="figures/resnet34_valid.png">
75 +
76 +* Evaluation (Acc @1)
77 +
78 +| | Valid | Test(Single) |
79 +|----------------|-------|-------------|
80 +| ResNet34 | 91.54 | 91.47 |
81 +| ResNet34 + FAA |**92.76**| **91.99** |
82 +
83 +### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)]
84 +<img src="figures/pm.png">
85 +
86 +### Augmented images
87 +<img src="figures/augmented_images.png">
88 +<img src="figures/augmented_images2.png">
1 +import os
2 +import fire
3 +import json
4 +from pprint import pprint
5 +
6 +import torch
7 +import torch.nn as nn
8 +
9 +from utils import *
10 +
11 +
12 +def eval(model_path):
13 + print('\n[+] Parse arguments')
14 + kwargs_path = os.path.join(model_path, 'kwargs.json')
15 + kwargs = json.loads(open(kwargs_path).read())
16 + args, kwargs = parse_args(kwargs)
17 + pprint(args)
18 + device = torch.device('cuda' if args.use_cuda else 'cpu')
19 +
20 + print('\n[+] Create network')
21 + model = select_model(args)
22 + optimizer = select_optimizer(args, model)
23 + criterion = nn.CrossEntropyLoss()
24 + if args.use_cuda:
25 + model = model.cuda()
26 + criterion = criterion.cuda()
27 +
28 + print('\n[+] Load model')
29 + weight_path = os.path.join(model_path, 'model', 'model.pt')
30 + model.load_state_dict(torch.load(weight_path))
31 +
32 + print('\n[+] Load dataset')
33 + test_transform = get_valid_transform(args, model)
34 + test_dataset = get_dataset(args, test_transform, 'test')
35 + test_loader = iter(get_dataloader(args, test_dataset))
36 +
37 + print('\n[+] Start testing')
38 + _test_res = validate(args, model, criterion, test_loader, step=0, writer=None)
39 +
40 + print('\n[+] Valid results')
41 + print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100))
42 + print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100))
43 + print(' Loss : {:.3f}'.format(_test_res[2].data))
44 + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset)))
45 +
46 +
47 +if __name__ == '__main__':
48 + fire.Fire(eval)
1 +import copy
2 +import json
3 +import time
4 +import torch
5 +import random
6 +import torchvision.transforms as transforms
7 +
8 +from torch.utils.data import Subset
9 +from sklearn.model_selection import StratifiedShuffleSplit
10 +from concurrent.futures import ProcessPoolExecutor
11 +
12 +from transforms import *
13 +from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
14 +from utils import *
15 +
16 +
17 +DEFALUT_CANDIDATES = [
18 + ShearXY,
19 + TranslateXY,
20 + Rotate,
21 + AutoContrast,
22 + Invert,
23 + Equalize,
24 + Solarize,
25 + Posterize,
26 + Contrast,
27 + Color,
28 + Brightness,
29 + Sharpness,
30 + Cutout,
31 +# SamplePairing,
32 +]
33 +
34 +
35 +def train_child(args, model, dataset, subset_indx, device=None):
36 + optimizer = select_optimizer(args, model)
37 + scheduler = select_scheduler(args, optimizer)
38 + criterion = nn.CrossEntropyLoss()
39 +
40 + dataset.transform = transforms.Compose([
41 + transforms.Resize(32),
42 + transforms.ToTensor()])
43 + subset = Subset(dataset, subset_indx)
44 + data_loader = get_inf_dataloader(args, subset)
45 +
46 + if device:
47 + model = model.to(device)
48 + criterion = criterion.to(device)
49 +
50 + elif args.use_cuda:
51 + model = model.cuda()
52 + criterion = criterion.cuda()
53 +
54 + if torch.cuda.device_count() > 1:
55 + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
56 + model = nn.DataParallel(model)
57 +
58 + start_t = time.time()
59 + for step in range(args.start_step, args.max_step):
60 + batch = next(data_loader)
61 + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device)
62 +
63 + if step % args.print_step == 0:
64 + print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format(
65 + step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device))
66 +
67 + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
68 + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
69 + print(' Loss : {}'.format(_train_res[2].data))
70 +
71 + return _train_res
72 +
73 +
74 +def validate_child(args, model, dataset, subset_indx, transform, device=None):
75 + criterion = nn.CrossEntropyLoss()
76 +
77 + if device:
78 + model = model.to(device)
79 + criterion = criterion.to(device)
80 +
81 + elif args.use_cuda:
82 + model = model.cuda()
83 + criterion = criterion.cuda()
84 +
85 + dataset.transform = transform
86 + subset = Subset(dataset, subset_indx)
87 + data_loader = get_dataloader(args, subset, pin_memory=False)
88 +
89 + return validate(args, model, criterion, data_loader, 0, None, device)
90 +
91 +
92 +def get_next_subpolicy(transform_candidates, op_per_subpolicy=2):
93 + n_candidates = len(transform_candidates)
94 + subpolicy = []
95 +
96 + for i in range(op_per_subpolicy):
97 + indx = random.randrange(n_candidates)
98 + prob = random.random()
99 + mag = random.random()
100 + subpolicy.append(transform_candidates[indx](prob, mag))
101 +
102 + subpolicy = transforms.Compose([
103 + *subpolicy,
104 + transforms.Resize(32),
105 + transforms.ToTensor()])
106 +
107 + return subpolicy
108 +
109 +
110 +def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device):
111 + subpolicies = []
112 +
113 + for b in range(B):
114 + subpolicy = get_next_subpolicy(transform_candidates)
115 + val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device)
116 + subpolicies.append((subpolicy, val_res[2]))
117 +
118 + return subpolicies
119 +
120 +
121 +def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device):
122 +
123 + def _objective(sampled):
124 + subpolicy = [transform(prob, mag)
125 + for transform, prob, mag in sampled]
126 +
127 + subpolicy = transforms.Compose([
128 + transforms.Resize(32),
129 + *subpolicy,
130 + transforms.ToTensor()])
131 +
132 + val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device)
133 + loss = val_res[2].cpu().numpy()
134 + return {'loss': loss, 'status': STATUS_OK }
135 +
136 + space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)),
137 + (hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))]
138 +
139 + trials = Trials()
140 + best = fmin(_objective,
141 + space=space,
142 + algo=tpe.suggest,
143 + max_evals=B,
144 + trials=trials)
145 +
146 + subpolicies = []
147 + for t in trials.trials:
148 + vals = t['misc']['vals']
149 + subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]),
150 + transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])]
151 + subpolicy = transforms.Compose([
152 + ## baseline augmentation
153 + transforms.Pad(4),
154 + transforms.RandomCrop(32),
155 + transforms.RandomHorizontalFlip(),
156 + ## policy
157 + *subpolicy,
158 + ## to tensor
159 + transforms.ToTensor()])
160 + subpolicies.append((subpolicy, t['result']['loss']))
161 +
162 + return subpolicies
163 +
164 +
165 +def get_topn_subpolicies(subpolicies, N=10):
166 + return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N]
167 +
168 +
169 +def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k):
170 + kwargs = json.loads(args_str)
171 + args, kwargs = parse_args(kwargs)
172 + device_id = k % torch.cuda.device_count()
173 + device = torch.device('cuda:%d' % device_id)
174 + _transform = []
175 +
176 + print('[+] Child %d training strated (GPU: %d)' % (k, device_id))
177 +
178 + # train child model
179 + child_model = copy.deepcopy(model)
180 + train_res = train_child(args, child_model, dataset, Dm_indx, device)
181 +
182 + # search sub policy
183 + for t in range(T):
184 + #subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device)
185 + subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device)
186 + subpolicies = get_topn_subpolicies(subpolicies, N)
187 + _transform.extend([subpolicy[0] for subpolicy in subpolicies])
188 +
189 + return _transform
190 +
191 +
192 +def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5):
193 + args_str = json.dumps(args._asdict())
194 + dataset = get_dataset(args, None, 'trainval')
195 + num_process = min(torch.cuda.device_count(), num_process)
196 + transform, futures = [], []
197 +
198 + torch.multiprocessing.set_start_method('spawn', force=True)
199 +
200 + if not transform_candidates:
201 + transform_candidates = DEFALUT_CANDIDATES
202 +
203 + # split
204 + Dm_indexes, Da_indexes = split_dataset(args, dataset, K)
205 +
206 + with ProcessPoolExecutor(max_workers=num_process) as executor:
207 + for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)):
208 + future = executor.submit(process_fn,
209 + args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k)
210 + futures.append(future)
211 +
212 + for future in futures:
213 + transform.extend(future.result())
214 +
215 + transform = transforms.RandomChoice(transform)
216 +
217 + return transform
1 +import torch.nn as nn
2 +
3 +class BaseNet(nn.Module):
4 + def __init__(self, backbone, args):
5 + super(BaseNet, self).__init__()
6 +
7 + # Separate layers
8 + self.first = nn.Sequential(*list(backbone.children())[:1])
9 + self.after = nn.Sequential(*list(backbone.children())[1:-1])
10 + self.fc = list(backbone.children())[-1]
11 +
12 + self.img_size = (224, 224)
13 +
14 + def forward(self, x):
15 + f = self.first(x)
16 + x = self.after(f)
17 + x = x.reshape(x.size(0), -1)
18 + x = self.fc(x)
19 + return x, f
1 +import math
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +
5 +
6 +def round_fn(orig, multiplier):
7 + if not multiplier:
8 + return orig
9 +
10 + return int(math.ceil(multiplier * orig))
11 +
12 +
13 +def get_activation_fn(activation):
14 + if activation == "swish":
15 + return Swish
16 +
17 + elif activation == "relu":
18 + return nn.ReLU
19 +
20 + else:
21 + raise Exception('Unkown activation %s' % activation)
22 +
23 +
24 +class Swish(nn.Module):
25 + """ Swish activation function, s(x) = x * sigmoid(x) """
26 +
27 + def __init__(self, inplace=False):
28 + super().__init__()
29 + self.inplace = True
30 +
31 + def forward(self, x):
32 + if self.inplace:
33 + x.mul_(F.sigmoid(x))
34 + return x
35 + else:
36 + return x * F.sigmoid(x)
37 +
38 +
39 +class ConvBlock(nn.Module):
40 + """ Conv + BatchNorm + Activation """
41 +
42 + def __init__(self, in_channel, out_channel, kernel_size,
43 + padding=0, stride=1, activation="swish"):
44 + super().__init__()
45 + self.fw = nn.Sequential(
46 + nn.Conv2d(in_channel, out_channel, kernel_size,
47 + padding=padding, stride=stride, bias=False),
48 + nn.BatchNorm2d(out_channel),
49 + get_activation_fn(activation)())
50 +
51 + def forward(self, x):
52 + return self.fw(x)
53 +
54 +
55 +class DepthwiseConvBlock(nn.Module):
56 + """ DepthwiseConv2D + BatchNorm + Activation """
57 +
58 + def __init__(self, in_channel, kernel_size,
59 + padding=0, stride=1, activation="swish"):
60 + super().__init__()
61 + self.fw = nn.Sequential(
62 + nn.Conv2d(in_channel, in_channel, kernel_size,
63 + padding=padding, stride=stride, groups=in_channel, bias=False),
64 + nn.BatchNorm2d(in_channel),
65 + get_activation_fn(activation)())
66 +
67 + def forward(self, x):
68 + return self.fw(x)
69 +
70 +
71 +class MBConv(nn.Module):
72 + """ Inverted residual block """
73 +
74 + def __init__(self, in_channel, out_channel, kernel_size,
75 + stride=1, expand_ratio=1, activation="swish"):
76 + super().__init__()
77 + self.in_channel = in_channel
78 + self.out_channel = out_channel
79 + self.expand_ratio = expand_ratio
80 + self.stride = stride
81 +
82 + if expand_ratio != 1:
83 + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
84 + activation=activation)
85 +
86 + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
87 + padding=(kernel_size-1)//2,
88 + stride=stride, activation=activation)
89 +
90 + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
91 + activation=activation)
92 +
93 + def forward(self, inputs):
94 + if self.expand_ratio != 1:
95 + x = self.expand(inputs)
96 + else:
97 + x = inputs
98 +
99 + x = self.dw_conv(x)
100 + x = self.pw_conv(x)
101 +
102 + if self.in_channel == self.out_channel and \
103 + self.stride == 1:
104 + x = x + inputs
105 +
106 + return x
107 +
108 +
109 +class Net(nn.Module):
110 + """ EfficientNet """
111 +
112 + def __init__(self, args):
113 + super(Net, self).__init__()
114 + pi = args.pi
115 + activation = args.activation
116 + num_classes = args.num_classes
117 +
118 + self.d = 1.2 ** pi
119 + self.w = 1.1 ** pi
120 + self.r = 1.15 ** pi
121 + self.img_size = (round_fn(224, self.r), round_fn(224, self.r))
122 +
123 + self.stage1 = ConvBlock(3, round_fn(32, self.w),
124 + kernel_size=3, padding=1, stride=2, activation=activation)
125 +
126 + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
127 + depth=round_fn(1, self.d), kernel_size=3,
128 + half_resolution=False, expand_ratio=1, activation=activation)
129 +
130 + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
131 + depth=round_fn(2, self.d), kernel_size=3,
132 + half_resolution=True, expand_ratio=6, activation=activation)
133 +
134 + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
135 + depth=round_fn(2, self.d), kernel_size=5,
136 + half_resolution=True, expand_ratio=6, activation=activation)
137 +
138 + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
139 + depth=round_fn(3, self.d), kernel_size=3,
140 + half_resolution=True, expand_ratio=6, activation=activation)
141 +
142 + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
143 + depth=round_fn(3, self.d), kernel_size=5,
144 + half_resolution=False, expand_ratio=6, activation=activation)
145 +
146 + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
147 + depth=round_fn(4, self.d), kernel_size=5,
148 + half_resolution=True, expand_ratio=6, activation=activation)
149 +
150 + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
151 + depth=round_fn(1, self.d), kernel_size=3,
152 + half_resolution=False, expand_ratio=6, activation=activation)
153 +
154 + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
155 + kernel_size=1, activation=activation)
156 +
157 + self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes)
158 +
159 + def make_layers(self, in_channel, out_channel, depth, kernel_size,
160 + half_resolution=False, expand_ratio=1, activation="swish"):
161 + blocks = []
162 + for i in range(depth):
163 + stride = 2 if half_resolution and i==0 else 1
164 + blocks.append(
165 + MBConv(in_channel, out_channel, kernel_size,
166 + stride=stride, expand_ratio=expand_ratio, activation=activation))
167 + in_channel = out_channel
168 +
169 + return nn.Sequential(*blocks)
170 +
171 + def forward(self, x):
172 + assert x.size()[-2:] == self.img_size, \
173 + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
174 +
175 + x = self.stage1(x)
176 + x = self.stage2(x)
177 + x = self.stage3(x)
178 + x = self.stage4(x)
179 + x = self.stage5(x)
180 + x = self.stage6(x)
181 + x = self.stage7(x)
182 + x = self.stage8(x)
183 + x = self.stage9(x)
184 + x = x.reshape(x.size(0), -1)
185 + x = self.fc(x)
186 + return x, x
1 +import math
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +
5 +
6 +def round_fn(orig, multiplier):
7 + if not multiplier:
8 + return orig
9 +
10 + return int(math.ceil(multiplier * orig))
11 +
12 +
13 +def get_activation_fn(activation):
14 + if activation == "swish":
15 + return Swish
16 +
17 + elif activation == "relu":
18 + return nn.ReLU
19 +
20 + else:
21 + raise Exception('Unkown activation %s' % activation)
22 +
23 +
24 +class Swish(nn.Module):
25 + """ Swish activation function, s(x) = x * sigmoid(x) """
26 +
27 + def __init__(self, inplace=False):
28 + super().__init__()
29 + self.inplace = True
30 +
31 + def forward(self, x):
32 + if self.inplace:
33 + x.mul_(F.sigmoid(x))
34 + return x
35 + else:
36 + return x * F.sigmoid(x)
37 +
38 +
39 +class ConvBlock(nn.Module):
40 + """ Conv + BatchNorm + Activation """
41 +
42 + def __init__(self, in_channel, out_channel, kernel_size,
43 + padding=0, stride=1, activation="swish"):
44 + super().__init__()
45 + self.fw = nn.Sequential(
46 + nn.Conv2d(in_channel, out_channel, kernel_size,
47 + padding=padding, stride=stride, bias=False),
48 + nn.BatchNorm2d(out_channel),
49 + get_activation_fn(activation)())
50 +
51 + def forward(self, x):
52 + return self.fw(x)
53 +
54 +
55 +class DepthwiseConvBlock(nn.Module):
56 + """ DepthwiseConv2D + BatchNorm + Activation """
57 +
58 + def __init__(self, in_channel, kernel_size,
59 + padding=0, stride=1, activation="swish"):
60 + super().__init__()
61 + self.fw = nn.Sequential(
62 + nn.Conv2d(in_channel, in_channel, kernel_size,
63 + padding=padding, stride=stride, groups=in_channel, bias=False),
64 + nn.BatchNorm2d(in_channel),
65 + get_activation_fn(activation)())
66 +
67 + def forward(self, x):
68 + return self.fw(x)
69 +
70 +
71 +class SEBlock(nn.Module):
72 + """ Squeeze and Excitation Block """
73 +
74 + def __init__(self, in_channel, se_ratio=16):
75 + super().__init__()
76 + self.global_avgpool = nn.AdaptiveAvgPool2d((1,1))
77 + inter_channel = in_channel // se_ratio
78 +
79 + self.reduce = nn.Sequential(
80 + nn.Conv2d(in_channel, inter_channel,
81 + kernel_size=1, padding=0, stride=1),
82 + nn.ReLU())
83 +
84 + self.expand = nn.Sequential(
85 + nn.Conv2d(inter_channel, in_channel,
86 + kernel_size=1, padding=0, stride=1),
87 + nn.Sigmoid())
88 +
89 +
90 + def forward(self, x):
91 + s = self.global_avgpool(x)
92 + s = self.reduce(s)
93 + s = self.expand(s)
94 + return x * s
95 +
96 +
97 +class MBConv(nn.Module):
98 + """ Inverted residual block """
99 +
100 + def __init__(self, in_channel, out_channel, kernel_size,
101 + stride=1, expand_ratio=1, activation="swish", use_seblock=False):
102 + super().__init__()
103 + self.in_channel = in_channel
104 + self.out_channel = out_channel
105 + self.expand_ratio = expand_ratio
106 + self.stride = stride
107 + self.use_seblock = use_seblock
108 +
109 + if expand_ratio != 1:
110 + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
111 + activation=activation)
112 +
113 + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
114 + padding=(kernel_size-1)//2,
115 + stride=stride, activation=activation)
116 +
117 + if use_seblock:
118 + self.seblock = SEBlock(in_channel*expand_ratio)
119 +
120 + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
121 + activation=activation)
122 +
123 + def forward(self, inputs):
124 + if self.expand_ratio != 1:
125 + x = self.expand(inputs)
126 + else:
127 + x = inputs
128 +
129 + x = self.dw_conv(x)
130 +
131 + if self.use_seblock:
132 + x = self.seblock(x)
133 +
134 + x = self.pw_conv(x)
135 +
136 + if self.in_channel == self.out_channel and \
137 + self.stride == 1:
138 + x = x + inputs
139 +
140 + return x
141 +
142 +
143 +class Net(nn.Module):
144 + """ EfficientNet """
145 +
146 + def __init__(self, args):
147 + super(Net, self).__init__()
148 + pi = args.pi
149 + activation = args.activation
150 + num_classes = 10
151 +
152 + self.d = 1.2 ** pi
153 + self.w = 1.1 ** pi
154 + self.r = 1.15 ** pi
155 + self.img_size = (round_fn(32, self.r), round_fn(32, self.r))
156 + self.use_seblock = args.use_seblock
157 +
158 + self.stage1 = ConvBlock(3, round_fn(32, self.w),
159 + kernel_size=3, padding=1, stride=2, activation=activation)
160 +
161 + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
162 + depth=round_fn(1, self.d), kernel_size=3,
163 + half_resolution=False, expand_ratio=1, activation=activation)
164 +
165 + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
166 + depth=round_fn(2, self.d), kernel_size=3,
167 + half_resolution=True, expand_ratio=6, activation=activation)
168 +
169 + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
170 + depth=round_fn(2, self.d), kernel_size=5,
171 + half_resolution=True, expand_ratio=6, activation=activation)
172 +
173 + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
174 + depth=round_fn(3, self.d), kernel_size=3,
175 + half_resolution=True, expand_ratio=6, activation=activation)
176 +
177 + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
178 + depth=round_fn(3, self.d), kernel_size=5,
179 + half_resolution=False, expand_ratio=6, activation=activation)
180 +
181 + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
182 + depth=round_fn(4, self.d), kernel_size=5,
183 + half_resolution=True, expand_ratio=6, activation=activation)
184 +
185 + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
186 + depth=round_fn(1, self.d), kernel_size=3,
187 + half_resolution=False, expand_ratio=6, activation=activation)
188 +
189 + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
190 + kernel_size=1, activation=activation)
191 +
192 + self.fc = nn.Linear(round_fn(1280, self.w), num_classes)
193 +
194 + def make_layers(self, in_channel, out_channel, depth, kernel_size,
195 + half_resolution=False, expand_ratio=1, activation="swish"):
196 + blocks = []
197 + for i in range(depth):
198 + stride = 2 if half_resolution and i==0 else 1
199 + blocks.append(
200 + MBConv(in_channel, out_channel, kernel_size,
201 + stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock))
202 + in_channel = out_channel
203 +
204 + return nn.Sequential(*blocks)
205 +
206 + def forward(self, x):
207 + assert x.size()[-2:] == self.img_size, \
208 + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
209 +
210 + s = self.stage1(x)
211 + x = self.stage2(s)
212 + x = self.stage3(x)
213 + x = self.stage4(x)
214 + x = self.stage5(x)
215 + x = self.stage6(x)
216 + x = self.stage7(x)
217 + x = self.stage8(x)
218 + x = self.stage9(x)
219 + x = x.reshape(x.size(0), -1)
220 + x = self.fc(x)
221 + return x, s
1 +import torch.nn as nn
2 +
3 +
4 +class ResidualBlock(nn.Module):
5 + def __init__(self, in_channel, out_channel, stride):
6 + super(ResidualBlock, self).__init__()
7 + self.in_channel = in_channel
8 + self.out_channel = out_channel
9 + self.stride = stride
10 +
11 + self.conv1 = nn.Sequential(
12 + nn.Conv2d(in_channel, out_channel,
13 + kernel_size=3, padding=1, stride=stride),
14 + nn.BatchNorm2d(out_channel))
15 +
16 + self.relu = nn.ReLU(inplace=True)
17 +
18 + self.conv2 = nn.Sequential(
19 + nn.Conv2d(out_channel, out_channel,
20 + kernel_size=3, padding=1),
21 + nn.BatchNorm2d(out_channel))
22 +
23 + if self.in_channel != self.out_channel or \
24 + self.stride != 1:
25 + self.down = nn.Sequential(
26 + nn.Conv2d(in_channel, out_channel,
27 + kernel_size=1, stride=stride),
28 + nn.BatchNorm2d(out_channel))
29 +
30 + def forward(self, b):
31 + t = self.conv1(b)
32 + t = self.relu(t)
33 + t = self.conv2(t)
34 +
35 + if self.in_channel != self.out_channel or \
36 + self.stride != 1:
37 + b = self.down(b)
38 +
39 + t += b
40 + t = self.relu(t)
41 +
42 + return t
43 +
44 +
45 +class Net(nn.Module):
46 + def __init__(self, args):
47 + super(Net, self).__init__()
48 + scale = args.scale
49 +
50 + self.stem = nn.Sequential(
51 + nn.Conv2d(3, 16,
52 + kernel_size=3, padding=1),
53 + nn.BatchNorm2d(16),
54 + nn.ReLU(inplace=True))
55 +
56 + self.layer1 = nn.Sequential(*[
57 + ResidualBlock(16, 16, 1) for _ in range(2*scale)])
58 +
59 + self.layer2 = nn.Sequential(*[
60 + ResidualBlock(in_channel=(16 if i==0 else 32),
61 + out_channel=32,
62 + stride=(2 if i==0 else 1)) for i in range(2*scale)])
63 +
64 + self.layer3 = nn.Sequential(*[
65 + ResidualBlock(in_channel=(32 if i==0 else 64),
66 + out_channel=64,
67 + stride=(2 if i==0 else 1)) for i in range(2*scale)])
68 +
69 + self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
70 +
71 + self.fc = nn.Linear(64, 10)
72 +
73 + for m in self.modules():
74 + if isinstance(m, nn.Conv2d):
75 + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
76 +
77 + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
78 + nn.init.constant_(m.weight, 1)
79 + nn.init.constant_(m.bias, 0)
80 +
81 + def forward(self, x):
82 + s = self.stem(x)
83 + x = self.layer1(s)
84 + x = self.layer2(x)
85 + x = self.layer3(x)
86 + x = self.avg_pool(x)
87 + x = x.reshape(x.size(0), -1)
88 + x = self.fc(x)
89 +
90 + return x, s
1 +import os
2 +import fire
3 +import time
4 +import json
5 +import random
6 +from pprint import pprint
7 +
8 +import torch.nn as nn
9 +import torch.backends.cudnn as cudnn
10 +from torch.utils.tensorboard import SummaryWriter
11 +
12 +from networks import *
13 +from utils import *
14 +
15 +
16 +def train(**kwargs):
17 + print('\n[+] Parse arguments')
18 + args, kwargs = parse_args(kwargs)
19 + pprint(args)
20 + device = torch.device('cuda' if args.use_cuda else 'cpu')
21 +
22 + print('\n[+] Create log dir')
23 + model_name = get_model_name(args)
24 + log_dir = os.path.join('./runs', model_name)
25 + os.makedirs(os.path.join(log_dir, 'model'))
26 + json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
27 + writer = SummaryWriter(log_dir=log_dir)
28 +
29 + if args.seed is not None:
30 + random.seed(args.seed)
31 + torch.manual_seed(args.seed)
32 + cudnn.deterministic = True
33 +
34 + print('\n[+] Create network')
35 + model = select_model(args)
36 + optimizer = select_optimizer(args, model)
37 + scheduler = select_scheduler(args, optimizer)
38 + criterion = nn.CrossEntropyLoss()
39 + if args.use_cuda:
40 + model = model.cuda()
41 + criterion = criterion.cuda()
42 + #writer.add_graph(model)
43 +
44 + print('\n[+] Load dataset')
45 + transform = get_train_transform(args, model, log_dir)
46 + val_transform = get_valid_transform(args, model)
47 + train_dataset = get_dataset(args, transform, 'train')
48 + valid_dataset = get_dataset(args, val_transform, 'val')
49 + train_loader = iter(get_inf_dataloader(args, train_dataset))
50 + max_epoch = len(train_dataset) // args.batch_size
51 + best_acc = -1
52 +
53 + print('\n[+] Start training')
54 + if torch.cuda.device_count() > 1:
55 + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
56 + model = nn.DataParallel(model)
57 +
58 + start_t = time.time()
59 + for step in range(args.start_step, args.max_step):
60 + batch = next(train_loader)
61 + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer)
62 +
63 + if step % args.print_step == 0:
64 + print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format(
65 + step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr']))
66 + writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step)
67 + writer.add_scalar('train/acc1', _train_res[0], global_step=step)
68 + writer.add_scalar('train/acc5', _train_res[1], global_step=step)
69 + writer.add_scalar('train/loss', _train_res[2], global_step=step)
70 + writer.add_scalar('train/forward_time', _train_res[3], global_step=step)
71 + writer.add_scalar('train/backward_time', _train_res[4], global_step=step)
72 + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
73 + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
74 + print(' Loss : {}'.format(_train_res[2].data))
75 + print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000))
76 + print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
77 +
78 + if step % args.val_step == args.val_step-1:
79 + valid_loader = iter(get_dataloader(args, valid_dataset))
80 + _valid_res = validate(args, model, criterion, valid_loader, step, writer)
81 + print('\n[+] Valid results')
82 + writer.add_scalar('valid/acc1', _valid_res[0], global_step=step)
83 + writer.add_scalar('valid/acc5', _valid_res[1], global_step=step)
84 + writer.add_scalar('valid/loss', _valid_res[2], global_step=step)
85 + print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100))
86 + print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100))
87 + print(' Loss : {}'.format(_valid_res[2].data))
88 +
89 + if _valid_res[0] > best_acc:
90 + best_acc = _valid_res[0]
91 + torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt"))
92 + print('\n[+] Model saved')
93 +
94 + writer.close()
95 +
96 +
97 +if __name__ == '__main__':
98 + fire.Fire(train)
1 +import numpy as np
2 +import torch.nn as nn
3 +import torchvision.transforms as transforms
4 +
5 +from abc import ABC, abstractmethod
6 +from PIL import Image, ImageOps, ImageEnhance
7 +
8 +
9 +class BaseTransform(ABC):
10 +
11 + def __init__(self, prob, mag):
12 + self.prob = prob
13 + self.mag = mag
14 +
15 + def __call__(self, img):
16 + return transforms.RandomApply([self.transform], self.prob)(img)
17 +
18 + def __repr__(self):
19 + return '%s(prob=%.2f, magnitude=%.2f)' % \
20 + (self.__class__.__name__, self.prob, self.mag)
21 +
22 + @abstractmethod
23 + def transform(self, img):
24 + pass
25 +
26 +
27 +class ShearXY(BaseTransform):
28 +
29 + def transform(self, img):
30 + degrees = self.mag * 360
31 + t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR)
32 + return t(img)
33 +
34 +
35 +class TranslateXY(BaseTransform):
36 +
37 + def transform(self, img):
38 + translate = (self.mag, self.mag)
39 + t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR)
40 + return t(img)
41 +
42 +
43 +class Rotate(BaseTransform):
44 +
45 + def transform(self, img):
46 + degrees = self.mag * 360
47 + t = transforms.RandomRotation(degrees, Image.BILINEAR)
48 + return t(img)
49 +
50 +
51 +class AutoContrast(BaseTransform):
52 +
53 + def transform(self, img):
54 + cutoff = int(self.mag * 49)
55 + return ImageOps.autocontrast(img, cutoff=cutoff)
56 +
57 +
58 +class Invert(BaseTransform):
59 +
60 + def transform(self, img):
61 + return ImageOps.invert(img)
62 +
63 +
64 +class Equalize(BaseTransform):
65 +
66 + def transform(self, img):
67 + return ImageOps.equalize(img)
68 +
69 +
70 +class Solarize(BaseTransform):
71 +
72 + def transform(self, img):
73 + threshold = (1-self.mag) * 255
74 + return ImageOps.solarize(img, threshold)
75 +
76 +
77 +class Posterize(BaseTransform):
78 +
79 + def transform(self, img):
80 + bits = int((1-self.mag) * 8)
81 + return ImageOps.posterize(img, bits=bits)
82 +
83 +
84 +class Contrast(BaseTransform):
85 +
86 + def transform(self, img):
87 + factor = self.mag * 10
88 + return ImageEnhance.Contrast(img).enhance(factor)
89 +
90 +
91 +class Color(BaseTransform):
92 +
93 + def transform(self, img):
94 + factor = self.mag * 10
95 + return ImageEnhance.Color(img).enhance(factor)
96 +
97 +
98 +class Brightness(BaseTransform):
99 +
100 + def transform(self, img):
101 + factor = self.mag * 10
102 + return ImageEnhance.Brightness(img).enhance(factor)
103 +
104 +
105 +class Sharpness(BaseTransform):
106 +
107 + def transform(self, img):
108 + factor = self.mag * 10
109 + return ImageEnhance.Sharpness(img).enhance(factor)
110 +
111 +
112 +class Cutout(BaseTransform):
113 +
114 + def transform(self, img):
115 + n_holes = 1
116 + length = 24 * self.mag
117 + cutout_op = CutoutOp(n_holes=n_holes, length=length)
118 + return cutout_op(img)
119 +
120 +
121 +class CutoutOp(object):
122 + """
123 + https://github.com/uoguelph-mlrg/Cutout
124 +
125 + Randomly mask out one or more patches from an image.
126 +
127 + Args:
128 + n_holes (int): Number of patches to cut out of each image.
129 + length (int): The length (in pixels) of each square patch.
130 + """
131 + def __init__(self, n_holes, length):
132 + self.n_holes = n_holes
133 + self.length = length
134 +
135 + def __call__(self, img):
136 + """
137 + Args:
138 + img (Tensor): Tensor image of size (C, H, W).
139 + Returns:
140 + Tensor: Image with n_holes of dimension length x length cut out of it.
141 + """
142 + w, h = img.size
143 +
144 + mask = np.ones((h, w, 1), np.uint8)
145 +
146 + for n in range(self.n_holes):
147 + y = np.random.randint(h)
148 + x = np.random.randint(w)
149 +
150 + y1 = np.clip(y - self.length // 2, 0, h).astype(int)
151 + y2 = np.clip(y + self.length // 2, 0, h).astype(int)
152 + x1 = np.clip(x - self.length // 2, 0, w).astype(int)
153 + x2 = np.clip(x + self.length // 2, 0, w).astype(int)
154 +
155 + mask[y1: y2, x1: x2, :] = 0.
156 +
157 + img = mask*np.asarray(img).astype(np.uint8)
158 + img = Image.fromarray(mask*np.asarray(img))
159 +
160 + return img
161 +
1 +import os
2 +import time
3 +import importlib
4 +import collections
5 +import pickle as cp
6 +import glob
7 +import numpy as np
8 +
9 +import torch
10 +import torchvision
11 +import torch.nn.functional as F
12 +import torchvision.models as models
13 +import torchvision.transforms as transforms
14 +from torch.utils.data import Subset
15 +from torch.utils.data import Dataset, DataLoader
16 +
17 +from sklearn.model_selection import StratifiedShuffleSplit
18 +
19 +
20 +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
21 +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
22 +current_epoch = 0
23 +
24 +
25 +def split_dataset(args, dataset, k):
26 + # load dataset
27 + X = list(range(len(dataset)))
28 + Y = dataset.targets
29 +
30 + # split to k-fold
31 + assert len(X) == len(Y)
32 +
33 + def _it_to_list(_it):
34 + return list(zip(*list(_it)))
35 +
36 + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
37 + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
38 +
39 + return Dm_indexes, Da_indexes
40 +
41 +
42 +def concat_image_features(image, features, max_features=3):
43 + _, h, w = image.shape
44 +
45 + max_features = min(features.size(0), max_features)
46 + image_feature = image.clone()
47 +
48 + for i in range(max_features):
49 + feature = features[i:i+1]
50 + _min, _max = torch.min(feature), torch.max(feature)
51 + feature = (feature - _min) / (_max - _min + 1e-6)
52 + feature = torch.cat([feature]*3, 0)
53 + feature = feature.view(1, 3, feature.size(1), feature.size(2))
54 + feature = F.upsample(feature, size=(h,w), mode="bilinear")
55 + feature = feature.view(3, h, w)
56 + image_feature = torch.cat((image_feature, feature), 2)
57 +
58 + return image_feature
59 +
60 +
61 +def get_model_name(args):
62 + from datetime import datetime
63 + now = datetime.now()
64 + date_time = now.strftime("%B_%d_%H:%M:%S")
65 + model_name = '__'.join([date_time, args.network, str(args.seed)])
66 + return model_name
67 +
68 +
69 +def dict_to_namedtuple(d):
70 + Args = collections.namedtuple('Args', sorted(d.keys()))
71 +
72 + for k,v in d.items():
73 + if type(v) is dict:
74 + d[k] = dict_to_namedtuple(v)
75 +
76 + elif type(v) is str:
77 + try:
78 + d[k] = eval(v)
79 + except:
80 + d[k] = v
81 +
82 + args = Args(**d)
83 + return args
84 +
85 +
86 +def parse_args(kwargs):
87 + # combine with default args
88 + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10'
89 + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10'
90 + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
91 + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1
92 + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
93 + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
94 + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
95 + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
96 + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000
97 + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000
98 + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
99 + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128
100 + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
101 + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000
102 + kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False
103 + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
104 +
105 + # to named tuple
106 + args = dict_to_namedtuple(kwargs)
107 + return args, kwargs
108 +
109 +
110 +def select_model(args):
111 + if args.network in models.__dict__:
112 + backbone = models.__dict__[args.network]()
113 + model = BaseNet(backbone, args)
114 + else:
115 + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
116 + model = Net(args)
117 +
118 + print(model)
119 + return model
120 +
121 +
122 +def select_optimizer(args, model):
123 + if args.optimizer == 'sgd':
124 + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
125 + elif args.optimizer == 'rms':
126 + #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5)
127 + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
128 + elif args.optimizer == 'adam':
129 + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
130 + else:
131 + raise Exception('Unknown Optimizer')
132 + return optimizer
133 +
134 +
135 +def select_scheduler(args, optimizer):
136 + if not args.scheduler or args.scheduler == 'None':
137 + return None
138 + elif args.scheduler =='clr':
139 + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
140 + elif args.scheduler =='exp':
141 + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
142 + else:
143 + raise Exception('Unknown Scheduler')
144 +
145 +
146 +class CustomDataset(Dataset):
147 + def __init__(self, path, transform = None):
148 + self.path = path
149 + self.transform = transform
150 + self.img = np.load(path)
151 + self.len = self.img.shape[0]
152 +
153 + def __len__(self):
154 + return self.len
155 +
156 + def __getitem__(self, idx):
157 + if self.transforms is not None:
158 + img = self.transforms(img)
159 + return img
160 +
161 +def get_dataset(args, transform, split='train'):
162 + assert split in ['train', 'val', 'test', 'trainval']
163 +
164 + if args.dataset == 'cifar10':
165 + train = split in ['train', 'val', 'trainval']
166 + dataset = torchvision.datasets.CIFAR10(DATASET_PATH,
167 + train=train,
168 + transform=transform,
169 + download=True)
170 +
171 + if split in ['train', 'val']:
172 + split_path = os.path.join(DATASET_PATH,
173 + 'cifar-10-batches-py', 'train_val_index.cp')
174 +
175 + if not os.path.exists(split_path):
176 + [train_index], [val_index] = split_dataset(args, dataset, k=1)
177 + split_index = {'train':train_index, 'val':val_index}
178 + cp.dump(split_index, open(split_path, 'wb'))
179 +
180 + split_index = cp.load(open(split_path, 'rb'))
181 + dataset = Subset(dataset, split_index[split])
182 +
183 + elif args.dataset == 'imagenet':
184 + dataset = torchvision.datasets.ImageNet(DATASET_PATH,
185 + split=split,
186 + transform=transform,
187 + download=(split is 'val'))
188 +
189 + elif args.dataset == 'BraTS':
190 + if split in ['train']:
191 + dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform)
192 + else:
193 + dataset = CustomDataset(VAL_DATASET_PATH, transform=transform)
194 +
195 +
196 + else:
197 + raise Exception('Unknown dataset')
198 +
199 + return dataset
200 +
201 +
202 +def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
203 + data_loader = torch.utils.data.DataLoader(dataset,
204 + batch_size=args.batch_size,
205 + shuffle=shuffle,
206 + num_workers=args.num_workers,
207 + pin_memory=pin_memory)
208 + return data_loader
209 +
210 +
211 +def get_inf_dataloader(args, dataset):
212 + global current_epoch
213 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
214 +
215 + while True:
216 + try:
217 + batch = next(data_loader)
218 +
219 + except StopIteration:
220 + current_epoch += 1
221 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
222 + batch = next(data_loader)
223 +
224 + yield batch
225 +
226 +
227 +def get_train_transform(args, model, log_dir=None):
228 + if args.fast_auto_augment:
229 + assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet
230 +
231 + from fast_auto_augment import fast_auto_augment
232 + if args.augment_path:
233 + transform = cp.load(open(args.augment_path, 'rb'))
234 + os.system('cp {} {}'.format(
235 + args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
236 + else:
237 + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4)
238 + if log_dir:
239 + cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
240 +
241 + elif args.dataset == 'cifar10':
242 + transform = transforms.Compose([
243 + transforms.Pad(4),
244 + transforms.RandomCrop(32),
245 + transforms.RandomHorizontalFlip(),
246 + transforms.ToTensor()
247 + ])
248 +
249 + elif args.dataset == 'imagenet':
250 + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
251 + transform = transforms.Compose([
252 + transforms.Resize([resize_h, resize_w]),
253 + transforms.RandomCrop(model.img_size),
254 + transforms.RandomHorizontalFlip(),
255 + transforms.ToTensor()
256 + ])
257 +
258 + elif args.dataset == 'BraTS':
259 + resize_h, resize_w = 256, 256
260 + transform = transforms.Compose([
261 + transforms.Resize([resize_h, resize_w]),
262 + transforms.RandomCrop(model.img_size),
263 + transforms.RandomHorizontalFlip(),
264 + transforms.ToTensor()
265 + ])
266 + else:
267 + raise Exception('Unknown Dataset')
268 +
269 + print(transform)
270 +
271 + return transform
272 +
273 +
274 +def get_valid_transform(args, model):
275 + if args.dataset == 'cifar10':
276 + val_transform = transforms.Compose([
277 + transforms.Resize(32),
278 + transforms.ToTensor()
279 + ])
280 +
281 + elif args.dataset == 'imagenet':
282 + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
283 + val_transform = transforms.Compose([
284 + transforms.Resize([resize_h, resize_w]),
285 + transforms.ToTensor()
286 + ])
287 + elif args.dataset == 'BraTS':
288 + resize_h, resize_w = 256, 256
289 + val_transform = transforms.Compose([
290 + transforms.Resize([resize_h, resize_w]),
291 + transforms.ToTensor()
292 + ])
293 + else:
294 + raise Exception('Unknown Dataset')
295 +
296 + return val_transform
297 +
298 +
299 +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
300 + model.train()
301 + images, target = batch
302 +
303 + if device:
304 + images = images.to(device)
305 + target = target.to(device)
306 +
307 + elif args.use_cuda:
308 + images = images.cuda(non_blocking=True)
309 + target = target.cuda(non_blocking=True)
310 +
311 + # compute output
312 + start_t = time.time()
313 + output, first = model(images)
314 + forward_t = time.time() - start_t
315 + loss = criterion(output, target)
316 +
317 + # measure accuracy and record loss
318 + acc1, acc5 = accuracy(output, target, topk=(1, 5))
319 + acc1 /= images.size(0)
320 + acc5 /= images.size(0)
321 +
322 + # compute gradient and do SGD step
323 + optimizer.zero_grad()
324 + start_t = time.time()
325 + loss.backward()
326 + backward_t = time.time() - start_t
327 + optimizer.step()
328 + if scheduler: scheduler.step()
329 +
330 + if writer and step % args.print_step == 0:
331 + n_imgs = min(images.size(0), 10)
332 + for j in range(n_imgs):
333 + writer.add_image('train/input_image',
334 + concat_image_features(images[j], first[j]), global_step=step)
335 +
336 + return acc1, acc5, loss, forward_t, backward_t
337 +
338 +
339 +def validate(args, model, criterion, valid_loader, step, writer, device=None):
340 + # switch to evaluate mode
341 + model.eval()
342 +
343 + acc1, acc5 = 0, 0
344 + samples = 0
345 + infer_t = 0
346 +
347 + with torch.no_grad():
348 + for i, (images, target) in enumerate(valid_loader):
349 +
350 + start_t = time.time()
351 + if device:
352 + images = images.to(device)
353 + target = target.to(device)
354 +
355 + elif args.use_cuda is not None:
356 + images = images.cuda(non_blocking=True)
357 + target = target.cuda(non_blocking=True)
358 +
359 + # compute output
360 + output, first = model(images)
361 + loss = criterion(output, target)
362 + infer_t += time.time() - start_t
363 +
364 + # measure accuracy and record loss
365 + _acc1, _acc5 = accuracy(output, target, topk=(1, 5))
366 + acc1 += _acc1
367 + acc5 += _acc5
368 + samples += images.size(0)
369 +
370 + acc1 /= samples
371 + acc5 /= samples
372 +
373 + if writer:
374 + n_imgs = min(images.size(0), 10)
375 + for j in range(n_imgs):
376 + writer.add_image('valid/input_image',
377 + concat_image_features(images[j], first[j]), global_step=step)
378 +
379 + return acc1, acc5, loss, infer_t
380 +
381 +
382 +def accuracy(output, target, topk=(1,)):
383 + """Computes the accuracy over the k top predictions for the specified values of k"""
384 + with torch.no_grad():
385 + maxk = max(topk)
386 + batch_size = target.size(0)
387 +
388 + _, pred = output.topk(maxk, 1, True, True)
389 + pred = pred.t()
390 + correct = pred.eq(target.view(1, -1).expand_as(pred))
391 +
392 + res = []
393 + for k in topk:
394 + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
395 + res.append(correct_k)
396 + return res
397 +