Showing
21 changed files
with
1636 additions
and
0 deletions
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | + | ||
6 | +# C extensions | ||
7 | +*.so | ||
8 | + | ||
9 | +# Distribution / packaging | ||
10 | +.Python | ||
11 | +build/ | ||
12 | +develop-eggs/ | ||
13 | +dist/ | ||
14 | +downloads/ | ||
15 | +eggs/ | ||
16 | +.eggs/ | ||
17 | +lib/ | ||
18 | +lib64/ | ||
19 | +parts/ | ||
20 | +sdist/ | ||
21 | +var/ | ||
22 | +wheels/ | ||
23 | +*.egg-info/ | ||
24 | +.installed.cfg | ||
25 | +*.egg | ||
26 | +MANIFEST | ||
27 | + | ||
28 | +# PyInstaller | ||
29 | +# Usually these files are written by a python script from a template | ||
30 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
31 | +*.manifest | ||
32 | +*.spec | ||
33 | + | ||
34 | +# Installer logs | ||
35 | +pip-log.txt | ||
36 | +pip-delete-this-directory.txt | ||
37 | + | ||
38 | +# Unit test / coverage reports | ||
39 | +htmlcov/ | ||
40 | +.tox/ | ||
41 | +.coverage | ||
42 | +.coverage.* | ||
43 | +.cache | ||
44 | +nosetests.xml | ||
45 | +coverage.xml | ||
46 | +*.cover | ||
47 | +.hypothesis/ | ||
48 | +.pytest_cache/ | ||
49 | + | ||
50 | +# Translations | ||
51 | +*.mo | ||
52 | +*.pot | ||
53 | + | ||
54 | +# Django stuff: | ||
55 | +*.log | ||
56 | +local_settings.py | ||
57 | +db.sqlite3 | ||
58 | + | ||
59 | +# Flask stuff: | ||
60 | +instance/ | ||
61 | +.webassets-cache | ||
62 | + | ||
63 | +# Scrapy stuff: | ||
64 | +.scrapy | ||
65 | + | ||
66 | +# Sphinx documentation | ||
67 | +docs/_build/ | ||
68 | + | ||
69 | +# PyBuilder | ||
70 | +target/ | ||
71 | + | ||
72 | +# Jupyter Notebook | ||
73 | +.ipynb_checkpoints | ||
74 | + | ||
75 | +# pyenv | ||
76 | +.python-version | ||
77 | + | ||
78 | +# celery beat schedule file | ||
79 | +celerybeat-schedule | ||
80 | + | ||
81 | +# SageMath parsed files | ||
82 | +*.sage.py | ||
83 | + | ||
84 | +# Environments | ||
85 | +.env | ||
86 | +.venv | ||
87 | +env/ | ||
88 | +venv/ | ||
89 | +ENV/ | ||
90 | +env.bak/ | ||
91 | +venv.bak/ | ||
92 | + | ||
93 | +# Spyder project settings | ||
94 | +.spyderproject | ||
95 | +.spyproject | ||
96 | + | ||
97 | +# Rope project settings | ||
98 | +.ropeproject | ||
99 | + | ||
100 | +# mkdocs documentation | ||
101 | +/site | ||
102 | + | ||
103 | +# mypy | ||
104 | +.mypy_cache/ | ||
105 | + |
1 | +# Fast Autoaugment | ||
2 | +<img src="figures/faa.png" width=800px> | ||
3 | + | ||
4 | +A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946). | ||
5 | + | ||
6 | +## Prerequisite | ||
7 | +* torch==1.1.0 | ||
8 | +* torchvision==0.2.2 | ||
9 | +* hyperopt==0.1.2 | ||
10 | +* future==0.17.1 | ||
11 | +* tb-nightly==1.15.0a20190622 | ||
12 | + | ||
13 | +## Usage | ||
14 | +### Training | ||
15 | +#### CIFAR10 | ||
16 | +```bash | ||
17 | +# ResNet20 (w/o FastAutoAugment) | ||
18 | +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False | ||
19 | + | ||
20 | +# ResNet20 (w/ FastAutoAugment) | ||
21 | +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True | ||
22 | + | ||
23 | +# ResNet20 (w/ FastAutoAugment, Pre-found policy) | ||
24 | +python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \ | ||
25 | + --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp | ||
26 | + | ||
27 | +# ResNet32 (w/o FastAutoAugment) | ||
28 | +python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False | ||
29 | + | ||
30 | +# ResNet32 (w/ FastAutoAugment) | ||
31 | +python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True | ||
32 | + | ||
33 | +# EfficientNet (w/ FastAutoAugment) | ||
34 | +python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
35 | + --network=efficientnet_cifar10 --activation=swish | ||
36 | +``` | ||
37 | + | ||
38 | +#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html)) | ||
39 | +```bash | ||
40 | + | ||
41 | +# BaseNet (w/o FastAutoAugment) | ||
42 | +python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50 | ||
43 | + | ||
44 | +# EfficientNet (w/ FastAutoAugment) (UnderConstruction) | ||
45 | +python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
46 | + --network=efficientnet --activation=swish | ||
47 | +``` | ||
48 | + | ||
49 | +### Eval | ||
50 | +```bash | ||
51 | +# Single Image testing | ||
52 | +python eval.py --model_path=runs/ResNet_Scale3_Basline | ||
53 | + | ||
54 | +# 5-crops testing | ||
55 | +python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True | ||
56 | +``` | ||
57 | + | ||
58 | +## Experiments | ||
59 | +### Fast AutoAugment | ||
60 | +#### ResNet20 (CIFAR10) | ||
61 | +* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)] | ||
62 | +* Validation Curve | ||
63 | +<img src="figures/resnet20_valid.png"> | ||
64 | + | ||
65 | +* Evaluation (Acc @1) | ||
66 | + | ||
67 | +| | Valid | Test(Single) | | ||
68 | +|----------------|-------|-------------| | ||
69 | +| ResNet20 | 90.70 | **91.45** | | ||
70 | +| ResNet20 + FAA |**92.46**| **91.45** | | ||
71 | + | ||
72 | +#### ResNet34 (CIFAR10) | ||
73 | +* Validation Curve | ||
74 | +<img src="figures/resnet34_valid.png"> | ||
75 | + | ||
76 | +* Evaluation (Acc @1) | ||
77 | + | ||
78 | +| | Valid | Test(Single) | | ||
79 | +|----------------|-------|-------------| | ||
80 | +| ResNet34 | 91.54 | 91.47 | | ||
81 | +| ResNet34 + FAA |**92.76**| **91.99** | | ||
82 | + | ||
83 | +### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)] | ||
84 | +<img src="figures/pm.png"> | ||
85 | + | ||
86 | +### Augmented images | ||
87 | +<img src="figures/augmented_images.png"> | ||
88 | +<img src="figures/augmented_images2.png"> |
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | + | ||
6 | +import torch | ||
7 | +import torch.nn as nn | ||
8 | + | ||
9 | +from utils import * | ||
10 | + | ||
11 | + | ||
12 | +def eval(model_path): | ||
13 | + print('\n[+] Parse arguments') | ||
14 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
15 | + kwargs = json.loads(open(kwargs_path).read()) | ||
16 | + args, kwargs = parse_args(kwargs) | ||
17 | + pprint(args) | ||
18 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
19 | + | ||
20 | + print('\n[+] Create network') | ||
21 | + model = select_model(args) | ||
22 | + optimizer = select_optimizer(args, model) | ||
23 | + criterion = nn.CrossEntropyLoss() | ||
24 | + if args.use_cuda: | ||
25 | + model = model.cuda() | ||
26 | + criterion = criterion.cuda() | ||
27 | + | ||
28 | + print('\n[+] Load model') | ||
29 | + weight_path = os.path.join(model_path, 'model', 'model.pt') | ||
30 | + model.load_state_dict(torch.load(weight_path)) | ||
31 | + | ||
32 | + print('\n[+] Load dataset') | ||
33 | + test_transform = get_valid_transform(args, model) | ||
34 | + test_dataset = get_dataset(args, test_transform, 'test') | ||
35 | + test_loader = iter(get_dataloader(args, test_dataset)) | ||
36 | + | ||
37 | + print('\n[+] Start testing') | ||
38 | + _test_res = validate(args, model, criterion, test_loader, step=0, writer=None) | ||
39 | + | ||
40 | + print('\n[+] Valid results') | ||
41 | + print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | ||
42 | + print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | ||
43 | + print(' Loss : {:.3f}'.format(_test_res[2].data)) | ||
44 | + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset))) | ||
45 | + | ||
46 | + | ||
47 | +if __name__ == '__main__': | ||
48 | + fire.Fire(eval) |
1 | +import copy | ||
2 | +import json | ||
3 | +import time | ||
4 | +import torch | ||
5 | +import random | ||
6 | +import torchvision.transforms as transforms | ||
7 | + | ||
8 | +from torch.utils.data import Subset | ||
9 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
10 | +from concurrent.futures import ProcessPoolExecutor | ||
11 | + | ||
12 | +from transforms import * | ||
13 | +from hyperopt import fmin, tpe, hp, STATUS_OK, Trials | ||
14 | +from utils import * | ||
15 | + | ||
16 | + | ||
17 | +DEFALUT_CANDIDATES = [ | ||
18 | + ShearXY, | ||
19 | + TranslateXY, | ||
20 | + Rotate, | ||
21 | + AutoContrast, | ||
22 | + Invert, | ||
23 | + Equalize, | ||
24 | + Solarize, | ||
25 | + Posterize, | ||
26 | + Contrast, | ||
27 | + Color, | ||
28 | + Brightness, | ||
29 | + Sharpness, | ||
30 | + Cutout, | ||
31 | +# SamplePairing, | ||
32 | +] | ||
33 | + | ||
34 | + | ||
35 | +def train_child(args, model, dataset, subset_indx, device=None): | ||
36 | + optimizer = select_optimizer(args, model) | ||
37 | + scheduler = select_scheduler(args, optimizer) | ||
38 | + criterion = nn.CrossEntropyLoss() | ||
39 | + | ||
40 | + dataset.transform = transforms.Compose([ | ||
41 | + transforms.Resize(32), | ||
42 | + transforms.ToTensor()]) | ||
43 | + subset = Subset(dataset, subset_indx) | ||
44 | + data_loader = get_inf_dataloader(args, subset) | ||
45 | + | ||
46 | + if device: | ||
47 | + model = model.to(device) | ||
48 | + criterion = criterion.to(device) | ||
49 | + | ||
50 | + elif args.use_cuda: | ||
51 | + model = model.cuda() | ||
52 | + criterion = criterion.cuda() | ||
53 | + | ||
54 | + if torch.cuda.device_count() > 1: | ||
55 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
56 | + model = nn.DataParallel(model) | ||
57 | + | ||
58 | + start_t = time.time() | ||
59 | + for step in range(args.start_step, args.max_step): | ||
60 | + batch = next(data_loader) | ||
61 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | ||
62 | + | ||
63 | + if step % args.print_step == 0: | ||
64 | + print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format( | ||
65 | + step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device)) | ||
66 | + | ||
67 | + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
68 | + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
69 | + print(' Loss : {}'.format(_train_res[2].data)) | ||
70 | + | ||
71 | + return _train_res | ||
72 | + | ||
73 | + | ||
74 | +def validate_child(args, model, dataset, subset_indx, transform, device=None): | ||
75 | + criterion = nn.CrossEntropyLoss() | ||
76 | + | ||
77 | + if device: | ||
78 | + model = model.to(device) | ||
79 | + criterion = criterion.to(device) | ||
80 | + | ||
81 | + elif args.use_cuda: | ||
82 | + model = model.cuda() | ||
83 | + criterion = criterion.cuda() | ||
84 | + | ||
85 | + dataset.transform = transform | ||
86 | + subset = Subset(dataset, subset_indx) | ||
87 | + data_loader = get_dataloader(args, subset, pin_memory=False) | ||
88 | + | ||
89 | + return validate(args, model, criterion, data_loader, 0, None, device) | ||
90 | + | ||
91 | + | ||
92 | +def get_next_subpolicy(transform_candidates, op_per_subpolicy=2): | ||
93 | + n_candidates = len(transform_candidates) | ||
94 | + subpolicy = [] | ||
95 | + | ||
96 | + for i in range(op_per_subpolicy): | ||
97 | + indx = random.randrange(n_candidates) | ||
98 | + prob = random.random() | ||
99 | + mag = random.random() | ||
100 | + subpolicy.append(transform_candidates[indx](prob, mag)) | ||
101 | + | ||
102 | + subpolicy = transforms.Compose([ | ||
103 | + *subpolicy, | ||
104 | + transforms.Resize(32), | ||
105 | + transforms.ToTensor()]) | ||
106 | + | ||
107 | + return subpolicy | ||
108 | + | ||
109 | + | ||
110 | +def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
111 | + subpolicies = [] | ||
112 | + | ||
113 | + for b in range(B): | ||
114 | + subpolicy = get_next_subpolicy(transform_candidates) | ||
115 | + val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
116 | + subpolicies.append((subpolicy, val_res[2])) | ||
117 | + | ||
118 | + return subpolicies | ||
119 | + | ||
120 | + | ||
121 | +def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
122 | + | ||
123 | + def _objective(sampled): | ||
124 | + subpolicy = [transform(prob, mag) | ||
125 | + for transform, prob, mag in sampled] | ||
126 | + | ||
127 | + subpolicy = transforms.Compose([ | ||
128 | + transforms.Resize(32), | ||
129 | + *subpolicy, | ||
130 | + transforms.ToTensor()]) | ||
131 | + | ||
132 | + val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
133 | + loss = val_res[2].cpu().numpy() | ||
134 | + return {'loss': loss, 'status': STATUS_OK } | ||
135 | + | ||
136 | + space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)), | ||
137 | + (hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))] | ||
138 | + | ||
139 | + trials = Trials() | ||
140 | + best = fmin(_objective, | ||
141 | + space=space, | ||
142 | + algo=tpe.suggest, | ||
143 | + max_evals=B, | ||
144 | + trials=trials) | ||
145 | + | ||
146 | + subpolicies = [] | ||
147 | + for t in trials.trials: | ||
148 | + vals = t['misc']['vals'] | ||
149 | + subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]), | ||
150 | + transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])] | ||
151 | + subpolicy = transforms.Compose([ | ||
152 | + ## baseline augmentation | ||
153 | + transforms.Pad(4), | ||
154 | + transforms.RandomCrop(32), | ||
155 | + transforms.RandomHorizontalFlip(), | ||
156 | + ## policy | ||
157 | + *subpolicy, | ||
158 | + ## to tensor | ||
159 | + transforms.ToTensor()]) | ||
160 | + subpolicies.append((subpolicy, t['result']['loss'])) | ||
161 | + | ||
162 | + return subpolicies | ||
163 | + | ||
164 | + | ||
165 | +def get_topn_subpolicies(subpolicies, N=10): | ||
166 | + return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N] | ||
167 | + | ||
168 | + | ||
169 | +def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k): | ||
170 | + kwargs = json.loads(args_str) | ||
171 | + args, kwargs = parse_args(kwargs) | ||
172 | + device_id = k % torch.cuda.device_count() | ||
173 | + device = torch.device('cuda:%d' % device_id) | ||
174 | + _transform = [] | ||
175 | + | ||
176 | + print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | ||
177 | + | ||
178 | + # train child model | ||
179 | + child_model = copy.deepcopy(model) | ||
180 | + train_res = train_child(args, child_model, dataset, Dm_indx, device) | ||
181 | + | ||
182 | + # search sub policy | ||
183 | + for t in range(T): | ||
184 | + #subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
185 | + subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
186 | + subpolicies = get_topn_subpolicies(subpolicies, N) | ||
187 | + _transform.extend([subpolicy[0] for subpolicy in subpolicies]) | ||
188 | + | ||
189 | + return _transform | ||
190 | + | ||
191 | + | ||
192 | +def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | ||
193 | + args_str = json.dumps(args._asdict()) | ||
194 | + dataset = get_dataset(args, None, 'trainval') | ||
195 | + num_process = min(torch.cuda.device_count(), num_process) | ||
196 | + transform, futures = [], [] | ||
197 | + | ||
198 | + torch.multiprocessing.set_start_method('spawn', force=True) | ||
199 | + | ||
200 | + if not transform_candidates: | ||
201 | + transform_candidates = DEFALUT_CANDIDATES | ||
202 | + | ||
203 | + # split | ||
204 | + Dm_indexes, Da_indexes = split_dataset(args, dataset, K) | ||
205 | + | ||
206 | + with ProcessPoolExecutor(max_workers=num_process) as executor: | ||
207 | + for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)): | ||
208 | + future = executor.submit(process_fn, | ||
209 | + args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k) | ||
210 | + futures.append(future) | ||
211 | + | ||
212 | + for future in futures: | ||
213 | + transform.extend(future.result()) | ||
214 | + | ||
215 | + transform = transforms.RandomChoice(transform) | ||
216 | + | ||
217 | + return transform |
File mode changed
92.9 KB
410 KB
425 KB
423 KB
48.3 KB
229 KB
227 KB
1 | +from .basenet import BaseNet |
1 | +import torch.nn as nn | ||
2 | + | ||
3 | +class BaseNet(nn.Module): | ||
4 | + def __init__(self, backbone, args): | ||
5 | + super(BaseNet, self).__init__() | ||
6 | + | ||
7 | + # Separate layers | ||
8 | + self.first = nn.Sequential(*list(backbone.children())[:1]) | ||
9 | + self.after = nn.Sequential(*list(backbone.children())[1:-1]) | ||
10 | + self.fc = list(backbone.children())[-1] | ||
11 | + | ||
12 | + self.img_size = (224, 224) | ||
13 | + | ||
14 | + def forward(self, x): | ||
15 | + f = self.first(x) | ||
16 | + x = self.after(f) | ||
17 | + x = x.reshape(x.size(0), -1) | ||
18 | + x = self.fc(x) | ||
19 | + return x, f |
1 | +import math | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | + | ||
5 | + | ||
6 | +def round_fn(orig, multiplier): | ||
7 | + if not multiplier: | ||
8 | + return orig | ||
9 | + | ||
10 | + return int(math.ceil(multiplier * orig)) | ||
11 | + | ||
12 | + | ||
13 | +def get_activation_fn(activation): | ||
14 | + if activation == "swish": | ||
15 | + return Swish | ||
16 | + | ||
17 | + elif activation == "relu": | ||
18 | + return nn.ReLU | ||
19 | + | ||
20 | + else: | ||
21 | + raise Exception('Unkown activation %s' % activation) | ||
22 | + | ||
23 | + | ||
24 | +class Swish(nn.Module): | ||
25 | + """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
26 | + | ||
27 | + def __init__(self, inplace=False): | ||
28 | + super().__init__() | ||
29 | + self.inplace = True | ||
30 | + | ||
31 | + def forward(self, x): | ||
32 | + if self.inplace: | ||
33 | + x.mul_(F.sigmoid(x)) | ||
34 | + return x | ||
35 | + else: | ||
36 | + return x * F.sigmoid(x) | ||
37 | + | ||
38 | + | ||
39 | +class ConvBlock(nn.Module): | ||
40 | + """ Conv + BatchNorm + Activation """ | ||
41 | + | ||
42 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
43 | + padding=0, stride=1, activation="swish"): | ||
44 | + super().__init__() | ||
45 | + self.fw = nn.Sequential( | ||
46 | + nn.Conv2d(in_channel, out_channel, kernel_size, | ||
47 | + padding=padding, stride=stride, bias=False), | ||
48 | + nn.BatchNorm2d(out_channel), | ||
49 | + get_activation_fn(activation)()) | ||
50 | + | ||
51 | + def forward(self, x): | ||
52 | + return self.fw(x) | ||
53 | + | ||
54 | + | ||
55 | +class DepthwiseConvBlock(nn.Module): | ||
56 | + """ DepthwiseConv2D + BatchNorm + Activation """ | ||
57 | + | ||
58 | + def __init__(self, in_channel, kernel_size, | ||
59 | + padding=0, stride=1, activation="swish"): | ||
60 | + super().__init__() | ||
61 | + self.fw = nn.Sequential( | ||
62 | + nn.Conv2d(in_channel, in_channel, kernel_size, | ||
63 | + padding=padding, stride=stride, groups=in_channel, bias=False), | ||
64 | + nn.BatchNorm2d(in_channel), | ||
65 | + get_activation_fn(activation)()) | ||
66 | + | ||
67 | + def forward(self, x): | ||
68 | + return self.fw(x) | ||
69 | + | ||
70 | + | ||
71 | +class MBConv(nn.Module): | ||
72 | + """ Inverted residual block """ | ||
73 | + | ||
74 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
75 | + stride=1, expand_ratio=1, activation="swish"): | ||
76 | + super().__init__() | ||
77 | + self.in_channel = in_channel | ||
78 | + self.out_channel = out_channel | ||
79 | + self.expand_ratio = expand_ratio | ||
80 | + self.stride = stride | ||
81 | + | ||
82 | + if expand_ratio != 1: | ||
83 | + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
84 | + activation=activation) | ||
85 | + | ||
86 | + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
87 | + padding=(kernel_size-1)//2, | ||
88 | + stride=stride, activation=activation) | ||
89 | + | ||
90 | + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
91 | + activation=activation) | ||
92 | + | ||
93 | + def forward(self, inputs): | ||
94 | + if self.expand_ratio != 1: | ||
95 | + x = self.expand(inputs) | ||
96 | + else: | ||
97 | + x = inputs | ||
98 | + | ||
99 | + x = self.dw_conv(x) | ||
100 | + x = self.pw_conv(x) | ||
101 | + | ||
102 | + if self.in_channel == self.out_channel and \ | ||
103 | + self.stride == 1: | ||
104 | + x = x + inputs | ||
105 | + | ||
106 | + return x | ||
107 | + | ||
108 | + | ||
109 | +class Net(nn.Module): | ||
110 | + """ EfficientNet """ | ||
111 | + | ||
112 | + def __init__(self, args): | ||
113 | + super(Net, self).__init__() | ||
114 | + pi = args.pi | ||
115 | + activation = args.activation | ||
116 | + num_classes = args.num_classes | ||
117 | + | ||
118 | + self.d = 1.2 ** pi | ||
119 | + self.w = 1.1 ** pi | ||
120 | + self.r = 1.15 ** pi | ||
121 | + self.img_size = (round_fn(224, self.r), round_fn(224, self.r)) | ||
122 | + | ||
123 | + self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
124 | + kernel_size=3, padding=1, stride=2, activation=activation) | ||
125 | + | ||
126 | + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
127 | + depth=round_fn(1, self.d), kernel_size=3, | ||
128 | + half_resolution=False, expand_ratio=1, activation=activation) | ||
129 | + | ||
130 | + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
131 | + depth=round_fn(2, self.d), kernel_size=3, | ||
132 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
133 | + | ||
134 | + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
135 | + depth=round_fn(2, self.d), kernel_size=5, | ||
136 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
137 | + | ||
138 | + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
139 | + depth=round_fn(3, self.d), kernel_size=3, | ||
140 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
141 | + | ||
142 | + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
143 | + depth=round_fn(3, self.d), kernel_size=5, | ||
144 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
145 | + | ||
146 | + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
147 | + depth=round_fn(4, self.d), kernel_size=5, | ||
148 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
149 | + | ||
150 | + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
151 | + depth=round_fn(1, self.d), kernel_size=3, | ||
152 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
153 | + | ||
154 | + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
155 | + kernel_size=1, activation=activation) | ||
156 | + | ||
157 | + self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes) | ||
158 | + | ||
159 | + def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
160 | + half_resolution=False, expand_ratio=1, activation="swish"): | ||
161 | + blocks = [] | ||
162 | + for i in range(depth): | ||
163 | + stride = 2 if half_resolution and i==0 else 1 | ||
164 | + blocks.append( | ||
165 | + MBConv(in_channel, out_channel, kernel_size, | ||
166 | + stride=stride, expand_ratio=expand_ratio, activation=activation)) | ||
167 | + in_channel = out_channel | ||
168 | + | ||
169 | + return nn.Sequential(*blocks) | ||
170 | + | ||
171 | + def forward(self, x): | ||
172 | + assert x.size()[-2:] == self.img_size, \ | ||
173 | + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
174 | + | ||
175 | + x = self.stage1(x) | ||
176 | + x = self.stage2(x) | ||
177 | + x = self.stage3(x) | ||
178 | + x = self.stage4(x) | ||
179 | + x = self.stage5(x) | ||
180 | + x = self.stage6(x) | ||
181 | + x = self.stage7(x) | ||
182 | + x = self.stage8(x) | ||
183 | + x = self.stage9(x) | ||
184 | + x = x.reshape(x.size(0), -1) | ||
185 | + x = self.fc(x) | ||
186 | + return x, x |
1 | +import math | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | + | ||
5 | + | ||
6 | +def round_fn(orig, multiplier): | ||
7 | + if not multiplier: | ||
8 | + return orig | ||
9 | + | ||
10 | + return int(math.ceil(multiplier * orig)) | ||
11 | + | ||
12 | + | ||
13 | +def get_activation_fn(activation): | ||
14 | + if activation == "swish": | ||
15 | + return Swish | ||
16 | + | ||
17 | + elif activation == "relu": | ||
18 | + return nn.ReLU | ||
19 | + | ||
20 | + else: | ||
21 | + raise Exception('Unkown activation %s' % activation) | ||
22 | + | ||
23 | + | ||
24 | +class Swish(nn.Module): | ||
25 | + """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
26 | + | ||
27 | + def __init__(self, inplace=False): | ||
28 | + super().__init__() | ||
29 | + self.inplace = True | ||
30 | + | ||
31 | + def forward(self, x): | ||
32 | + if self.inplace: | ||
33 | + x.mul_(F.sigmoid(x)) | ||
34 | + return x | ||
35 | + else: | ||
36 | + return x * F.sigmoid(x) | ||
37 | + | ||
38 | + | ||
39 | +class ConvBlock(nn.Module): | ||
40 | + """ Conv + BatchNorm + Activation """ | ||
41 | + | ||
42 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
43 | + padding=0, stride=1, activation="swish"): | ||
44 | + super().__init__() | ||
45 | + self.fw = nn.Sequential( | ||
46 | + nn.Conv2d(in_channel, out_channel, kernel_size, | ||
47 | + padding=padding, stride=stride, bias=False), | ||
48 | + nn.BatchNorm2d(out_channel), | ||
49 | + get_activation_fn(activation)()) | ||
50 | + | ||
51 | + def forward(self, x): | ||
52 | + return self.fw(x) | ||
53 | + | ||
54 | + | ||
55 | +class DepthwiseConvBlock(nn.Module): | ||
56 | + """ DepthwiseConv2D + BatchNorm + Activation """ | ||
57 | + | ||
58 | + def __init__(self, in_channel, kernel_size, | ||
59 | + padding=0, stride=1, activation="swish"): | ||
60 | + super().__init__() | ||
61 | + self.fw = nn.Sequential( | ||
62 | + nn.Conv2d(in_channel, in_channel, kernel_size, | ||
63 | + padding=padding, stride=stride, groups=in_channel, bias=False), | ||
64 | + nn.BatchNorm2d(in_channel), | ||
65 | + get_activation_fn(activation)()) | ||
66 | + | ||
67 | + def forward(self, x): | ||
68 | + return self.fw(x) | ||
69 | + | ||
70 | + | ||
71 | +class SEBlock(nn.Module): | ||
72 | + """ Squeeze and Excitation Block """ | ||
73 | + | ||
74 | + def __init__(self, in_channel, se_ratio=16): | ||
75 | + super().__init__() | ||
76 | + self.global_avgpool = nn.AdaptiveAvgPool2d((1,1)) | ||
77 | + inter_channel = in_channel // se_ratio | ||
78 | + | ||
79 | + self.reduce = nn.Sequential( | ||
80 | + nn.Conv2d(in_channel, inter_channel, | ||
81 | + kernel_size=1, padding=0, stride=1), | ||
82 | + nn.ReLU()) | ||
83 | + | ||
84 | + self.expand = nn.Sequential( | ||
85 | + nn.Conv2d(inter_channel, in_channel, | ||
86 | + kernel_size=1, padding=0, stride=1), | ||
87 | + nn.Sigmoid()) | ||
88 | + | ||
89 | + | ||
90 | + def forward(self, x): | ||
91 | + s = self.global_avgpool(x) | ||
92 | + s = self.reduce(s) | ||
93 | + s = self.expand(s) | ||
94 | + return x * s | ||
95 | + | ||
96 | + | ||
97 | +class MBConv(nn.Module): | ||
98 | + """ Inverted residual block """ | ||
99 | + | ||
100 | + def __init__(self, in_channel, out_channel, kernel_size, | ||
101 | + stride=1, expand_ratio=1, activation="swish", use_seblock=False): | ||
102 | + super().__init__() | ||
103 | + self.in_channel = in_channel | ||
104 | + self.out_channel = out_channel | ||
105 | + self.expand_ratio = expand_ratio | ||
106 | + self.stride = stride | ||
107 | + self.use_seblock = use_seblock | ||
108 | + | ||
109 | + if expand_ratio != 1: | ||
110 | + self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
111 | + activation=activation) | ||
112 | + | ||
113 | + self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
114 | + padding=(kernel_size-1)//2, | ||
115 | + stride=stride, activation=activation) | ||
116 | + | ||
117 | + if use_seblock: | ||
118 | + self.seblock = SEBlock(in_channel*expand_ratio) | ||
119 | + | ||
120 | + self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
121 | + activation=activation) | ||
122 | + | ||
123 | + def forward(self, inputs): | ||
124 | + if self.expand_ratio != 1: | ||
125 | + x = self.expand(inputs) | ||
126 | + else: | ||
127 | + x = inputs | ||
128 | + | ||
129 | + x = self.dw_conv(x) | ||
130 | + | ||
131 | + if self.use_seblock: | ||
132 | + x = self.seblock(x) | ||
133 | + | ||
134 | + x = self.pw_conv(x) | ||
135 | + | ||
136 | + if self.in_channel == self.out_channel and \ | ||
137 | + self.stride == 1: | ||
138 | + x = x + inputs | ||
139 | + | ||
140 | + return x | ||
141 | + | ||
142 | + | ||
143 | +class Net(nn.Module): | ||
144 | + """ EfficientNet """ | ||
145 | + | ||
146 | + def __init__(self, args): | ||
147 | + super(Net, self).__init__() | ||
148 | + pi = args.pi | ||
149 | + activation = args.activation | ||
150 | + num_classes = 10 | ||
151 | + | ||
152 | + self.d = 1.2 ** pi | ||
153 | + self.w = 1.1 ** pi | ||
154 | + self.r = 1.15 ** pi | ||
155 | + self.img_size = (round_fn(32, self.r), round_fn(32, self.r)) | ||
156 | + self.use_seblock = args.use_seblock | ||
157 | + | ||
158 | + self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
159 | + kernel_size=3, padding=1, stride=2, activation=activation) | ||
160 | + | ||
161 | + self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
162 | + depth=round_fn(1, self.d), kernel_size=3, | ||
163 | + half_resolution=False, expand_ratio=1, activation=activation) | ||
164 | + | ||
165 | + self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
166 | + depth=round_fn(2, self.d), kernel_size=3, | ||
167 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
168 | + | ||
169 | + self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
170 | + depth=round_fn(2, self.d), kernel_size=5, | ||
171 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
172 | + | ||
173 | + self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
174 | + depth=round_fn(3, self.d), kernel_size=3, | ||
175 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
176 | + | ||
177 | + self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
178 | + depth=round_fn(3, self.d), kernel_size=5, | ||
179 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
180 | + | ||
181 | + self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
182 | + depth=round_fn(4, self.d), kernel_size=5, | ||
183 | + half_resolution=True, expand_ratio=6, activation=activation) | ||
184 | + | ||
185 | + self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
186 | + depth=round_fn(1, self.d), kernel_size=3, | ||
187 | + half_resolution=False, expand_ratio=6, activation=activation) | ||
188 | + | ||
189 | + self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
190 | + kernel_size=1, activation=activation) | ||
191 | + | ||
192 | + self.fc = nn.Linear(round_fn(1280, self.w), num_classes) | ||
193 | + | ||
194 | + def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
195 | + half_resolution=False, expand_ratio=1, activation="swish"): | ||
196 | + blocks = [] | ||
197 | + for i in range(depth): | ||
198 | + stride = 2 if half_resolution and i==0 else 1 | ||
199 | + blocks.append( | ||
200 | + MBConv(in_channel, out_channel, kernel_size, | ||
201 | + stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock)) | ||
202 | + in_channel = out_channel | ||
203 | + | ||
204 | + return nn.Sequential(*blocks) | ||
205 | + | ||
206 | + def forward(self, x): | ||
207 | + assert x.size()[-2:] == self.img_size, \ | ||
208 | + 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
209 | + | ||
210 | + s = self.stage1(x) | ||
211 | + x = self.stage2(s) | ||
212 | + x = self.stage3(x) | ||
213 | + x = self.stage4(x) | ||
214 | + x = self.stage5(x) | ||
215 | + x = self.stage6(x) | ||
216 | + x = self.stage7(x) | ||
217 | + x = self.stage8(x) | ||
218 | + x = self.stage9(x) | ||
219 | + x = x.reshape(x.size(0), -1) | ||
220 | + x = self.fc(x) | ||
221 | + return x, s |
1 | +import torch.nn as nn | ||
2 | + | ||
3 | + | ||
4 | +class ResidualBlock(nn.Module): | ||
5 | + def __init__(self, in_channel, out_channel, stride): | ||
6 | + super(ResidualBlock, self).__init__() | ||
7 | + self.in_channel = in_channel | ||
8 | + self.out_channel = out_channel | ||
9 | + self.stride = stride | ||
10 | + | ||
11 | + self.conv1 = nn.Sequential( | ||
12 | + nn.Conv2d(in_channel, out_channel, | ||
13 | + kernel_size=3, padding=1, stride=stride), | ||
14 | + nn.BatchNorm2d(out_channel)) | ||
15 | + | ||
16 | + self.relu = nn.ReLU(inplace=True) | ||
17 | + | ||
18 | + self.conv2 = nn.Sequential( | ||
19 | + nn.Conv2d(out_channel, out_channel, | ||
20 | + kernel_size=3, padding=1), | ||
21 | + nn.BatchNorm2d(out_channel)) | ||
22 | + | ||
23 | + if self.in_channel != self.out_channel or \ | ||
24 | + self.stride != 1: | ||
25 | + self.down = nn.Sequential( | ||
26 | + nn.Conv2d(in_channel, out_channel, | ||
27 | + kernel_size=1, stride=stride), | ||
28 | + nn.BatchNorm2d(out_channel)) | ||
29 | + | ||
30 | + def forward(self, b): | ||
31 | + t = self.conv1(b) | ||
32 | + t = self.relu(t) | ||
33 | + t = self.conv2(t) | ||
34 | + | ||
35 | + if self.in_channel != self.out_channel or \ | ||
36 | + self.stride != 1: | ||
37 | + b = self.down(b) | ||
38 | + | ||
39 | + t += b | ||
40 | + t = self.relu(t) | ||
41 | + | ||
42 | + return t | ||
43 | + | ||
44 | + | ||
45 | +class Net(nn.Module): | ||
46 | + def __init__(self, args): | ||
47 | + super(Net, self).__init__() | ||
48 | + scale = args.scale | ||
49 | + | ||
50 | + self.stem = nn.Sequential( | ||
51 | + nn.Conv2d(3, 16, | ||
52 | + kernel_size=3, padding=1), | ||
53 | + nn.BatchNorm2d(16), | ||
54 | + nn.ReLU(inplace=True)) | ||
55 | + | ||
56 | + self.layer1 = nn.Sequential(*[ | ||
57 | + ResidualBlock(16, 16, 1) for _ in range(2*scale)]) | ||
58 | + | ||
59 | + self.layer2 = nn.Sequential(*[ | ||
60 | + ResidualBlock(in_channel=(16 if i==0 else 32), | ||
61 | + out_channel=32, | ||
62 | + stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
63 | + | ||
64 | + self.layer3 = nn.Sequential(*[ | ||
65 | + ResidualBlock(in_channel=(32 if i==0 else 64), | ||
66 | + out_channel=64, | ||
67 | + stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
68 | + | ||
69 | + self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) | ||
70 | + | ||
71 | + self.fc = nn.Linear(64, 10) | ||
72 | + | ||
73 | + for m in self.modules(): | ||
74 | + if isinstance(m, nn.Conv2d): | ||
75 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
76 | + | ||
77 | + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
78 | + nn.init.constant_(m.weight, 1) | ||
79 | + nn.init.constant_(m.bias, 0) | ||
80 | + | ||
81 | + def forward(self, x): | ||
82 | + s = self.stem(x) | ||
83 | + x = self.layer1(s) | ||
84 | + x = self.layer2(x) | ||
85 | + x = self.layer3(x) | ||
86 | + x = self.avg_pool(x) | ||
87 | + x = x.reshape(x.size(0), -1) | ||
88 | + x = self.fc(x) | ||
89 | + | ||
90 | + return x, s |
1 | +import os | ||
2 | +import fire | ||
3 | +import time | ||
4 | +import json | ||
5 | +import random | ||
6 | +from pprint import pprint | ||
7 | + | ||
8 | +import torch.nn as nn | ||
9 | +import torch.backends.cudnn as cudnn | ||
10 | +from torch.utils.tensorboard import SummaryWriter | ||
11 | + | ||
12 | +from networks import * | ||
13 | +from utils import * | ||
14 | + | ||
15 | + | ||
16 | +def train(**kwargs): | ||
17 | + print('\n[+] Parse arguments') | ||
18 | + args, kwargs = parse_args(kwargs) | ||
19 | + pprint(args) | ||
20 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
21 | + | ||
22 | + print('\n[+] Create log dir') | ||
23 | + model_name = get_model_name(args) | ||
24 | + log_dir = os.path.join('./runs', model_name) | ||
25 | + os.makedirs(os.path.join(log_dir, 'model')) | ||
26 | + json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | ||
27 | + writer = SummaryWriter(log_dir=log_dir) | ||
28 | + | ||
29 | + if args.seed is not None: | ||
30 | + random.seed(args.seed) | ||
31 | + torch.manual_seed(args.seed) | ||
32 | + cudnn.deterministic = True | ||
33 | + | ||
34 | + print('\n[+] Create network') | ||
35 | + model = select_model(args) | ||
36 | + optimizer = select_optimizer(args, model) | ||
37 | + scheduler = select_scheduler(args, optimizer) | ||
38 | + criterion = nn.CrossEntropyLoss() | ||
39 | + if args.use_cuda: | ||
40 | + model = model.cuda() | ||
41 | + criterion = criterion.cuda() | ||
42 | + #writer.add_graph(model) | ||
43 | + | ||
44 | + print('\n[+] Load dataset') | ||
45 | + transform = get_train_transform(args, model, log_dir) | ||
46 | + val_transform = get_valid_transform(args, model) | ||
47 | + train_dataset = get_dataset(args, transform, 'train') | ||
48 | + valid_dataset = get_dataset(args, val_transform, 'val') | ||
49 | + train_loader = iter(get_inf_dataloader(args, train_dataset)) | ||
50 | + max_epoch = len(train_dataset) // args.batch_size | ||
51 | + best_acc = -1 | ||
52 | + | ||
53 | + print('\n[+] Start training') | ||
54 | + if torch.cuda.device_count() > 1: | ||
55 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
56 | + model = nn.DataParallel(model) | ||
57 | + | ||
58 | + start_t = time.time() | ||
59 | + for step in range(args.start_step, args.max_step): | ||
60 | + batch = next(train_loader) | ||
61 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | ||
62 | + | ||
63 | + if step % args.print_step == 0: | ||
64 | + print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | ||
65 | + step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | ||
66 | + writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | ||
67 | + writer.add_scalar('train/acc1', _train_res[0], global_step=step) | ||
68 | + writer.add_scalar('train/acc5', _train_res[1], global_step=step) | ||
69 | + writer.add_scalar('train/loss', _train_res[2], global_step=step) | ||
70 | + writer.add_scalar('train/forward_time', _train_res[3], global_step=step) | ||
71 | + writer.add_scalar('train/backward_time', _train_res[4], global_step=step) | ||
72 | + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
73 | + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
74 | + print(' Loss : {}'.format(_train_res[2].data)) | ||
75 | + print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000)) | ||
76 | + print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | ||
77 | + | ||
78 | + if step % args.val_step == args.val_step-1: | ||
79 | + valid_loader = iter(get_dataloader(args, valid_dataset)) | ||
80 | + _valid_res = validate(args, model, criterion, valid_loader, step, writer) | ||
81 | + print('\n[+] Valid results') | ||
82 | + writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | ||
83 | + writer.add_scalar('valid/acc5', _valid_res[1], global_step=step) | ||
84 | + writer.add_scalar('valid/loss', _valid_res[2], global_step=step) | ||
85 | + print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | ||
86 | + print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100)) | ||
87 | + print(' Loss : {}'.format(_valid_res[2].data)) | ||
88 | + | ||
89 | + if _valid_res[0] > best_acc: | ||
90 | + best_acc = _valid_res[0] | ||
91 | + torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | ||
92 | + print('\n[+] Model saved') | ||
93 | + | ||
94 | + writer.close() | ||
95 | + | ||
96 | + | ||
97 | +if __name__ == '__main__': | ||
98 | + fire.Fire(train) |
1 | +import numpy as np | ||
2 | +import torch.nn as nn | ||
3 | +import torchvision.transforms as transforms | ||
4 | + | ||
5 | +from abc import ABC, abstractmethod | ||
6 | +from PIL import Image, ImageOps, ImageEnhance | ||
7 | + | ||
8 | + | ||
9 | +class BaseTransform(ABC): | ||
10 | + | ||
11 | + def __init__(self, prob, mag): | ||
12 | + self.prob = prob | ||
13 | + self.mag = mag | ||
14 | + | ||
15 | + def __call__(self, img): | ||
16 | + return transforms.RandomApply([self.transform], self.prob)(img) | ||
17 | + | ||
18 | + def __repr__(self): | ||
19 | + return '%s(prob=%.2f, magnitude=%.2f)' % \ | ||
20 | + (self.__class__.__name__, self.prob, self.mag) | ||
21 | + | ||
22 | + @abstractmethod | ||
23 | + def transform(self, img): | ||
24 | + pass | ||
25 | + | ||
26 | + | ||
27 | +class ShearXY(BaseTransform): | ||
28 | + | ||
29 | + def transform(self, img): | ||
30 | + degrees = self.mag * 360 | ||
31 | + t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR) | ||
32 | + return t(img) | ||
33 | + | ||
34 | + | ||
35 | +class TranslateXY(BaseTransform): | ||
36 | + | ||
37 | + def transform(self, img): | ||
38 | + translate = (self.mag, self.mag) | ||
39 | + t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR) | ||
40 | + return t(img) | ||
41 | + | ||
42 | + | ||
43 | +class Rotate(BaseTransform): | ||
44 | + | ||
45 | + def transform(self, img): | ||
46 | + degrees = self.mag * 360 | ||
47 | + t = transforms.RandomRotation(degrees, Image.BILINEAR) | ||
48 | + return t(img) | ||
49 | + | ||
50 | + | ||
51 | +class AutoContrast(BaseTransform): | ||
52 | + | ||
53 | + def transform(self, img): | ||
54 | + cutoff = int(self.mag * 49) | ||
55 | + return ImageOps.autocontrast(img, cutoff=cutoff) | ||
56 | + | ||
57 | + | ||
58 | +class Invert(BaseTransform): | ||
59 | + | ||
60 | + def transform(self, img): | ||
61 | + return ImageOps.invert(img) | ||
62 | + | ||
63 | + | ||
64 | +class Equalize(BaseTransform): | ||
65 | + | ||
66 | + def transform(self, img): | ||
67 | + return ImageOps.equalize(img) | ||
68 | + | ||
69 | + | ||
70 | +class Solarize(BaseTransform): | ||
71 | + | ||
72 | + def transform(self, img): | ||
73 | + threshold = (1-self.mag) * 255 | ||
74 | + return ImageOps.solarize(img, threshold) | ||
75 | + | ||
76 | + | ||
77 | +class Posterize(BaseTransform): | ||
78 | + | ||
79 | + def transform(self, img): | ||
80 | + bits = int((1-self.mag) * 8) | ||
81 | + return ImageOps.posterize(img, bits=bits) | ||
82 | + | ||
83 | + | ||
84 | +class Contrast(BaseTransform): | ||
85 | + | ||
86 | + def transform(self, img): | ||
87 | + factor = self.mag * 10 | ||
88 | + return ImageEnhance.Contrast(img).enhance(factor) | ||
89 | + | ||
90 | + | ||
91 | +class Color(BaseTransform): | ||
92 | + | ||
93 | + def transform(self, img): | ||
94 | + factor = self.mag * 10 | ||
95 | + return ImageEnhance.Color(img).enhance(factor) | ||
96 | + | ||
97 | + | ||
98 | +class Brightness(BaseTransform): | ||
99 | + | ||
100 | + def transform(self, img): | ||
101 | + factor = self.mag * 10 | ||
102 | + return ImageEnhance.Brightness(img).enhance(factor) | ||
103 | + | ||
104 | + | ||
105 | +class Sharpness(BaseTransform): | ||
106 | + | ||
107 | + def transform(self, img): | ||
108 | + factor = self.mag * 10 | ||
109 | + return ImageEnhance.Sharpness(img).enhance(factor) | ||
110 | + | ||
111 | + | ||
112 | +class Cutout(BaseTransform): | ||
113 | + | ||
114 | + def transform(self, img): | ||
115 | + n_holes = 1 | ||
116 | + length = 24 * self.mag | ||
117 | + cutout_op = CutoutOp(n_holes=n_holes, length=length) | ||
118 | + return cutout_op(img) | ||
119 | + | ||
120 | + | ||
121 | +class CutoutOp(object): | ||
122 | + """ | ||
123 | + https://github.com/uoguelph-mlrg/Cutout | ||
124 | + | ||
125 | + Randomly mask out one or more patches from an image. | ||
126 | + | ||
127 | + Args: | ||
128 | + n_holes (int): Number of patches to cut out of each image. | ||
129 | + length (int): The length (in pixels) of each square patch. | ||
130 | + """ | ||
131 | + def __init__(self, n_holes, length): | ||
132 | + self.n_holes = n_holes | ||
133 | + self.length = length | ||
134 | + | ||
135 | + def __call__(self, img): | ||
136 | + """ | ||
137 | + Args: | ||
138 | + img (Tensor): Tensor image of size (C, H, W). | ||
139 | + Returns: | ||
140 | + Tensor: Image with n_holes of dimension length x length cut out of it. | ||
141 | + """ | ||
142 | + w, h = img.size | ||
143 | + | ||
144 | + mask = np.ones((h, w, 1), np.uint8) | ||
145 | + | ||
146 | + for n in range(self.n_holes): | ||
147 | + y = np.random.randint(h) | ||
148 | + x = np.random.randint(w) | ||
149 | + | ||
150 | + y1 = np.clip(y - self.length // 2, 0, h).astype(int) | ||
151 | + y2 = np.clip(y + self.length // 2, 0, h).astype(int) | ||
152 | + x1 = np.clip(x - self.length // 2, 0, w).astype(int) | ||
153 | + x2 = np.clip(x + self.length // 2, 0, w).astype(int) | ||
154 | + | ||
155 | + mask[y1: y2, x1: x2, :] = 0. | ||
156 | + | ||
157 | + img = mask*np.asarray(img).astype(np.uint8) | ||
158 | + img = Image.fromarray(mask*np.asarray(img)) | ||
159 | + | ||
160 | + return img | ||
161 | + |
1 | +import os | ||
2 | +import time | ||
3 | +import importlib | ||
4 | +import collections | ||
5 | +import pickle as cp | ||
6 | +import glob | ||
7 | +import numpy as np | ||
8 | + | ||
9 | +import torch | ||
10 | +import torchvision | ||
11 | +import torch.nn.functional as F | ||
12 | +import torchvision.models as models | ||
13 | +import torchvision.transforms as transforms | ||
14 | +from torch.utils.data import Subset | ||
15 | +from torch.utils.data import Dataset, DataLoader | ||
16 | + | ||
17 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
18 | + | ||
19 | + | ||
20 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | ||
21 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | ||
22 | +current_epoch = 0 | ||
23 | + | ||
24 | + | ||
25 | +def split_dataset(args, dataset, k): | ||
26 | + # load dataset | ||
27 | + X = list(range(len(dataset))) | ||
28 | + Y = dataset.targets | ||
29 | + | ||
30 | + # split to k-fold | ||
31 | + assert len(X) == len(Y) | ||
32 | + | ||
33 | + def _it_to_list(_it): | ||
34 | + return list(zip(*list(_it))) | ||
35 | + | ||
36 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
37 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
38 | + | ||
39 | + return Dm_indexes, Da_indexes | ||
40 | + | ||
41 | + | ||
42 | +def concat_image_features(image, features, max_features=3): | ||
43 | + _, h, w = image.shape | ||
44 | + | ||
45 | + max_features = min(features.size(0), max_features) | ||
46 | + image_feature = image.clone() | ||
47 | + | ||
48 | + for i in range(max_features): | ||
49 | + feature = features[i:i+1] | ||
50 | + _min, _max = torch.min(feature), torch.max(feature) | ||
51 | + feature = (feature - _min) / (_max - _min + 1e-6) | ||
52 | + feature = torch.cat([feature]*3, 0) | ||
53 | + feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
54 | + feature = F.upsample(feature, size=(h,w), mode="bilinear") | ||
55 | + feature = feature.view(3, h, w) | ||
56 | + image_feature = torch.cat((image_feature, feature), 2) | ||
57 | + | ||
58 | + return image_feature | ||
59 | + | ||
60 | + | ||
61 | +def get_model_name(args): | ||
62 | + from datetime import datetime | ||
63 | + now = datetime.now() | ||
64 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
65 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
66 | + return model_name | ||
67 | + | ||
68 | + | ||
69 | +def dict_to_namedtuple(d): | ||
70 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
71 | + | ||
72 | + for k,v in d.items(): | ||
73 | + if type(v) is dict: | ||
74 | + d[k] = dict_to_namedtuple(v) | ||
75 | + | ||
76 | + elif type(v) is str: | ||
77 | + try: | ||
78 | + d[k] = eval(v) | ||
79 | + except: | ||
80 | + d[k] = v | ||
81 | + | ||
82 | + args = Args(**d) | ||
83 | + return args | ||
84 | + | ||
85 | + | ||
86 | +def parse_args(kwargs): | ||
87 | + # combine with default args | ||
88 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10' | ||
89 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10' | ||
90 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
91 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1 | ||
92 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
93 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
94 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
95 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
96 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000 | ||
97 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000 | ||
98 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
99 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
100 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
101 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000 | ||
102 | + kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | ||
103 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
104 | + | ||
105 | + # to named tuple | ||
106 | + args = dict_to_namedtuple(kwargs) | ||
107 | + return args, kwargs | ||
108 | + | ||
109 | + | ||
110 | +def select_model(args): | ||
111 | + if args.network in models.__dict__: | ||
112 | + backbone = models.__dict__[args.network]() | ||
113 | + model = BaseNet(backbone, args) | ||
114 | + else: | ||
115 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
116 | + model = Net(args) | ||
117 | + | ||
118 | + print(model) | ||
119 | + return model | ||
120 | + | ||
121 | + | ||
122 | +def select_optimizer(args, model): | ||
123 | + if args.optimizer == 'sgd': | ||
124 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
125 | + elif args.optimizer == 'rms': | ||
126 | + #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5) | ||
127 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
128 | + elif args.optimizer == 'adam': | ||
129 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
130 | + else: | ||
131 | + raise Exception('Unknown Optimizer') | ||
132 | + return optimizer | ||
133 | + | ||
134 | + | ||
135 | +def select_scheduler(args, optimizer): | ||
136 | + if not args.scheduler or args.scheduler == 'None': | ||
137 | + return None | ||
138 | + elif args.scheduler =='clr': | ||
139 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
140 | + elif args.scheduler =='exp': | ||
141 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
142 | + else: | ||
143 | + raise Exception('Unknown Scheduler') | ||
144 | + | ||
145 | + | ||
146 | +class CustomDataset(Dataset): | ||
147 | + def __init__(self, path, transform = None): | ||
148 | + self.path = path | ||
149 | + self.transform = transform | ||
150 | + self.img = np.load(path) | ||
151 | + self.len = self.img.shape[0] | ||
152 | + | ||
153 | + def __len__(self): | ||
154 | + return self.len | ||
155 | + | ||
156 | + def __getitem__(self, idx): | ||
157 | + if self.transforms is not None: | ||
158 | + img = self.transforms(img) | ||
159 | + return img | ||
160 | + | ||
161 | +def get_dataset(args, transform, split='train'): | ||
162 | + assert split in ['train', 'val', 'test', 'trainval'] | ||
163 | + | ||
164 | + if args.dataset == 'cifar10': | ||
165 | + train = split in ['train', 'val', 'trainval'] | ||
166 | + dataset = torchvision.datasets.CIFAR10(DATASET_PATH, | ||
167 | + train=train, | ||
168 | + transform=transform, | ||
169 | + download=True) | ||
170 | + | ||
171 | + if split in ['train', 'val']: | ||
172 | + split_path = os.path.join(DATASET_PATH, | ||
173 | + 'cifar-10-batches-py', 'train_val_index.cp') | ||
174 | + | ||
175 | + if not os.path.exists(split_path): | ||
176 | + [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
177 | + split_index = {'train':train_index, 'val':val_index} | ||
178 | + cp.dump(split_index, open(split_path, 'wb')) | ||
179 | + | ||
180 | + split_index = cp.load(open(split_path, 'rb')) | ||
181 | + dataset = Subset(dataset, split_index[split]) | ||
182 | + | ||
183 | + elif args.dataset == 'imagenet': | ||
184 | + dataset = torchvision.datasets.ImageNet(DATASET_PATH, | ||
185 | + split=split, | ||
186 | + transform=transform, | ||
187 | + download=(split is 'val')) | ||
188 | + | ||
189 | + elif args.dataset == 'BraTS': | ||
190 | + if split in ['train']: | ||
191 | + dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) | ||
192 | + else: | ||
193 | + dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) | ||
194 | + | ||
195 | + | ||
196 | + else: | ||
197 | + raise Exception('Unknown dataset') | ||
198 | + | ||
199 | + return dataset | ||
200 | + | ||
201 | + | ||
202 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
203 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
204 | + batch_size=args.batch_size, | ||
205 | + shuffle=shuffle, | ||
206 | + num_workers=args.num_workers, | ||
207 | + pin_memory=pin_memory) | ||
208 | + return data_loader | ||
209 | + | ||
210 | + | ||
211 | +def get_inf_dataloader(args, dataset): | ||
212 | + global current_epoch | ||
213 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
214 | + | ||
215 | + while True: | ||
216 | + try: | ||
217 | + batch = next(data_loader) | ||
218 | + | ||
219 | + except StopIteration: | ||
220 | + current_epoch += 1 | ||
221 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
222 | + batch = next(data_loader) | ||
223 | + | ||
224 | + yield batch | ||
225 | + | ||
226 | + | ||
227 | +def get_train_transform(args, model, log_dir=None): | ||
228 | + if args.fast_auto_augment: | ||
229 | + assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet | ||
230 | + | ||
231 | + from fast_auto_augment import fast_auto_augment | ||
232 | + if args.augment_path: | ||
233 | + transform = cp.load(open(args.augment_path, 'rb')) | ||
234 | + os.system('cp {} {}'.format( | ||
235 | + args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | ||
236 | + else: | ||
237 | + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | ||
238 | + if log_dir: | ||
239 | + cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | ||
240 | + | ||
241 | + elif args.dataset == 'cifar10': | ||
242 | + transform = transforms.Compose([ | ||
243 | + transforms.Pad(4), | ||
244 | + transforms.RandomCrop(32), | ||
245 | + transforms.RandomHorizontalFlip(), | ||
246 | + transforms.ToTensor() | ||
247 | + ]) | ||
248 | + | ||
249 | + elif args.dataset == 'imagenet': | ||
250 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
251 | + transform = transforms.Compose([ | ||
252 | + transforms.Resize([resize_h, resize_w]), | ||
253 | + transforms.RandomCrop(model.img_size), | ||
254 | + transforms.RandomHorizontalFlip(), | ||
255 | + transforms.ToTensor() | ||
256 | + ]) | ||
257 | + | ||
258 | + elif args.dataset == 'BraTS': | ||
259 | + resize_h, resize_w = 256, 256 | ||
260 | + transform = transforms.Compose([ | ||
261 | + transforms.Resize([resize_h, resize_w]), | ||
262 | + transforms.RandomCrop(model.img_size), | ||
263 | + transforms.RandomHorizontalFlip(), | ||
264 | + transforms.ToTensor() | ||
265 | + ]) | ||
266 | + else: | ||
267 | + raise Exception('Unknown Dataset') | ||
268 | + | ||
269 | + print(transform) | ||
270 | + | ||
271 | + return transform | ||
272 | + | ||
273 | + | ||
274 | +def get_valid_transform(args, model): | ||
275 | + if args.dataset == 'cifar10': | ||
276 | + val_transform = transforms.Compose([ | ||
277 | + transforms.Resize(32), | ||
278 | + transforms.ToTensor() | ||
279 | + ]) | ||
280 | + | ||
281 | + elif args.dataset == 'imagenet': | ||
282 | + resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
283 | + val_transform = transforms.Compose([ | ||
284 | + transforms.Resize([resize_h, resize_w]), | ||
285 | + transforms.ToTensor() | ||
286 | + ]) | ||
287 | + elif args.dataset == 'BraTS': | ||
288 | + resize_h, resize_w = 256, 256 | ||
289 | + val_transform = transforms.Compose([ | ||
290 | + transforms.Resize([resize_h, resize_w]), | ||
291 | + transforms.ToTensor() | ||
292 | + ]) | ||
293 | + else: | ||
294 | + raise Exception('Unknown Dataset') | ||
295 | + | ||
296 | + return val_transform | ||
297 | + | ||
298 | + | ||
299 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
300 | + model.train() | ||
301 | + images, target = batch | ||
302 | + | ||
303 | + if device: | ||
304 | + images = images.to(device) | ||
305 | + target = target.to(device) | ||
306 | + | ||
307 | + elif args.use_cuda: | ||
308 | + images = images.cuda(non_blocking=True) | ||
309 | + target = target.cuda(non_blocking=True) | ||
310 | + | ||
311 | + # compute output | ||
312 | + start_t = time.time() | ||
313 | + output, first = model(images) | ||
314 | + forward_t = time.time() - start_t | ||
315 | + loss = criterion(output, target) | ||
316 | + | ||
317 | + # measure accuracy and record loss | ||
318 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
319 | + acc1 /= images.size(0) | ||
320 | + acc5 /= images.size(0) | ||
321 | + | ||
322 | + # compute gradient and do SGD step | ||
323 | + optimizer.zero_grad() | ||
324 | + start_t = time.time() | ||
325 | + loss.backward() | ||
326 | + backward_t = time.time() - start_t | ||
327 | + optimizer.step() | ||
328 | + if scheduler: scheduler.step() | ||
329 | + | ||
330 | + if writer and step % args.print_step == 0: | ||
331 | + n_imgs = min(images.size(0), 10) | ||
332 | + for j in range(n_imgs): | ||
333 | + writer.add_image('train/input_image', | ||
334 | + concat_image_features(images[j], first[j]), global_step=step) | ||
335 | + | ||
336 | + return acc1, acc5, loss, forward_t, backward_t | ||
337 | + | ||
338 | + | ||
339 | +def validate(args, model, criterion, valid_loader, step, writer, device=None): | ||
340 | + # switch to evaluate mode | ||
341 | + model.eval() | ||
342 | + | ||
343 | + acc1, acc5 = 0, 0 | ||
344 | + samples = 0 | ||
345 | + infer_t = 0 | ||
346 | + | ||
347 | + with torch.no_grad(): | ||
348 | + for i, (images, target) in enumerate(valid_loader): | ||
349 | + | ||
350 | + start_t = time.time() | ||
351 | + if device: | ||
352 | + images = images.to(device) | ||
353 | + target = target.to(device) | ||
354 | + | ||
355 | + elif args.use_cuda is not None: | ||
356 | + images = images.cuda(non_blocking=True) | ||
357 | + target = target.cuda(non_blocking=True) | ||
358 | + | ||
359 | + # compute output | ||
360 | + output, first = model(images) | ||
361 | + loss = criterion(output, target) | ||
362 | + infer_t += time.time() - start_t | ||
363 | + | ||
364 | + # measure accuracy and record loss | ||
365 | + _acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
366 | + acc1 += _acc1 | ||
367 | + acc5 += _acc5 | ||
368 | + samples += images.size(0) | ||
369 | + | ||
370 | + acc1 /= samples | ||
371 | + acc5 /= samples | ||
372 | + | ||
373 | + if writer: | ||
374 | + n_imgs = min(images.size(0), 10) | ||
375 | + for j in range(n_imgs): | ||
376 | + writer.add_image('valid/input_image', | ||
377 | + concat_image_features(images[j], first[j]), global_step=step) | ||
378 | + | ||
379 | + return acc1, acc5, loss, infer_t | ||
380 | + | ||
381 | + | ||
382 | +def accuracy(output, target, topk=(1,)): | ||
383 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
384 | + with torch.no_grad(): | ||
385 | + maxk = max(topk) | ||
386 | + batch_size = target.size(0) | ||
387 | + | ||
388 | + _, pred = output.topk(maxk, 1, True, True) | ||
389 | + pred = pred.t() | ||
390 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
391 | + | ||
392 | + res = [] | ||
393 | + for k in topk: | ||
394 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
395 | + res.append(correct_k) | ||
396 | + return res | ||
397 | + |
-
Please register or login to post a comment