Showing
21 changed files
with
0 additions
and
1636 deletions
1 | -# Byte-compiled / optimized / DLL files | ||
2 | -__pycache__/ | ||
3 | -*.py[cod] | ||
4 | -*$py.class | ||
5 | - | ||
6 | -# C extensions | ||
7 | -*.so | ||
8 | - | ||
9 | -# Distribution / packaging | ||
10 | -.Python | ||
11 | -build/ | ||
12 | -develop-eggs/ | ||
13 | -dist/ | ||
14 | -downloads/ | ||
15 | -eggs/ | ||
16 | -.eggs/ | ||
17 | -lib/ | ||
18 | -lib64/ | ||
19 | -parts/ | ||
20 | -sdist/ | ||
21 | -var/ | ||
22 | -wheels/ | ||
23 | -*.egg-info/ | ||
24 | -.installed.cfg | ||
25 | -*.egg | ||
26 | -MANIFEST | ||
27 | - | ||
28 | -# PyInstaller | ||
29 | -# Usually these files are written by a python script from a template | ||
30 | -# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
31 | -*.manifest | ||
32 | -*.spec | ||
33 | - | ||
34 | -# Installer logs | ||
35 | -pip-log.txt | ||
36 | -pip-delete-this-directory.txt | ||
37 | - | ||
38 | -# Unit test / coverage reports | ||
39 | -htmlcov/ | ||
40 | -.tox/ | ||
41 | -.coverage | ||
42 | -.coverage.* | ||
43 | -.cache | ||
44 | -nosetests.xml | ||
45 | -coverage.xml | ||
46 | -*.cover | ||
47 | -.hypothesis/ | ||
48 | -.pytest_cache/ | ||
49 | - | ||
50 | -# Translations | ||
51 | -*.mo | ||
52 | -*.pot | ||
53 | - | ||
54 | -# Django stuff: | ||
55 | -*.log | ||
56 | -local_settings.py | ||
57 | -db.sqlite3 | ||
58 | - | ||
59 | -# Flask stuff: | ||
60 | -instance/ | ||
61 | -.webassets-cache | ||
62 | - | ||
63 | -# Scrapy stuff: | ||
64 | -.scrapy | ||
65 | - | ||
66 | -# Sphinx documentation | ||
67 | -docs/_build/ | ||
68 | - | ||
69 | -# PyBuilder | ||
70 | -target/ | ||
71 | - | ||
72 | -# Jupyter Notebook | ||
73 | -.ipynb_checkpoints | ||
74 | - | ||
75 | -# pyenv | ||
76 | -.python-version | ||
77 | - | ||
78 | -# celery beat schedule file | ||
79 | -celerybeat-schedule | ||
80 | - | ||
81 | -# SageMath parsed files | ||
82 | -*.sage.py | ||
83 | - | ||
84 | -# Environments | ||
85 | -.env | ||
86 | -.venv | ||
87 | -env/ | ||
88 | -venv/ | ||
89 | -ENV/ | ||
90 | -env.bak/ | ||
91 | -venv.bak/ | ||
92 | - | ||
93 | -# Spyder project settings | ||
94 | -.spyderproject | ||
95 | -.spyproject | ||
96 | - | ||
97 | -# Rope project settings | ||
98 | -.ropeproject | ||
99 | - | ||
100 | -# mkdocs documentation | ||
101 | -/site | ||
102 | - | ||
103 | -# mypy | ||
104 | -.mypy_cache/ | ||
105 | - |
1 | -# Fast Autoaugment | ||
2 | -<img src="figures/faa.png" width=800px> | ||
3 | - | ||
4 | -A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946). | ||
5 | - | ||
6 | -## Prerequisite | ||
7 | -* torch==1.1.0 | ||
8 | -* torchvision==0.2.2 | ||
9 | -* hyperopt==0.1.2 | ||
10 | -* future==0.17.1 | ||
11 | -* tb-nightly==1.15.0a20190622 | ||
12 | - | ||
13 | -## Usage | ||
14 | -### Training | ||
15 | -#### CIFAR10 | ||
16 | -```bash | ||
17 | -# ResNet20 (w/o FastAutoAugment) | ||
18 | -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False | ||
19 | - | ||
20 | -# ResNet20 (w/ FastAutoAugment) | ||
21 | -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True | ||
22 | - | ||
23 | -# ResNet20 (w/ FastAutoAugment, Pre-found policy) | ||
24 | -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \ | ||
25 | - --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp | ||
26 | - | ||
27 | -# ResNet32 (w/o FastAutoAugment) | ||
28 | -python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False | ||
29 | - | ||
30 | -# ResNet32 (w/ FastAutoAugment) | ||
31 | -python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True | ||
32 | - | ||
33 | -# EfficientNet (w/ FastAutoAugment) | ||
34 | -python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
35 | - --network=efficientnet_cifar10 --activation=swish | ||
36 | -``` | ||
37 | - | ||
38 | -#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html)) | ||
39 | -```bash | ||
40 | - | ||
41 | -# BaseNet (w/o FastAutoAugment) | ||
42 | -python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50 | ||
43 | - | ||
44 | -# EfficientNet (w/ FastAutoAugment) (UnderConstruction) | ||
45 | -python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
46 | - --network=efficientnet --activation=swish | ||
47 | -``` | ||
48 | - | ||
49 | -### Eval | ||
50 | -```bash | ||
51 | -# Single Image testing | ||
52 | -python eval.py --model_path=runs/ResNet_Scale3_Basline | ||
53 | - | ||
54 | -# 5-crops testing | ||
55 | -python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True | ||
56 | -``` | ||
57 | - | ||
58 | -## Experiments | ||
59 | -### Fast AutoAugment | ||
60 | -#### ResNet20 (CIFAR10) | ||
61 | -* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)] | ||
62 | -* Validation Curve | ||
63 | -<img src="figures/resnet20_valid.png"> | ||
64 | - | ||
65 | -* Evaluation (Acc @1) | ||
66 | - | ||
67 | -| | Valid | Test(Single) | | ||
68 | -|----------------|-------|-------------| | ||
69 | -| ResNet20 | 90.70 | **91.45** | | ||
70 | -| ResNet20 + FAA |**92.46**| **91.45** | | ||
71 | - | ||
72 | -#### ResNet34 (CIFAR10) | ||
73 | -* Validation Curve | ||
74 | -<img src="figures/resnet34_valid.png"> | ||
75 | - | ||
76 | -* Evaluation (Acc @1) | ||
77 | - | ||
78 | -| | Valid | Test(Single) | | ||
79 | -|----------------|-------|-------------| | ||
80 | -| ResNet34 | 91.54 | 91.47 | | ||
81 | -| ResNet34 + FAA |**92.76**| **91.99** | | ||
82 | - | ||
83 | -### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)] | ||
84 | -<img src="figures/pm.png"> | ||
85 | - | ||
86 | -### Augmented images | ||
87 | -<img src="figures/augmented_images.png"> | ||
88 | -<img src="figures/augmented_images2.png"> |
1 | -import os | ||
2 | -import fire | ||
3 | -import json | ||
4 | -from pprint import pprint | ||
5 | - | ||
6 | -import torch | ||
7 | -import torch.nn as nn | ||
8 | - | ||
9 | -from utils import * | ||
10 | - | ||
11 | - | ||
12 | -def eval(model_path): | ||
13 | - print('\n[+] Parse arguments') | ||
14 | - kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
15 | - kwargs = json.loads(open(kwargs_path).read()) | ||
16 | - args, kwargs = parse_args(kwargs) | ||
17 | - pprint(args) | ||
18 | - device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
19 | - | ||
20 | - print('\n[+] Create network') | ||
21 | - model = select_model(args) | ||
22 | - optimizer = select_optimizer(args, model) | ||
23 | - criterion = nn.CrossEntropyLoss() | ||
24 | - if args.use_cuda: | ||
25 | - model = model.cuda() | ||
26 | - criterion = criterion.cuda() | ||
27 | - | ||
28 | - print('\n[+] Load model') | ||
29 | - weight_path = os.path.join(model_path, 'model', 'model.pt') | ||
30 | - model.load_state_dict(torch.load(weight_path)) | ||
31 | - | ||
32 | - print('\n[+] Load dataset') | ||
33 | - test_transform = get_valid_transform(args, model) | ||
34 | - test_dataset = get_dataset(args, test_transform, 'test') | ||
35 | - test_loader = iter(get_dataloader(args, test_dataset)) | ||
36 | - | ||
37 | - print('\n[+] Start testing') | ||
38 | - _test_res = validate(args, model, criterion, test_loader, step=0, writer=None) | ||
39 | - | ||
40 | - print('\n[+] Valid results') | ||
41 | - print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | ||
42 | - print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | ||
43 | - print(' Loss : {:.3f}'.format(_test_res[2].data)) | ||
44 | - print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset))) | ||
45 | - | ||
46 | - | ||
47 | -if __name__ == '__main__': | ||
48 | - fire.Fire(eval) |
1 | -import copy | ||
2 | -import json | ||
3 | -import time | ||
4 | -import torch | ||
5 | -import random | ||
6 | -import torchvision.transforms as transforms | ||
7 | - | ||
8 | -from torch.utils.data import Subset | ||
9 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
10 | -from concurrent.futures import ProcessPoolExecutor | ||
11 | - | ||
12 | -from transforms import * | ||
13 | -from hyperopt import fmin, tpe, hp, STATUS_OK, Trials | ||
14 | -from utils import * | ||
15 | - | ||
16 | - | ||
17 | -DEFALUT_CANDIDATES = [ | ||
18 | - ShearXY, | ||
19 | - TranslateXY, | ||
20 | - Rotate, | ||
21 | - AutoContrast, | ||
22 | - Invert, | ||
23 | - Equalize, | ||
24 | - Solarize, | ||
25 | - Posterize, | ||
26 | - Contrast, | ||
27 | - Color, | ||
28 | - Brightness, | ||
29 | - Sharpness, | ||
30 | - Cutout, | ||
31 | -# SamplePairing, | ||
32 | -] | ||
33 | - | ||
34 | - | ||
35 | -def train_child(args, model, dataset, subset_indx, device=None): | ||
36 | - optimizer = select_optimizer(args, model) | ||
37 | - scheduler = select_scheduler(args, optimizer) | ||
38 | - criterion = nn.CrossEntropyLoss() | ||
39 | - | ||
40 | - dataset.transform = transforms.Compose([ | ||
41 | - transforms.Resize(32), | ||
42 | - transforms.ToTensor()]) | ||
43 | - subset = Subset(dataset, subset_indx) | ||
44 | - data_loader = get_inf_dataloader(args, subset) | ||
45 | - | ||
46 | - if device: | ||
47 | - model = model.to(device) | ||
48 | - criterion = criterion.to(device) | ||
49 | - | ||
50 | - elif args.use_cuda: | ||
51 | - model = model.cuda() | ||
52 | - criterion = criterion.cuda() | ||
53 | - | ||
54 | - if torch.cuda.device_count() > 1: | ||
55 | - print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
56 | - model = nn.DataParallel(model) | ||
57 | - | ||
58 | - start_t = time.time() | ||
59 | - for step in range(args.start_step, args.max_step): | ||
60 | - batch = next(data_loader) | ||
61 | - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | ||
62 | - | ||
63 | - if step % args.print_step == 0: | ||
64 | - print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format( | ||
65 | - step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device)) | ||
66 | - | ||
67 | - print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
68 | - print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
69 | - print(' Loss : {}'.format(_train_res[2].data)) | ||
70 | - | ||
71 | - return _train_res | ||
72 | - | ||
73 | - | ||
74 | -def validate_child(args, model, dataset, subset_indx, transform, device=None): | ||
75 | - criterion = nn.CrossEntropyLoss() | ||
76 | - | ||
77 | - if device: | ||
78 | - model = model.to(device) | ||
79 | - criterion = criterion.to(device) | ||
80 | - | ||
81 | - elif args.use_cuda: | ||
82 | - model = model.cuda() | ||
83 | - criterion = criterion.cuda() | ||
84 | - | ||
85 | - dataset.transform = transform | ||
86 | - subset = Subset(dataset, subset_indx) | ||
87 | - data_loader = get_dataloader(args, subset, pin_memory=False) | ||
88 | - | ||
89 | - return validate(args, model, criterion, data_loader, 0, None, device) | ||
90 | - | ||
91 | - | ||
92 | -def get_next_subpolicy(transform_candidates, op_per_subpolicy=2): | ||
93 | - n_candidates = len(transform_candidates) | ||
94 | - subpolicy = [] | ||
95 | - | ||
96 | - for i in range(op_per_subpolicy): | ||
97 | - indx = random.randrange(n_candidates) | ||
98 | - prob = random.random() | ||
99 | - mag = random.random() | ||
100 | - subpolicy.append(transform_candidates[indx](prob, mag)) | ||
101 | - | ||
102 | - subpolicy = transforms.Compose([ | ||
103 | - *subpolicy, | ||
104 | - transforms.Resize(32), | ||
105 | - transforms.ToTensor()]) | ||
106 | - | ||
107 | - return subpolicy | ||
108 | - | ||
109 | - | ||
110 | -def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
111 | - subpolicies = [] | ||
112 | - | ||
113 | - for b in range(B): | ||
114 | - subpolicy = get_next_subpolicy(transform_candidates) | ||
115 | - val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
116 | - subpolicies.append((subpolicy, val_res[2])) | ||
117 | - | ||
118 | - return subpolicies | ||
119 | - | ||
120 | - | ||
121 | -def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
122 | - | ||
123 | - def _objective(sampled): | ||
124 | - subpolicy = [transform(prob, mag) | ||
125 | - for transform, prob, mag in sampled] | ||
126 | - | ||
127 | - subpolicy = transforms.Compose([ | ||
128 | - transforms.Resize(32), | ||
129 | - *subpolicy, | ||
130 | - transforms.ToTensor()]) | ||
131 | - | ||
132 | - val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
133 | - loss = val_res[2].cpu().numpy() | ||
134 | - return {'loss': loss, 'status': STATUS_OK } | ||
135 | - | ||
136 | - space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)), | ||
137 | - (hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))] | ||
138 | - | ||
139 | - trials = Trials() | ||
140 | - best = fmin(_objective, | ||
141 | - space=space, | ||
142 | - algo=tpe.suggest, | ||
143 | - max_evals=B, | ||
144 | - trials=trials) | ||
145 | - | ||
146 | - subpolicies = [] | ||
147 | - for t in trials.trials: | ||
148 | - vals = t['misc']['vals'] | ||
149 | - subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]), | ||
150 | - transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])] | ||
151 | - subpolicy = transforms.Compose([ | ||
152 | - ## baseline augmentation | ||
153 | - transforms.Pad(4), | ||
154 | - transforms.RandomCrop(32), | ||
155 | - transforms.RandomHorizontalFlip(), | ||
156 | - ## policy | ||
157 | - *subpolicy, | ||
158 | - ## to tensor | ||
159 | - transforms.ToTensor()]) | ||
160 | - subpolicies.append((subpolicy, t['result']['loss'])) | ||
161 | - | ||
162 | - return subpolicies | ||
163 | - | ||
164 | - | ||
165 | -def get_topn_subpolicies(subpolicies, N=10): | ||
166 | - return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N] | ||
167 | - | ||
168 | - | ||
169 | -def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k): | ||
170 | - kwargs = json.loads(args_str) | ||
171 | - args, kwargs = parse_args(kwargs) | ||
172 | - device_id = k % torch.cuda.device_count() | ||
173 | - device = torch.device('cuda:%d' % device_id) | ||
174 | - _transform = [] | ||
175 | - | ||
176 | - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | ||
177 | - | ||
178 | - # train child model | ||
179 | - child_model = copy.deepcopy(model) | ||
180 | - train_res = train_child(args, child_model, dataset, Dm_indx, device) | ||
181 | - | ||
182 | - # search sub policy | ||
183 | - for t in range(T): | ||
184 | - #subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
185 | - subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
186 | - subpolicies = get_topn_subpolicies(subpolicies, N) | ||
187 | - _transform.extend([subpolicy[0] for subpolicy in subpolicies]) | ||
188 | - | ||
189 | - return _transform | ||
190 | - | ||
191 | - | ||
192 | -def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | ||
193 | - args_str = json.dumps(args._asdict()) | ||
194 | - dataset = get_dataset(args, None, 'trainval') | ||
195 | - num_process = min(torch.cuda.device_count(), num_process) | ||
196 | - transform, futures = [], [] | ||
197 | - | ||
198 | - torch.multiprocessing.set_start_method('spawn', force=True) | ||
199 | - | ||
200 | - if not transform_candidates: | ||
201 | - transform_candidates = DEFALUT_CANDIDATES | ||
202 | - | ||
203 | - # split | ||
204 | - Dm_indexes, Da_indexes = split_dataset(args, dataset, K) | ||
205 | - | ||
206 | - with ProcessPoolExecutor(max_workers=num_process) as executor: | ||
207 | - for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)): | ||
208 | - future = executor.submit(process_fn, | ||
209 | - args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k) | ||
210 | - futures.append(future) | ||
211 | - | ||
212 | - for future in futures: | ||
213 | - transform.extend(future.result()) | ||
214 | - | ||
215 | - transform = transforms.RandomChoice(transform) | ||
216 | - | ||
217 | - return transform |
File mode changed

92.9 KB

410 KB

425 KB

423 KB

48.3 KB

229 KB

227 KB
1 | -from .basenet import BaseNet |
1 | -import torch.nn as nn | ||
2 | - | ||
3 | -class BaseNet(nn.Module): | ||
4 | - def __init__(self, backbone, args): | ||
5 | - super(BaseNet, self).__init__() | ||
6 | - | ||
7 | - # Separate layers | ||
8 | - self.first = nn.Sequential(*list(backbone.children())[:1]) | ||
9 | - self.after = nn.Sequential(*list(backbone.children())[1:-1]) | ||
10 | - self.fc = list(backbone.children())[-1] | ||
11 | - | ||
12 | - self.img_size = (224, 224) | ||
13 | - | ||
14 | - def forward(self, x): | ||
15 | - f = self.first(x) | ||
16 | - x = self.after(f) | ||
17 | - x = x.reshape(x.size(0), -1) | ||
18 | - x = self.fc(x) | ||
19 | - return x, f |
1 | -import math | ||
2 | -import torch.nn as nn | ||
3 | -import torch.nn.functional as F | ||
4 | - | ||
5 | - | ||
6 | -def round_fn(orig, multiplier): | ||
7 | - if not multiplier: | ||
8 | - return orig | ||
9 | - | ||
10 | - return int(math.ceil(multiplier * orig)) | ||
11 | - | ||
12 | - | ||
13 | -def get_activation_fn(activation): | ||
14 | - if activation == "swish": | ||
15 | - return Swish | ||
16 | - | ||
17 | - elif activation == "relu": | ||
18 | - return nn.ReLU | ||
19 | - | ||
20 | - else: | ||
21 | - raise Exception('Unkown activation %s' % activation) | ||
22 | - | ||
23 | - | ||
24 | -class Swish(nn.Module): | ||
25 | - """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
26 | - | ||
27 | - def __init__(self, inplace=False): | ||
28 | - super().__init__() | ||
29 | - self.inplace = True | ||
30 | - | ||
31 | - def forward(self, x): | ||
32 | - if self.inplace: | ||
33 | - x.mul_(F.sigmoid(x)) | ||
34 | - return x | ||
35 | - else: | ||
36 | - return x * F.sigmoid(x) | ||
37 | - | ||
38 | - | ||
39 | -class ConvBlock(nn.Module): | ||
40 | - """ Conv + BatchNorm + Activation """ | ||
41 | - | ||
42 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
43 | - padding=0, stride=1, activation="swish"): | ||
44 | - super().__init__() | ||
45 | - self.fw = nn.Sequential( | ||
46 | - nn.Conv2d(in_channel, out_channel, kernel_size, | ||
47 | - padding=padding, stride=stride, bias=False), | ||
48 | - nn.BatchNorm2d(out_channel), | ||
49 | - get_activation_fn(activation)()) | ||
50 | - | ||
51 | - def forward(self, x): | ||
52 | - return self.fw(x) | ||
53 | - | ||
54 | - | ||
55 | -class DepthwiseConvBlock(nn.Module): | ||
56 | - """ DepthwiseConv2D + BatchNorm + Activation """ | ||
57 | - | ||
58 | - def __init__(self, in_channel, kernel_size, | ||
59 | - padding=0, stride=1, activation="swish"): | ||
60 | - super().__init__() | ||
61 | - self.fw = nn.Sequential( | ||
62 | - nn.Conv2d(in_channel, in_channel, kernel_size, | ||
63 | - padding=padding, stride=stride, groups=in_channel, bias=False), | ||
64 | - nn.BatchNorm2d(in_channel), | ||
65 | - get_activation_fn(activation)()) | ||
66 | - | ||
67 | - def forward(self, x): | ||
68 | - return self.fw(x) | ||
69 | - | ||
70 | - | ||
71 | -class MBConv(nn.Module): | ||
72 | - """ Inverted residual block """ | ||
73 | - | ||
74 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
75 | - stride=1, expand_ratio=1, activation="swish"): | ||
76 | - super().__init__() | ||
77 | - self.in_channel = in_channel | ||
78 | - self.out_channel = out_channel | ||
79 | - self.expand_ratio = expand_ratio | ||
80 | - self.stride = stride | ||
81 | - | ||
82 | - if expand_ratio != 1: | ||
83 | - self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
84 | - activation=activation) | ||
85 | - | ||
86 | - self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
87 | - padding=(kernel_size-1)//2, | ||
88 | - stride=stride, activation=activation) | ||
89 | - | ||
90 | - self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
91 | - activation=activation) | ||
92 | - | ||
93 | - def forward(self, inputs): | ||
94 | - if self.expand_ratio != 1: | ||
95 | - x = self.expand(inputs) | ||
96 | - else: | ||
97 | - x = inputs | ||
98 | - | ||
99 | - x = self.dw_conv(x) | ||
100 | - x = self.pw_conv(x) | ||
101 | - | ||
102 | - if self.in_channel == self.out_channel and \ | ||
103 | - self.stride == 1: | ||
104 | - x = x + inputs | ||
105 | - | ||
106 | - return x | ||
107 | - | ||
108 | - | ||
109 | -class Net(nn.Module): | ||
110 | - """ EfficientNet """ | ||
111 | - | ||
112 | - def __init__(self, args): | ||
113 | - super(Net, self).__init__() | ||
114 | - pi = args.pi | ||
115 | - activation = args.activation | ||
116 | - num_classes = args.num_classes | ||
117 | - | ||
118 | - self.d = 1.2 ** pi | ||
119 | - self.w = 1.1 ** pi | ||
120 | - self.r = 1.15 ** pi | ||
121 | - self.img_size = (round_fn(224, self.r), round_fn(224, self.r)) | ||
122 | - | ||
123 | - self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
124 | - kernel_size=3, padding=1, stride=2, activation=activation) | ||
125 | - | ||
126 | - self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
127 | - depth=round_fn(1, self.d), kernel_size=3, | ||
128 | - half_resolution=False, expand_ratio=1, activation=activation) | ||
129 | - | ||
130 | - self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
131 | - depth=round_fn(2, self.d), kernel_size=3, | ||
132 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
133 | - | ||
134 | - self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
135 | - depth=round_fn(2, self.d), kernel_size=5, | ||
136 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
137 | - | ||
138 | - self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
139 | - depth=round_fn(3, self.d), kernel_size=3, | ||
140 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
141 | - | ||
142 | - self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
143 | - depth=round_fn(3, self.d), kernel_size=5, | ||
144 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
145 | - | ||
146 | - self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
147 | - depth=round_fn(4, self.d), kernel_size=5, | ||
148 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
149 | - | ||
150 | - self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
151 | - depth=round_fn(1, self.d), kernel_size=3, | ||
152 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
153 | - | ||
154 | - self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
155 | - kernel_size=1, activation=activation) | ||
156 | - | ||
157 | - self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes) | ||
158 | - | ||
159 | - def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
160 | - half_resolution=False, expand_ratio=1, activation="swish"): | ||
161 | - blocks = [] | ||
162 | - for i in range(depth): | ||
163 | - stride = 2 if half_resolution and i==0 else 1 | ||
164 | - blocks.append( | ||
165 | - MBConv(in_channel, out_channel, kernel_size, | ||
166 | - stride=stride, expand_ratio=expand_ratio, activation=activation)) | ||
167 | - in_channel = out_channel | ||
168 | - | ||
169 | - return nn.Sequential(*blocks) | ||
170 | - | ||
171 | - def forward(self, x): | ||
172 | - assert x.size()[-2:] == self.img_size, \ | ||
173 | - 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
174 | - | ||
175 | - x = self.stage1(x) | ||
176 | - x = self.stage2(x) | ||
177 | - x = self.stage3(x) | ||
178 | - x = self.stage4(x) | ||
179 | - x = self.stage5(x) | ||
180 | - x = self.stage6(x) | ||
181 | - x = self.stage7(x) | ||
182 | - x = self.stage8(x) | ||
183 | - x = self.stage9(x) | ||
184 | - x = x.reshape(x.size(0), -1) | ||
185 | - x = self.fc(x) | ||
186 | - return x, x |
1 | -import math | ||
2 | -import torch.nn as nn | ||
3 | -import torch.nn.functional as F | ||
4 | - | ||
5 | - | ||
6 | -def round_fn(orig, multiplier): | ||
7 | - if not multiplier: | ||
8 | - return orig | ||
9 | - | ||
10 | - return int(math.ceil(multiplier * orig)) | ||
11 | - | ||
12 | - | ||
13 | -def get_activation_fn(activation): | ||
14 | - if activation == "swish": | ||
15 | - return Swish | ||
16 | - | ||
17 | - elif activation == "relu": | ||
18 | - return nn.ReLU | ||
19 | - | ||
20 | - else: | ||
21 | - raise Exception('Unkown activation %s' % activation) | ||
22 | - | ||
23 | - | ||
24 | -class Swish(nn.Module): | ||
25 | - """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
26 | - | ||
27 | - def __init__(self, inplace=False): | ||
28 | - super().__init__() | ||
29 | - self.inplace = True | ||
30 | - | ||
31 | - def forward(self, x): | ||
32 | - if self.inplace: | ||
33 | - x.mul_(F.sigmoid(x)) | ||
34 | - return x | ||
35 | - else: | ||
36 | - return x * F.sigmoid(x) | ||
37 | - | ||
38 | - | ||
39 | -class ConvBlock(nn.Module): | ||
40 | - """ Conv + BatchNorm + Activation """ | ||
41 | - | ||
42 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
43 | - padding=0, stride=1, activation="swish"): | ||
44 | - super().__init__() | ||
45 | - self.fw = nn.Sequential( | ||
46 | - nn.Conv2d(in_channel, out_channel, kernel_size, | ||
47 | - padding=padding, stride=stride, bias=False), | ||
48 | - nn.BatchNorm2d(out_channel), | ||
49 | - get_activation_fn(activation)()) | ||
50 | - | ||
51 | - def forward(self, x): | ||
52 | - return self.fw(x) | ||
53 | - | ||
54 | - | ||
55 | -class DepthwiseConvBlock(nn.Module): | ||
56 | - """ DepthwiseConv2D + BatchNorm + Activation """ | ||
57 | - | ||
58 | - def __init__(self, in_channel, kernel_size, | ||
59 | - padding=0, stride=1, activation="swish"): | ||
60 | - super().__init__() | ||
61 | - self.fw = nn.Sequential( | ||
62 | - nn.Conv2d(in_channel, in_channel, kernel_size, | ||
63 | - padding=padding, stride=stride, groups=in_channel, bias=False), | ||
64 | - nn.BatchNorm2d(in_channel), | ||
65 | - get_activation_fn(activation)()) | ||
66 | - | ||
67 | - def forward(self, x): | ||
68 | - return self.fw(x) | ||
69 | - | ||
70 | - | ||
71 | -class SEBlock(nn.Module): | ||
72 | - """ Squeeze and Excitation Block """ | ||
73 | - | ||
74 | - def __init__(self, in_channel, se_ratio=16): | ||
75 | - super().__init__() | ||
76 | - self.global_avgpool = nn.AdaptiveAvgPool2d((1,1)) | ||
77 | - inter_channel = in_channel // se_ratio | ||
78 | - | ||
79 | - self.reduce = nn.Sequential( | ||
80 | - nn.Conv2d(in_channel, inter_channel, | ||
81 | - kernel_size=1, padding=0, stride=1), | ||
82 | - nn.ReLU()) | ||
83 | - | ||
84 | - self.expand = nn.Sequential( | ||
85 | - nn.Conv2d(inter_channel, in_channel, | ||
86 | - kernel_size=1, padding=0, stride=1), | ||
87 | - nn.Sigmoid()) | ||
88 | - | ||
89 | - | ||
90 | - def forward(self, x): | ||
91 | - s = self.global_avgpool(x) | ||
92 | - s = self.reduce(s) | ||
93 | - s = self.expand(s) | ||
94 | - return x * s | ||
95 | - | ||
96 | - | ||
97 | -class MBConv(nn.Module): | ||
98 | - """ Inverted residual block """ | ||
99 | - | ||
100 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
101 | - stride=1, expand_ratio=1, activation="swish", use_seblock=False): | ||
102 | - super().__init__() | ||
103 | - self.in_channel = in_channel | ||
104 | - self.out_channel = out_channel | ||
105 | - self.expand_ratio = expand_ratio | ||
106 | - self.stride = stride | ||
107 | - self.use_seblock = use_seblock | ||
108 | - | ||
109 | - if expand_ratio != 1: | ||
110 | - self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
111 | - activation=activation) | ||
112 | - | ||
113 | - self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
114 | - padding=(kernel_size-1)//2, | ||
115 | - stride=stride, activation=activation) | ||
116 | - | ||
117 | - if use_seblock: | ||
118 | - self.seblock = SEBlock(in_channel*expand_ratio) | ||
119 | - | ||
120 | - self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
121 | - activation=activation) | ||
122 | - | ||
123 | - def forward(self, inputs): | ||
124 | - if self.expand_ratio != 1: | ||
125 | - x = self.expand(inputs) | ||
126 | - else: | ||
127 | - x = inputs | ||
128 | - | ||
129 | - x = self.dw_conv(x) | ||
130 | - | ||
131 | - if self.use_seblock: | ||
132 | - x = self.seblock(x) | ||
133 | - | ||
134 | - x = self.pw_conv(x) | ||
135 | - | ||
136 | - if self.in_channel == self.out_channel and \ | ||
137 | - self.stride == 1: | ||
138 | - x = x + inputs | ||
139 | - | ||
140 | - return x | ||
141 | - | ||
142 | - | ||
143 | -class Net(nn.Module): | ||
144 | - """ EfficientNet """ | ||
145 | - | ||
146 | - def __init__(self, args): | ||
147 | - super(Net, self).__init__() | ||
148 | - pi = args.pi | ||
149 | - activation = args.activation | ||
150 | - num_classes = 10 | ||
151 | - | ||
152 | - self.d = 1.2 ** pi | ||
153 | - self.w = 1.1 ** pi | ||
154 | - self.r = 1.15 ** pi | ||
155 | - self.img_size = (round_fn(32, self.r), round_fn(32, self.r)) | ||
156 | - self.use_seblock = args.use_seblock | ||
157 | - | ||
158 | - self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
159 | - kernel_size=3, padding=1, stride=2, activation=activation) | ||
160 | - | ||
161 | - self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
162 | - depth=round_fn(1, self.d), kernel_size=3, | ||
163 | - half_resolution=False, expand_ratio=1, activation=activation) | ||
164 | - | ||
165 | - self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
166 | - depth=round_fn(2, self.d), kernel_size=3, | ||
167 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
168 | - | ||
169 | - self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
170 | - depth=round_fn(2, self.d), kernel_size=5, | ||
171 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
172 | - | ||
173 | - self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
174 | - depth=round_fn(3, self.d), kernel_size=3, | ||
175 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
176 | - | ||
177 | - self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
178 | - depth=round_fn(3, self.d), kernel_size=5, | ||
179 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
180 | - | ||
181 | - self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
182 | - depth=round_fn(4, self.d), kernel_size=5, | ||
183 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
184 | - | ||
185 | - self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
186 | - depth=round_fn(1, self.d), kernel_size=3, | ||
187 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
188 | - | ||
189 | - self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
190 | - kernel_size=1, activation=activation) | ||
191 | - | ||
192 | - self.fc = nn.Linear(round_fn(1280, self.w), num_classes) | ||
193 | - | ||
194 | - def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
195 | - half_resolution=False, expand_ratio=1, activation="swish"): | ||
196 | - blocks = [] | ||
197 | - for i in range(depth): | ||
198 | - stride = 2 if half_resolution and i==0 else 1 | ||
199 | - blocks.append( | ||
200 | - MBConv(in_channel, out_channel, kernel_size, | ||
201 | - stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock)) | ||
202 | - in_channel = out_channel | ||
203 | - | ||
204 | - return nn.Sequential(*blocks) | ||
205 | - | ||
206 | - def forward(self, x): | ||
207 | - assert x.size()[-2:] == self.img_size, \ | ||
208 | - 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
209 | - | ||
210 | - s = self.stage1(x) | ||
211 | - x = self.stage2(s) | ||
212 | - x = self.stage3(x) | ||
213 | - x = self.stage4(x) | ||
214 | - x = self.stage5(x) | ||
215 | - x = self.stage6(x) | ||
216 | - x = self.stage7(x) | ||
217 | - x = self.stage8(x) | ||
218 | - x = self.stage9(x) | ||
219 | - x = x.reshape(x.size(0), -1) | ||
220 | - x = self.fc(x) | ||
221 | - return x, s |
1 | -import torch.nn as nn | ||
2 | - | ||
3 | - | ||
4 | -class ResidualBlock(nn.Module): | ||
5 | - def __init__(self, in_channel, out_channel, stride): | ||
6 | - super(ResidualBlock, self).__init__() | ||
7 | - self.in_channel = in_channel | ||
8 | - self.out_channel = out_channel | ||
9 | - self.stride = stride | ||
10 | - | ||
11 | - self.conv1 = nn.Sequential( | ||
12 | - nn.Conv2d(in_channel, out_channel, | ||
13 | - kernel_size=3, padding=1, stride=stride), | ||
14 | - nn.BatchNorm2d(out_channel)) | ||
15 | - | ||
16 | - self.relu = nn.ReLU(inplace=True) | ||
17 | - | ||
18 | - self.conv2 = nn.Sequential( | ||
19 | - nn.Conv2d(out_channel, out_channel, | ||
20 | - kernel_size=3, padding=1), | ||
21 | - nn.BatchNorm2d(out_channel)) | ||
22 | - | ||
23 | - if self.in_channel != self.out_channel or \ | ||
24 | - self.stride != 1: | ||
25 | - self.down = nn.Sequential( | ||
26 | - nn.Conv2d(in_channel, out_channel, | ||
27 | - kernel_size=1, stride=stride), | ||
28 | - nn.BatchNorm2d(out_channel)) | ||
29 | - | ||
30 | - def forward(self, b): | ||
31 | - t = self.conv1(b) | ||
32 | - t = self.relu(t) | ||
33 | - t = self.conv2(t) | ||
34 | - | ||
35 | - if self.in_channel != self.out_channel or \ | ||
36 | - self.stride != 1: | ||
37 | - b = self.down(b) | ||
38 | - | ||
39 | - t += b | ||
40 | - t = self.relu(t) | ||
41 | - | ||
42 | - return t | ||
43 | - | ||
44 | - | ||
45 | -class Net(nn.Module): | ||
46 | - def __init__(self, args): | ||
47 | - super(Net, self).__init__() | ||
48 | - scale = args.scale | ||
49 | - | ||
50 | - self.stem = nn.Sequential( | ||
51 | - nn.Conv2d(3, 16, | ||
52 | - kernel_size=3, padding=1), | ||
53 | - nn.BatchNorm2d(16), | ||
54 | - nn.ReLU(inplace=True)) | ||
55 | - | ||
56 | - self.layer1 = nn.Sequential(*[ | ||
57 | - ResidualBlock(16, 16, 1) for _ in range(2*scale)]) | ||
58 | - | ||
59 | - self.layer2 = nn.Sequential(*[ | ||
60 | - ResidualBlock(in_channel=(16 if i==0 else 32), | ||
61 | - out_channel=32, | ||
62 | - stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
63 | - | ||
64 | - self.layer3 = nn.Sequential(*[ | ||
65 | - ResidualBlock(in_channel=(32 if i==0 else 64), | ||
66 | - out_channel=64, | ||
67 | - stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
68 | - | ||
69 | - self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) | ||
70 | - | ||
71 | - self.fc = nn.Linear(64, 10) | ||
72 | - | ||
73 | - for m in self.modules(): | ||
74 | - if isinstance(m, nn.Conv2d): | ||
75 | - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
76 | - | ||
77 | - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
78 | - nn.init.constant_(m.weight, 1) | ||
79 | - nn.init.constant_(m.bias, 0) | ||
80 | - | ||
81 | - def forward(self, x): | ||
82 | - s = self.stem(x) | ||
83 | - x = self.layer1(s) | ||
84 | - x = self.layer2(x) | ||
85 | - x = self.layer3(x) | ||
86 | - x = self.avg_pool(x) | ||
87 | - x = x.reshape(x.size(0), -1) | ||
88 | - x = self.fc(x) | ||
89 | - | ||
90 | - return x, s |
1 | -import os | ||
2 | -import fire | ||
3 | -import time | ||
4 | -import json | ||
5 | -import random | ||
6 | -from pprint import pprint | ||
7 | - | ||
8 | -import torch.nn as nn | ||
9 | -import torch.backends.cudnn as cudnn | ||
10 | -from torch.utils.tensorboard import SummaryWriter | ||
11 | - | ||
12 | -from networks import * | ||
13 | -from utils import * | ||
14 | - | ||
15 | - | ||
16 | -def train(**kwargs): | ||
17 | - print('\n[+] Parse arguments') | ||
18 | - args, kwargs = parse_args(kwargs) | ||
19 | - pprint(args) | ||
20 | - device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
21 | - | ||
22 | - print('\n[+] Create log dir') | ||
23 | - model_name = get_model_name(args) | ||
24 | - log_dir = os.path.join('./runs', model_name) | ||
25 | - os.makedirs(os.path.join(log_dir, 'model')) | ||
26 | - json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | ||
27 | - writer = SummaryWriter(log_dir=log_dir) | ||
28 | - | ||
29 | - if args.seed is not None: | ||
30 | - random.seed(args.seed) | ||
31 | - torch.manual_seed(args.seed) | ||
32 | - cudnn.deterministic = True | ||
33 | - | ||
34 | - print('\n[+] Create network') | ||
35 | - model = select_model(args) | ||
36 | - optimizer = select_optimizer(args, model) | ||
37 | - scheduler = select_scheduler(args, optimizer) | ||
38 | - criterion = nn.CrossEntropyLoss() | ||
39 | - if args.use_cuda: | ||
40 | - model = model.cuda() | ||
41 | - criterion = criterion.cuda() | ||
42 | - #writer.add_graph(model) | ||
43 | - | ||
44 | - print('\n[+] Load dataset') | ||
45 | - transform = get_train_transform(args, model, log_dir) | ||
46 | - val_transform = get_valid_transform(args, model) | ||
47 | - train_dataset = get_dataset(args, transform, 'train') | ||
48 | - valid_dataset = get_dataset(args, val_transform, 'val') | ||
49 | - train_loader = iter(get_inf_dataloader(args, train_dataset)) | ||
50 | - max_epoch = len(train_dataset) // args.batch_size | ||
51 | - best_acc = -1 | ||
52 | - | ||
53 | - print('\n[+] Start training') | ||
54 | - if torch.cuda.device_count() > 1: | ||
55 | - print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
56 | - model = nn.DataParallel(model) | ||
57 | - | ||
58 | - start_t = time.time() | ||
59 | - for step in range(args.start_step, args.max_step): | ||
60 | - batch = next(train_loader) | ||
61 | - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | ||
62 | - | ||
63 | - if step % args.print_step == 0: | ||
64 | - print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | ||
65 | - step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | ||
66 | - writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | ||
67 | - writer.add_scalar('train/acc1', _train_res[0], global_step=step) | ||
68 | - writer.add_scalar('train/acc5', _train_res[1], global_step=step) | ||
69 | - writer.add_scalar('train/loss', _train_res[2], global_step=step) | ||
70 | - writer.add_scalar('train/forward_time', _train_res[3], global_step=step) | ||
71 | - writer.add_scalar('train/backward_time', _train_res[4], global_step=step) | ||
72 | - print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
73 | - print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
74 | - print(' Loss : {}'.format(_train_res[2].data)) | ||
75 | - print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000)) | ||
76 | - print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | ||
77 | - | ||
78 | - if step % args.val_step == args.val_step-1: | ||
79 | - valid_loader = iter(get_dataloader(args, valid_dataset)) | ||
80 | - _valid_res = validate(args, model, criterion, valid_loader, step, writer) | ||
81 | - print('\n[+] Valid results') | ||
82 | - writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | ||
83 | - writer.add_scalar('valid/acc5', _valid_res[1], global_step=step) | ||
84 | - writer.add_scalar('valid/loss', _valid_res[2], global_step=step) | ||
85 | - print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | ||
86 | - print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100)) | ||
87 | - print(' Loss : {}'.format(_valid_res[2].data)) | ||
88 | - | ||
89 | - if _valid_res[0] > best_acc: | ||
90 | - best_acc = _valid_res[0] | ||
91 | - torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | ||
92 | - print('\n[+] Model saved') | ||
93 | - | ||
94 | - writer.close() | ||
95 | - | ||
96 | - | ||
97 | -if __name__ == '__main__': | ||
98 | - fire.Fire(train) |
1 | -import numpy as np | ||
2 | -import torch.nn as nn | ||
3 | -import torchvision.transforms as transforms | ||
4 | - | ||
5 | -from abc import ABC, abstractmethod | ||
6 | -from PIL import Image, ImageOps, ImageEnhance | ||
7 | - | ||
8 | - | ||
9 | -class BaseTransform(ABC): | ||
10 | - | ||
11 | - def __init__(self, prob, mag): | ||
12 | - self.prob = prob | ||
13 | - self.mag = mag | ||
14 | - | ||
15 | - def __call__(self, img): | ||
16 | - return transforms.RandomApply([self.transform], self.prob)(img) | ||
17 | - | ||
18 | - def __repr__(self): | ||
19 | - return '%s(prob=%.2f, magnitude=%.2f)' % \ | ||
20 | - (self.__class__.__name__, self.prob, self.mag) | ||
21 | - | ||
22 | - @abstractmethod | ||
23 | - def transform(self, img): | ||
24 | - pass | ||
25 | - | ||
26 | - | ||
27 | -class ShearXY(BaseTransform): | ||
28 | - | ||
29 | - def transform(self, img): | ||
30 | - degrees = self.mag * 360 | ||
31 | - t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR) | ||
32 | - return t(img) | ||
33 | - | ||
34 | - | ||
35 | -class TranslateXY(BaseTransform): | ||
36 | - | ||
37 | - def transform(self, img): | ||
38 | - translate = (self.mag, self.mag) | ||
39 | - t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR) | ||
40 | - return t(img) | ||
41 | - | ||
42 | - | ||
43 | -class Rotate(BaseTransform): | ||
44 | - | ||
45 | - def transform(self, img): | ||
46 | - degrees = self.mag * 360 | ||
47 | - t = transforms.RandomRotation(degrees, Image.BILINEAR) | ||
48 | - return t(img) | ||
49 | - | ||
50 | - | ||
51 | -class AutoContrast(BaseTransform): | ||
52 | - | ||
53 | - def transform(self, img): | ||
54 | - cutoff = int(self.mag * 49) | ||
55 | - return ImageOps.autocontrast(img, cutoff=cutoff) | ||
56 | - | ||
57 | - | ||
58 | -class Invert(BaseTransform): | ||
59 | - | ||
60 | - def transform(self, img): | ||
61 | - return ImageOps.invert(img) | ||
62 | - | ||
63 | - | ||
64 | -class Equalize(BaseTransform): | ||
65 | - | ||
66 | - def transform(self, img): | ||
67 | - return ImageOps.equalize(img) | ||
68 | - | ||
69 | - | ||
70 | -class Solarize(BaseTransform): | ||
71 | - | ||
72 | - def transform(self, img): | ||
73 | - threshold = (1-self.mag) * 255 | ||
74 | - return ImageOps.solarize(img, threshold) | ||
75 | - | ||
76 | - | ||
77 | -class Posterize(BaseTransform): | ||
78 | - | ||
79 | - def transform(self, img): | ||
80 | - bits = int((1-self.mag) * 8) | ||
81 | - return ImageOps.posterize(img, bits=bits) | ||
82 | - | ||
83 | - | ||
84 | -class Contrast(BaseTransform): | ||
85 | - | ||
86 | - def transform(self, img): | ||
87 | - factor = self.mag * 10 | ||
88 | - return ImageEnhance.Contrast(img).enhance(factor) | ||
89 | - | ||
90 | - | ||
91 | -class Color(BaseTransform): | ||
92 | - | ||
93 | - def transform(self, img): | ||
94 | - factor = self.mag * 10 | ||
95 | - return ImageEnhance.Color(img).enhance(factor) | ||
96 | - | ||
97 | - | ||
98 | -class Brightness(BaseTransform): | ||
99 | - | ||
100 | - def transform(self, img): | ||
101 | - factor = self.mag * 10 | ||
102 | - return ImageEnhance.Brightness(img).enhance(factor) | ||
103 | - | ||
104 | - | ||
105 | -class Sharpness(BaseTransform): | ||
106 | - | ||
107 | - def transform(self, img): | ||
108 | - factor = self.mag * 10 | ||
109 | - return ImageEnhance.Sharpness(img).enhance(factor) | ||
110 | - | ||
111 | - | ||
112 | -class Cutout(BaseTransform): | ||
113 | - | ||
114 | - def transform(self, img): | ||
115 | - n_holes = 1 | ||
116 | - length = 24 * self.mag | ||
117 | - cutout_op = CutoutOp(n_holes=n_holes, length=length) | ||
118 | - return cutout_op(img) | ||
119 | - | ||
120 | - | ||
121 | -class CutoutOp(object): | ||
122 | - """ | ||
123 | - https://github.com/uoguelph-mlrg/Cutout | ||
124 | - | ||
125 | - Randomly mask out one or more patches from an image. | ||
126 | - | ||
127 | - Args: | ||
128 | - n_holes (int): Number of patches to cut out of each image. | ||
129 | - length (int): The length (in pixels) of each square patch. | ||
130 | - """ | ||
131 | - def __init__(self, n_holes, length): | ||
132 | - self.n_holes = n_holes | ||
133 | - self.length = length | ||
134 | - | ||
135 | - def __call__(self, img): | ||
136 | - """ | ||
137 | - Args: | ||
138 | - img (Tensor): Tensor image of size (C, H, W). | ||
139 | - Returns: | ||
140 | - Tensor: Image with n_holes of dimension length x length cut out of it. | ||
141 | - """ | ||
142 | - w, h = img.size | ||
143 | - | ||
144 | - mask = np.ones((h, w, 1), np.uint8) | ||
145 | - | ||
146 | - for n in range(self.n_holes): | ||
147 | - y = np.random.randint(h) | ||
148 | - x = np.random.randint(w) | ||
149 | - | ||
150 | - y1 = np.clip(y - self.length // 2, 0, h).astype(int) | ||
151 | - y2 = np.clip(y + self.length // 2, 0, h).astype(int) | ||
152 | - x1 = np.clip(x - self.length // 2, 0, w).astype(int) | ||
153 | - x2 = np.clip(x + self.length // 2, 0, w).astype(int) | ||
154 | - | ||
155 | - mask[y1: y2, x1: x2, :] = 0. | ||
156 | - | ||
157 | - img = mask*np.asarray(img).astype(np.uint8) | ||
158 | - img = Image.fromarray(mask*np.asarray(img)) | ||
159 | - | ||
160 | - return img | ||
161 | - |
1 | -import os | ||
2 | -import time | ||
3 | -import importlib | ||
4 | -import collections | ||
5 | -import pickle as cp | ||
6 | -import glob | ||
7 | -import numpy as np | ||
8 | - | ||
9 | -import torch | ||
10 | -import torchvision | ||
11 | -import torch.nn.functional as F | ||
12 | -import torchvision.models as models | ||
13 | -import torchvision.transforms as transforms | ||
14 | -from torch.utils.data import Subset | ||
15 | -from torch.utils.data import Dataset, DataLoader | ||
16 | - | ||
17 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
18 | - | ||
19 | - | ||
20 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | ||
21 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | ||
22 | -current_epoch = 0 | ||
23 | - | ||
24 | - | ||
25 | -def split_dataset(args, dataset, k): | ||
26 | - # load dataset | ||
27 | - X = list(range(len(dataset))) | ||
28 | - Y = dataset.targets | ||
29 | - | ||
30 | - # split to k-fold | ||
31 | - assert len(X) == len(Y) | ||
32 | - | ||
33 | - def _it_to_list(_it): | ||
34 | - return list(zip(*list(_it))) | ||
35 | - | ||
36 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
37 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
38 | - | ||
39 | - return Dm_indexes, Da_indexes | ||
40 | - | ||
41 | - | ||
42 | -def concat_image_features(image, features, max_features=3): | ||
43 | - _, h, w = image.shape | ||
44 | - | ||
45 | - max_features = min(features.size(0), max_features) | ||
46 | - image_feature = image.clone() | ||
47 | - | ||
48 | - for i in range(max_features): | ||
49 | - feature = features[i:i+1] | ||
50 | - _min, _max = torch.min(feature), torch.max(feature) | ||
51 | - feature = (feature - _min) / (_max - _min + 1e-6) | ||
52 | - feature = torch.cat([feature]*3, 0) | ||
53 | - feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
54 | - feature = F.upsample(feature, size=(h,w), mode="bilinear") | ||
55 | - feature = feature.view(3, h, w) | ||
56 | - image_feature = torch.cat((image_feature, feature), 2) | ||
57 | - | ||
58 | - return image_feature | ||
59 | - | ||
60 | - | ||
61 | -def get_model_name(args): | ||
62 | - from datetime import datetime | ||
63 | - now = datetime.now() | ||
64 | - date_time = now.strftime("%B_%d_%H:%M:%S") | ||
65 | - model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
66 | - return model_name | ||
67 | - | ||
68 | - | ||
69 | -def dict_to_namedtuple(d): | ||
70 | - Args = collections.namedtuple('Args', sorted(d.keys())) | ||
71 | - | ||
72 | - for k,v in d.items(): | ||
73 | - if type(v) is dict: | ||
74 | - d[k] = dict_to_namedtuple(v) | ||
75 | - | ||
76 | - elif type(v) is str: | ||
77 | - try: | ||
78 | - d[k] = eval(v) | ||
79 | - except: | ||
80 | - d[k] = v | ||
81 | - | ||
82 | - args = Args(**d) | ||
83 | - return args | ||
84 | - | ||
85 | - | ||
86 | -def parse_args(kwargs): | ||
87 | - # combine with default args | ||
88 | - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10' | ||
89 | - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10' | ||
90 | - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
91 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1 | ||
92 | - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
93 | - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
94 | - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
95 | - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
96 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000 | ||
97 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000 | ||
98 | - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
99 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
100 | - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
101 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000 | ||
102 | - kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | ||
103 | - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
104 | - | ||
105 | - # to named tuple | ||
106 | - args = dict_to_namedtuple(kwargs) | ||
107 | - return args, kwargs | ||
108 | - | ||
109 | - | ||
110 | -def select_model(args): | ||
111 | - if args.network in models.__dict__: | ||
112 | - backbone = models.__dict__[args.network]() | ||
113 | - model = BaseNet(backbone, args) | ||
114 | - else: | ||
115 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
116 | - model = Net(args) | ||
117 | - | ||
118 | - print(model) | ||
119 | - return model | ||
120 | - | ||
121 | - | ||
122 | -def select_optimizer(args, model): | ||
123 | - if args.optimizer == 'sgd': | ||
124 | - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
125 | - elif args.optimizer == 'rms': | ||
126 | - #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5) | ||
127 | - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
128 | - elif args.optimizer == 'adam': | ||
129 | - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
130 | - else: | ||
131 | - raise Exception('Unknown Optimizer') | ||
132 | - return optimizer | ||
133 | - | ||
134 | - | ||
135 | -def select_scheduler(args, optimizer): | ||
136 | - if not args.scheduler or args.scheduler == 'None': | ||
137 | - return None | ||
138 | - elif args.scheduler =='clr': | ||
139 | - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
140 | - elif args.scheduler =='exp': | ||
141 | - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
142 | - else: | ||
143 | - raise Exception('Unknown Scheduler') | ||
144 | - | ||
145 | - | ||
146 | -class CustomDataset(Dataset): | ||
147 | - def __init__(self, path, transform = None): | ||
148 | - self.path = path | ||
149 | - self.transform = transform | ||
150 | - self.img = np.load(path) | ||
151 | - self.len = self.img.shape[0] | ||
152 | - | ||
153 | - def __len__(self): | ||
154 | - return self.len | ||
155 | - | ||
156 | - def __getitem__(self, idx): | ||
157 | - if self.transforms is not None: | ||
158 | - img = self.transforms(img) | ||
159 | - return img | ||
160 | - | ||
161 | -def get_dataset(args, transform, split='train'): | ||
162 | - assert split in ['train', 'val', 'test', 'trainval'] | ||
163 | - | ||
164 | - if args.dataset == 'cifar10': | ||
165 | - train = split in ['train', 'val', 'trainval'] | ||
166 | - dataset = torchvision.datasets.CIFAR10(DATASET_PATH, | ||
167 | - train=train, | ||
168 | - transform=transform, | ||
169 | - download=True) | ||
170 | - | ||
171 | - if split in ['train', 'val']: | ||
172 | - split_path = os.path.join(DATASET_PATH, | ||
173 | - 'cifar-10-batches-py', 'train_val_index.cp') | ||
174 | - | ||
175 | - if not os.path.exists(split_path): | ||
176 | - [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
177 | - split_index = {'train':train_index, 'val':val_index} | ||
178 | - cp.dump(split_index, open(split_path, 'wb')) | ||
179 | - | ||
180 | - split_index = cp.load(open(split_path, 'rb')) | ||
181 | - dataset = Subset(dataset, split_index[split]) | ||
182 | - | ||
183 | - elif args.dataset == 'imagenet': | ||
184 | - dataset = torchvision.datasets.ImageNet(DATASET_PATH, | ||
185 | - split=split, | ||
186 | - transform=transform, | ||
187 | - download=(split is 'val')) | ||
188 | - | ||
189 | - elif args.dataset == 'BraTS': | ||
190 | - if split in ['train']: | ||
191 | - dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) | ||
192 | - else: | ||
193 | - dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) | ||
194 | - | ||
195 | - | ||
196 | - else: | ||
197 | - raise Exception('Unknown dataset') | ||
198 | - | ||
199 | - return dataset | ||
200 | - | ||
201 | - | ||
202 | -def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
203 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
204 | - batch_size=args.batch_size, | ||
205 | - shuffle=shuffle, | ||
206 | - num_workers=args.num_workers, | ||
207 | - pin_memory=pin_memory) | ||
208 | - return data_loader | ||
209 | - | ||
210 | - | ||
211 | -def get_inf_dataloader(args, dataset): | ||
212 | - global current_epoch | ||
213 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
214 | - | ||
215 | - while True: | ||
216 | - try: | ||
217 | - batch = next(data_loader) | ||
218 | - | ||
219 | - except StopIteration: | ||
220 | - current_epoch += 1 | ||
221 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
222 | - batch = next(data_loader) | ||
223 | - | ||
224 | - yield batch | ||
225 | - | ||
226 | - | ||
227 | -def get_train_transform(args, model, log_dir=None): | ||
228 | - if args.fast_auto_augment: | ||
229 | - assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet | ||
230 | - | ||
231 | - from fast_auto_augment import fast_auto_augment | ||
232 | - if args.augment_path: | ||
233 | - transform = cp.load(open(args.augment_path, 'rb')) | ||
234 | - os.system('cp {} {}'.format( | ||
235 | - args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | ||
236 | - else: | ||
237 | - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | ||
238 | - if log_dir: | ||
239 | - cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | ||
240 | - | ||
241 | - elif args.dataset == 'cifar10': | ||
242 | - transform = transforms.Compose([ | ||
243 | - transforms.Pad(4), | ||
244 | - transforms.RandomCrop(32), | ||
245 | - transforms.RandomHorizontalFlip(), | ||
246 | - transforms.ToTensor() | ||
247 | - ]) | ||
248 | - | ||
249 | - elif args.dataset == 'imagenet': | ||
250 | - resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
251 | - transform = transforms.Compose([ | ||
252 | - transforms.Resize([resize_h, resize_w]), | ||
253 | - transforms.RandomCrop(model.img_size), | ||
254 | - transforms.RandomHorizontalFlip(), | ||
255 | - transforms.ToTensor() | ||
256 | - ]) | ||
257 | - | ||
258 | - elif args.dataset == 'BraTS': | ||
259 | - resize_h, resize_w = 256, 256 | ||
260 | - transform = transforms.Compose([ | ||
261 | - transforms.Resize([resize_h, resize_w]), | ||
262 | - transforms.RandomCrop(model.img_size), | ||
263 | - transforms.RandomHorizontalFlip(), | ||
264 | - transforms.ToTensor() | ||
265 | - ]) | ||
266 | - else: | ||
267 | - raise Exception('Unknown Dataset') | ||
268 | - | ||
269 | - print(transform) | ||
270 | - | ||
271 | - return transform | ||
272 | - | ||
273 | - | ||
274 | -def get_valid_transform(args, model): | ||
275 | - if args.dataset == 'cifar10': | ||
276 | - val_transform = transforms.Compose([ | ||
277 | - transforms.Resize(32), | ||
278 | - transforms.ToTensor() | ||
279 | - ]) | ||
280 | - | ||
281 | - elif args.dataset == 'imagenet': | ||
282 | - resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875) | ||
283 | - val_transform = transforms.Compose([ | ||
284 | - transforms.Resize([resize_h, resize_w]), | ||
285 | - transforms.ToTensor() | ||
286 | - ]) | ||
287 | - elif args.dataset == 'BraTS': | ||
288 | - resize_h, resize_w = 256, 256 | ||
289 | - val_transform = transforms.Compose([ | ||
290 | - transforms.Resize([resize_h, resize_w]), | ||
291 | - transforms.ToTensor() | ||
292 | - ]) | ||
293 | - else: | ||
294 | - raise Exception('Unknown Dataset') | ||
295 | - | ||
296 | - return val_transform | ||
297 | - | ||
298 | - | ||
299 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
300 | - model.train() | ||
301 | - images, target = batch | ||
302 | - | ||
303 | - if device: | ||
304 | - images = images.to(device) | ||
305 | - target = target.to(device) | ||
306 | - | ||
307 | - elif args.use_cuda: | ||
308 | - images = images.cuda(non_blocking=True) | ||
309 | - target = target.cuda(non_blocking=True) | ||
310 | - | ||
311 | - # compute output | ||
312 | - start_t = time.time() | ||
313 | - output, first = model(images) | ||
314 | - forward_t = time.time() - start_t | ||
315 | - loss = criterion(output, target) | ||
316 | - | ||
317 | - # measure accuracy and record loss | ||
318 | - acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
319 | - acc1 /= images.size(0) | ||
320 | - acc5 /= images.size(0) | ||
321 | - | ||
322 | - # compute gradient and do SGD step | ||
323 | - optimizer.zero_grad() | ||
324 | - start_t = time.time() | ||
325 | - loss.backward() | ||
326 | - backward_t = time.time() - start_t | ||
327 | - optimizer.step() | ||
328 | - if scheduler: scheduler.step() | ||
329 | - | ||
330 | - if writer and step % args.print_step == 0: | ||
331 | - n_imgs = min(images.size(0), 10) | ||
332 | - for j in range(n_imgs): | ||
333 | - writer.add_image('train/input_image', | ||
334 | - concat_image_features(images[j], first[j]), global_step=step) | ||
335 | - | ||
336 | - return acc1, acc5, loss, forward_t, backward_t | ||
337 | - | ||
338 | - | ||
339 | -def validate(args, model, criterion, valid_loader, step, writer, device=None): | ||
340 | - # switch to evaluate mode | ||
341 | - model.eval() | ||
342 | - | ||
343 | - acc1, acc5 = 0, 0 | ||
344 | - samples = 0 | ||
345 | - infer_t = 0 | ||
346 | - | ||
347 | - with torch.no_grad(): | ||
348 | - for i, (images, target) in enumerate(valid_loader): | ||
349 | - | ||
350 | - start_t = time.time() | ||
351 | - if device: | ||
352 | - images = images.to(device) | ||
353 | - target = target.to(device) | ||
354 | - | ||
355 | - elif args.use_cuda is not None: | ||
356 | - images = images.cuda(non_blocking=True) | ||
357 | - target = target.cuda(non_blocking=True) | ||
358 | - | ||
359 | - # compute output | ||
360 | - output, first = model(images) | ||
361 | - loss = criterion(output, target) | ||
362 | - infer_t += time.time() - start_t | ||
363 | - | ||
364 | - # measure accuracy and record loss | ||
365 | - _acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
366 | - acc1 += _acc1 | ||
367 | - acc5 += _acc5 | ||
368 | - samples += images.size(0) | ||
369 | - | ||
370 | - acc1 /= samples | ||
371 | - acc5 /= samples | ||
372 | - | ||
373 | - if writer: | ||
374 | - n_imgs = min(images.size(0), 10) | ||
375 | - for j in range(n_imgs): | ||
376 | - writer.add_image('valid/input_image', | ||
377 | - concat_image_features(images[j], first[j]), global_step=step) | ||
378 | - | ||
379 | - return acc1, acc5, loss, infer_t | ||
380 | - | ||
381 | - | ||
382 | -def accuracy(output, target, topk=(1,)): | ||
383 | - """Computes the accuracy over the k top predictions for the specified values of k""" | ||
384 | - with torch.no_grad(): | ||
385 | - maxk = max(topk) | ||
386 | - batch_size = target.size(0) | ||
387 | - | ||
388 | - _, pred = output.topk(maxk, 1, True, True) | ||
389 | - pred = pred.t() | ||
390 | - correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
391 | - | ||
392 | - res = [] | ||
393 | - for k in topk: | ||
394 | - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
395 | - res.append(correct_k) | ||
396 | - return res | ||
397 | - |
-
Please register or login to post a comment