Showing
21 changed files
with
0 additions
and
1239 deletions
| 1 | -# Byte-compiled / optimized / DLL files | ||
| 2 | -__pycache__/ | ||
| 3 | -*.py[cod] | ||
| 4 | -*$py.class | ||
| 5 | - | ||
| 6 | -# C extensions | ||
| 7 | -*.so | ||
| 8 | - | ||
| 9 | -# Distribution / packaging | ||
| 10 | -.Python | ||
| 11 | -build/ | ||
| 12 | -develop-eggs/ | ||
| 13 | -dist/ | ||
| 14 | -downloads/ | ||
| 15 | -eggs/ | ||
| 16 | -.eggs/ | ||
| 17 | -lib/ | ||
| 18 | -lib64/ | ||
| 19 | -parts/ | ||
| 20 | -sdist/ | ||
| 21 | -var/ | ||
| 22 | -wheels/ | ||
| 23 | -*.egg-info/ | ||
| 24 | -.installed.cfg | ||
| 25 | -*.egg | ||
| 26 | -MANIFEST | ||
| 27 | - | ||
| 28 | -# PyInstaller | ||
| 29 | -# Usually these files are written by a python script from a template | ||
| 30 | -# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
| 31 | -*.manifest | ||
| 32 | -*.spec | ||
| 33 | - | ||
| 34 | -# Installer logs | ||
| 35 | -pip-log.txt | ||
| 36 | -pip-delete-this-directory.txt | ||
| 37 | - | ||
| 38 | -# Unit test / coverage reports | ||
| 39 | -htmlcov/ | ||
| 40 | -.tox/ | ||
| 41 | -.coverage | ||
| 42 | -.coverage.* | ||
| 43 | -.cache | ||
| 44 | -nosetests.xml | ||
| 45 | -coverage.xml | ||
| 46 | -*.cover | ||
| 47 | -.hypothesis/ | ||
| 48 | -.pytest_cache/ | ||
| 49 | - | ||
| 50 | -# Translations | ||
| 51 | -*.mo | ||
| 52 | -*.pot | ||
| 53 | - | ||
| 54 | -# Django stuff: | ||
| 55 | -*.log | ||
| 56 | -local_settings.py | ||
| 57 | -db.sqlite3 | ||
| 58 | - | ||
| 59 | -# Flask stuff: | ||
| 60 | -instance/ | ||
| 61 | -.webassets-cache | ||
| 62 | - | ||
| 63 | -# Scrapy stuff: | ||
| 64 | -.scrapy | ||
| 65 | - | ||
| 66 | -# Sphinx documentation | ||
| 67 | -docs/_build/ | ||
| 68 | - | ||
| 69 | -# PyBuilder | ||
| 70 | -target/ | ||
| 71 | - | ||
| 72 | -# Jupyter Notebook | ||
| 73 | -.ipynb_checkpoints | ||
| 74 | - | ||
| 75 | -# pyenv | ||
| 76 | -.python-version | ||
| 77 | - | ||
| 78 | -# celery beat schedule file | ||
| 79 | -celerybeat-schedule | ||
| 80 | - | ||
| 81 | -# SageMath parsed files | ||
| 82 | -*.sage.py | ||
| 83 | - | ||
| 84 | -# Environments | ||
| 85 | -.env | ||
| 86 | -.venv | ||
| 87 | -env/ | ||
| 88 | -venv/ | ||
| 89 | -ENV/ | ||
| 90 | -env.bak/ | ||
| 91 | -venv.bak/ | ||
| 92 | - | ||
| 93 | -# Spyder project settings | ||
| 94 | -.spyderproject | ||
| 95 | -.spyproject | ||
| 96 | - | ||
| 97 | -# Rope project settings | ||
| 98 | -.ropeproject | ||
| 99 | - | ||
| 100 | -# mkdocs documentation | ||
| 101 | -/site | ||
| 102 | - | ||
| 103 | -# mypy | ||
| 104 | -.mypy_cache/ | ||
| 105 | - |
| 1 | -# Fast Autoaugment | ||
| 2 | -<img src="figures/faa.png" width=800px> | ||
| 3 | - | ||
| 4 | -A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946). | ||
| 5 | - | ||
| 6 | -## Prerequisite | ||
| 7 | -* torch==1.1.0 | ||
| 8 | -* torchvision==0.2.2 | ||
| 9 | -* hyperopt==0.1.2 | ||
| 10 | -* future==0.17.1 | ||
| 11 | -* tb-nightly==1.15.0a20190622 | ||
| 12 | - | ||
| 13 | -## Usage | ||
| 14 | -### Training | ||
| 15 | -#### CIFAR10 | ||
| 16 | -```bash | ||
| 17 | -# ResNet20 (w/o FastAutoAugment) | ||
| 18 | -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False | ||
| 19 | - | ||
| 20 | -# ResNet20 (w/ FastAutoAugment) | ||
| 21 | -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True | ||
| 22 | - | ||
| 23 | -# ResNet20 (w/ FastAutoAugment, Pre-found policy) | ||
| 24 | -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \ | ||
| 25 | - --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp | ||
| 26 | - | ||
| 27 | -# ResNet32 (w/o FastAutoAugment) | ||
| 28 | -python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False | ||
| 29 | - | ||
| 30 | -# ResNet32 (w/ FastAutoAugment) | ||
| 31 | -python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True | ||
| 32 | - | ||
| 33 | -# EfficientNet (w/ FastAutoAugment) | ||
| 34 | -python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
| 35 | - --network=efficientnet_cifar10 --activation=swish | ||
| 36 | -``` | ||
| 37 | - | ||
| 38 | -#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html)) | ||
| 39 | -```bash | ||
| 40 | - | ||
| 41 | -# BaseNet (w/o FastAutoAugment) | ||
| 42 | -python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50 | ||
| 43 | - | ||
| 44 | -# EfficientNet (w/ FastAutoAugment) (UnderConstruction) | ||
| 45 | -python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \ | ||
| 46 | - --network=efficientnet --activation=swish | ||
| 47 | -``` | ||
| 48 | - | ||
| 49 | -### Eval | ||
| 50 | -```bash | ||
| 51 | -# Single Image testing | ||
| 52 | -python eval.py --model_path=runs/ResNet_Scale3_Basline | ||
| 53 | - | ||
| 54 | -# 5-crops testing | ||
| 55 | -python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True | ||
| 56 | -``` | ||
| 57 | - | ||
| 58 | -## Experiments | ||
| 59 | -### Fast AutoAugment | ||
| 60 | -#### ResNet20 (CIFAR10) | ||
| 61 | -* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)] | ||
| 62 | -* Validation Curve | ||
| 63 | -<img src="figures/resnet20_valid.png"> | ||
| 64 | - | ||
| 65 | -* Evaluation (Acc @1) | ||
| 66 | - | ||
| 67 | -| | Valid | Test(Single) | | ||
| 68 | -|----------------|-------|-------------| | ||
| 69 | -| ResNet20 | 90.70 | **91.45** | | ||
| 70 | -| ResNet20 + FAA |**92.46**| **91.45** | | ||
| 71 | - | ||
| 72 | -#### ResNet34 (CIFAR10) | ||
| 73 | -* Validation Curve | ||
| 74 | -<img src="figures/resnet34_valid.png"> | ||
| 75 | - | ||
| 76 | -* Evaluation (Acc @1) | ||
| 77 | - | ||
| 78 | -| | Valid | Test(Single) | | ||
| 79 | -|----------------|-------|-------------| | ||
| 80 | -| ResNet34 | 91.54 | 91.47 | | ||
| 81 | -| ResNet34 + FAA |**92.76**| **91.99** | | ||
| 82 | - | ||
| 83 | -### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)] | ||
| 84 | -<img src="figures/pm.png"> | ||
| 85 | - | ||
| 86 | -### Augmented images | ||
| 87 | -<img src="figures/augmented_images.png"> | ||
| 88 | -<img src="figures/augmented_images2.png"> |
| 1 | -import os | ||
| 2 | -import fire | ||
| 3 | -import json | ||
| 4 | -from pprint import pprint | ||
| 5 | - | ||
| 6 | -import torch | ||
| 7 | -import torch.nn as nn | ||
| 8 | - | ||
| 9 | -from utils import * | ||
| 10 | - | ||
| 11 | - | ||
| 12 | -def eval(model_path): | ||
| 13 | - print('\n[+] Parse arguments') | ||
| 14 | - kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 15 | - kwargs = json.loads(open(kwargs_path).read()) | ||
| 16 | - args, kwargs = parse_args(kwargs) | ||
| 17 | - pprint(args) | ||
| 18 | - device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 19 | - | ||
| 20 | - print('\n[+] Create network') | ||
| 21 | - model = select_model(args) | ||
| 22 | - optimizer = select_optimizer(args, model) | ||
| 23 | - criterion = nn.CrossEntropyLoss() | ||
| 24 | - if args.use_cuda: | ||
| 25 | - model = model.cuda() | ||
| 26 | - criterion = criterion.cuda() | ||
| 27 | - | ||
| 28 | - print('\n[+] Load model') | ||
| 29 | - weight_path = os.path.join(model_path, 'model', 'model.pt') | ||
| 30 | - model.load_state_dict(torch.load(weight_path)) | ||
| 31 | - | ||
| 32 | - print('\n[+] Load dataset') | ||
| 33 | - test_transform = get_valid_transform(args, model) | ||
| 34 | - test_dataset = get_dataset(args, test_transform, 'test') | ||
| 35 | - test_loader = iter(get_dataloader(args, test_dataset)) | ||
| 36 | - | ||
| 37 | - print('\n[+] Start testing') | ||
| 38 | - _test_res = validate(args, model, criterion, test_loader, step=0, writer=None) | ||
| 39 | - | ||
| 40 | - print('\n[+] Valid results') | ||
| 41 | - print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | ||
| 42 | - print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | ||
| 43 | - print(' Loss : {:.3f}'.format(_test_res[2].data)) | ||
| 44 | - print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset))) | ||
| 45 | - | ||
| 46 | - | ||
| 47 | -if __name__ == '__main__': | ||
| 48 | - fire.Fire(eval) |
| 1 | -import copy | ||
| 2 | -import json | ||
| 3 | -import time | ||
| 4 | -import torch | ||
| 5 | -import random | ||
| 6 | -import torchvision.transforms as transforms | ||
| 7 | - | ||
| 8 | -from torch.utils.data import Subset | ||
| 9 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
| 10 | -from concurrent.futures import ProcessPoolExecutor | ||
| 11 | - | ||
| 12 | -from transforms import * | ||
| 13 | -from hyperopt import fmin, tpe, hp, STATUS_OK, Trials | ||
| 14 | -from utils import * | ||
| 15 | - | ||
| 16 | - | ||
| 17 | -DEFALUT_CANDIDATES = [ | ||
| 18 | - ShearXY, | ||
| 19 | - TranslateXY, | ||
| 20 | - Rotate, | ||
| 21 | - AutoContrast, | ||
| 22 | - Invert, | ||
| 23 | - Equalize, | ||
| 24 | - Solarize, | ||
| 25 | - Posterize, | ||
| 26 | - Contrast, | ||
| 27 | - Color, | ||
| 28 | - Brightness, | ||
| 29 | - Sharpness, | ||
| 30 | - Cutout, | ||
| 31 | -# SamplePairing, | ||
| 32 | -] | ||
| 33 | - | ||
| 34 | - | ||
| 35 | -def train_child(args, model, dataset, subset_indx, device=None): | ||
| 36 | - optimizer = select_optimizer(args, model) | ||
| 37 | - scheduler = select_scheduler(args, optimizer) | ||
| 38 | - criterion = nn.CrossEntropyLoss() | ||
| 39 | - | ||
| 40 | - dataset.transform = transforms.Compose([ | ||
| 41 | - transforms.Resize(32), | ||
| 42 | - transforms.ToTensor()]) | ||
| 43 | - subset = Subset(dataset, subset_indx) | ||
| 44 | - data_loader = get_inf_dataloader(args, subset) | ||
| 45 | - | ||
| 46 | - if device: | ||
| 47 | - model = model.to(device) | ||
| 48 | - criterion = criterion.to(device) | ||
| 49 | - | ||
| 50 | - elif args.use_cuda: | ||
| 51 | - model = model.cuda() | ||
| 52 | - criterion = criterion.cuda() | ||
| 53 | - | ||
| 54 | - if torch.cuda.device_count() > 1: | ||
| 55 | - print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 56 | - model = nn.DataParallel(model) | ||
| 57 | - | ||
| 58 | - start_t = time.time() | ||
| 59 | - for step in range(args.start_step, args.max_step): | ||
| 60 | - batch = next(data_loader) | ||
| 61 | - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | ||
| 62 | - | ||
| 63 | - if step % args.print_step == 0: | ||
| 64 | - print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format( | ||
| 65 | - step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device)) | ||
| 66 | - | ||
| 67 | - print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
| 68 | - print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
| 69 | - print(' Loss : {}'.format(_train_res[2].data)) | ||
| 70 | - | ||
| 71 | - return _train_res | ||
| 72 | - | ||
| 73 | - | ||
| 74 | -def validate_child(args, model, dataset, subset_indx, transform, device=None): | ||
| 75 | - criterion = nn.CrossEntropyLoss() | ||
| 76 | - | ||
| 77 | - if device: | ||
| 78 | - model = model.to(device) | ||
| 79 | - criterion = criterion.to(device) | ||
| 80 | - | ||
| 81 | - elif args.use_cuda: | ||
| 82 | - model = model.cuda() | ||
| 83 | - criterion = criterion.cuda() | ||
| 84 | - | ||
| 85 | - dataset.transform = transform | ||
| 86 | - subset = Subset(dataset, subset_indx) | ||
| 87 | - data_loader = get_dataloader(args, subset, pin_memory=False) | ||
| 88 | - | ||
| 89 | - return validate(args, model, criterion, data_loader, 0, None, device) | ||
| 90 | - | ||
| 91 | - | ||
| 92 | -def get_next_subpolicy(transform_candidates, op_per_subpolicy=2): | ||
| 93 | - n_candidates = len(transform_candidates) | ||
| 94 | - subpolicy = [] | ||
| 95 | - | ||
| 96 | - for i in range(op_per_subpolicy): | ||
| 97 | - indx = random.randrange(n_candidates) | ||
| 98 | - prob = random.random() | ||
| 99 | - mag = random.random() | ||
| 100 | - subpolicy.append(transform_candidates[indx](prob, mag)) | ||
| 101 | - | ||
| 102 | - subpolicy = transforms.Compose([ | ||
| 103 | - *subpolicy, | ||
| 104 | - transforms.Resize(32), | ||
| 105 | - transforms.ToTensor()]) | ||
| 106 | - | ||
| 107 | - return subpolicy | ||
| 108 | - | ||
| 109 | - | ||
| 110 | -def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
| 111 | - subpolicies = [] | ||
| 112 | - | ||
| 113 | - for b in range(B): | ||
| 114 | - subpolicy = get_next_subpolicy(transform_candidates) | ||
| 115 | - val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
| 116 | - subpolicies.append((subpolicy, val_res[2])) | ||
| 117 | - | ||
| 118 | - return subpolicies | ||
| 119 | - | ||
| 120 | - | ||
| 121 | -def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device): | ||
| 122 | - | ||
| 123 | - def _objective(sampled): | ||
| 124 | - subpolicy = [transform(prob, mag) | ||
| 125 | - for transform, prob, mag in sampled] | ||
| 126 | - | ||
| 127 | - subpolicy = transforms.Compose([ | ||
| 128 | - transforms.Resize(32), | ||
| 129 | - *subpolicy, | ||
| 130 | - transforms.ToTensor()]) | ||
| 131 | - | ||
| 132 | - val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device) | ||
| 133 | - loss = val_res[2].cpu().numpy() | ||
| 134 | - return {'loss': loss, 'status': STATUS_OK } | ||
| 135 | - | ||
| 136 | - space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)), | ||
| 137 | - (hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))] | ||
| 138 | - | ||
| 139 | - trials = Trials() | ||
| 140 | - best = fmin(_objective, | ||
| 141 | - space=space, | ||
| 142 | - algo=tpe.suggest, | ||
| 143 | - max_evals=B, | ||
| 144 | - trials=trials) | ||
| 145 | - | ||
| 146 | - subpolicies = [] | ||
| 147 | - for t in trials.trials: | ||
| 148 | - vals = t['misc']['vals'] | ||
| 149 | - subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]), | ||
| 150 | - transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])] | ||
| 151 | - subpolicy = transforms.Compose([ | ||
| 152 | - ## baseline augmentation | ||
| 153 | - transforms.Pad(4), | ||
| 154 | - transforms.RandomCrop(32), | ||
| 155 | - transforms.RandomHorizontalFlip(), | ||
| 156 | - ## policy | ||
| 157 | - *subpolicy, | ||
| 158 | - ## to tensor | ||
| 159 | - transforms.ToTensor()]) | ||
| 160 | - subpolicies.append((subpolicy, t['result']['loss'])) | ||
| 161 | - | ||
| 162 | - return subpolicies | ||
| 163 | - | ||
| 164 | - | ||
| 165 | -def get_topn_subpolicies(subpolicies, N=10): | ||
| 166 | - return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N] | ||
| 167 | - | ||
| 168 | - | ||
| 169 | -def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k): | ||
| 170 | - kwargs = json.loads(args_str) | ||
| 171 | - args, kwargs = parse_args(kwargs) | ||
| 172 | - device_id = k % torch.cuda.device_count() | ||
| 173 | - device = torch.device('cuda:%d' % device_id) | ||
| 174 | - _transform = [] | ||
| 175 | - | ||
| 176 | - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | ||
| 177 | - | ||
| 178 | - # train child model | ||
| 179 | - child_model = copy.deepcopy(model) | ||
| 180 | - train_res = train_child(args, child_model, dataset, Dm_indx, device) | ||
| 181 | - | ||
| 182 | - # search sub policy | ||
| 183 | - for t in range(T): | ||
| 184 | - #subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
| 185 | - subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device) | ||
| 186 | - subpolicies = get_topn_subpolicies(subpolicies, N) | ||
| 187 | - _transform.extend([subpolicy[0] for subpolicy in subpolicies]) | ||
| 188 | - | ||
| 189 | - return _transform | ||
| 190 | - | ||
| 191 | - | ||
| 192 | -def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | ||
| 193 | - args_str = json.dumps(args._asdict()) | ||
| 194 | - dataset = get_dataset(args, None, 'trainval') | ||
| 195 | - num_process = min(torch.cuda.device_count(), num_process) | ||
| 196 | - transform, futures = [], [] | ||
| 197 | - | ||
| 198 | - torch.multiprocessing.set_start_method('spawn', force=True) | ||
| 199 | - | ||
| 200 | - if not transform_candidates: | ||
| 201 | - transform_candidates = DEFALUT_CANDIDATES | ||
| 202 | - | ||
| 203 | - # split | ||
| 204 | - Dm_indexes, Da_indexes = split_dataset(args, dataset, K) | ||
| 205 | - | ||
| 206 | - with ProcessPoolExecutor(max_workers=num_process) as executor: | ||
| 207 | - for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)): | ||
| 208 | - future = executor.submit(process_fn, | ||
| 209 | - args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k) | ||
| 210 | - futures.append(future) | ||
| 211 | - | ||
| 212 | - for future in futures: | ||
| 213 | - transform.extend(future.result()) | ||
| 214 | - | ||
| 215 | - transform = transforms.RandomChoice(transform) | ||
| 216 | - | ||
| 217 | - return transform |
File mode changed
92.9 KB
410 KB
425 KB
423 KB
48.3 KB
229 KB
227 KB
| 1 | -from .basenet import BaseNet |
| 1 | -import torch.nn as nn | ||
| 2 | - | ||
| 3 | -class BaseNet(nn.Module): | ||
| 4 | - def __init__(self, backbone, args): | ||
| 5 | - super(BaseNet, self).__init__() | ||
| 6 | - | ||
| 7 | - # Separate layers | ||
| 8 | - self.first = nn.Sequential(*list(backbone.children())[:1]) | ||
| 9 | - self.after = nn.Sequential(*list(backbone.children())[1:-1]) | ||
| 10 | - self.fc = list(backbone.children())[-1] | ||
| 11 | - | ||
| 12 | - self.img_size = (224, 224) | ||
| 13 | - | ||
| 14 | - def forward(self, x): | ||
| 15 | - f = self.first(x) | ||
| 16 | - x = self.after(f) | ||
| 17 | - x = x.reshape(x.size(0), -1) | ||
| 18 | - x = self.fc(x) | ||
| 19 | - return x, f |
| 1 | -import math | ||
| 2 | -import torch.nn as nn | ||
| 3 | -import torch.nn.functional as F | ||
| 4 | - | ||
| 5 | - | ||
| 6 | -def round_fn(orig, multiplier): | ||
| 7 | - if not multiplier: | ||
| 8 | - return orig | ||
| 9 | - | ||
| 10 | - return int(math.ceil(multiplier * orig)) | ||
| 11 | - | ||
| 12 | - | ||
| 13 | -def get_activation_fn(activation): | ||
| 14 | - if activation == "swish": | ||
| 15 | - return Swish | ||
| 16 | - | ||
| 17 | - elif activation == "relu": | ||
| 18 | - return nn.ReLU | ||
| 19 | - | ||
| 20 | - else: | ||
| 21 | - raise Exception('Unkown activation %s' % activation) | ||
| 22 | - | ||
| 23 | - | ||
| 24 | -class Swish(nn.Module): | ||
| 25 | - """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
| 26 | - | ||
| 27 | - def __init__(self, inplace=False): | ||
| 28 | - super().__init__() | ||
| 29 | - self.inplace = True | ||
| 30 | - | ||
| 31 | - def forward(self, x): | ||
| 32 | - if self.inplace: | ||
| 33 | - x.mul_(F.sigmoid(x)) | ||
| 34 | - return x | ||
| 35 | - else: | ||
| 36 | - return x * F.sigmoid(x) | ||
| 37 | - | ||
| 38 | - | ||
| 39 | -class ConvBlock(nn.Module): | ||
| 40 | - """ Conv + BatchNorm + Activation """ | ||
| 41 | - | ||
| 42 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
| 43 | - padding=0, stride=1, activation="swish"): | ||
| 44 | - super().__init__() | ||
| 45 | - self.fw = nn.Sequential( | ||
| 46 | - nn.Conv2d(in_channel, out_channel, kernel_size, | ||
| 47 | - padding=padding, stride=stride, bias=False), | ||
| 48 | - nn.BatchNorm2d(out_channel), | ||
| 49 | - get_activation_fn(activation)()) | ||
| 50 | - | ||
| 51 | - def forward(self, x): | ||
| 52 | - return self.fw(x) | ||
| 53 | - | ||
| 54 | - | ||
| 55 | -class DepthwiseConvBlock(nn.Module): | ||
| 56 | - """ DepthwiseConv2D + BatchNorm + Activation """ | ||
| 57 | - | ||
| 58 | - def __init__(self, in_channel, kernel_size, | ||
| 59 | - padding=0, stride=1, activation="swish"): | ||
| 60 | - super().__init__() | ||
| 61 | - self.fw = nn.Sequential( | ||
| 62 | - nn.Conv2d(in_channel, in_channel, kernel_size, | ||
| 63 | - padding=padding, stride=stride, groups=in_channel, bias=False), | ||
| 64 | - nn.BatchNorm2d(in_channel), | ||
| 65 | - get_activation_fn(activation)()) | ||
| 66 | - | ||
| 67 | - def forward(self, x): | ||
| 68 | - return self.fw(x) | ||
| 69 | - | ||
| 70 | - | ||
| 71 | -class MBConv(nn.Module): | ||
| 72 | - """ Inverted residual block """ | ||
| 73 | - | ||
| 74 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
| 75 | - stride=1, expand_ratio=1, activation="swish"): | ||
| 76 | - super().__init__() | ||
| 77 | - self.in_channel = in_channel | ||
| 78 | - self.out_channel = out_channel | ||
| 79 | - self.expand_ratio = expand_ratio | ||
| 80 | - self.stride = stride | ||
| 81 | - | ||
| 82 | - if expand_ratio != 1: | ||
| 83 | - self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
| 84 | - activation=activation) | ||
| 85 | - | ||
| 86 | - self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
| 87 | - padding=(kernel_size-1)//2, | ||
| 88 | - stride=stride, activation=activation) | ||
| 89 | - | ||
| 90 | - self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
| 91 | - activation=activation) | ||
| 92 | - | ||
| 93 | - def forward(self, inputs): | ||
| 94 | - if self.expand_ratio != 1: | ||
| 95 | - x = self.expand(inputs) | ||
| 96 | - else: | ||
| 97 | - x = inputs | ||
| 98 | - | ||
| 99 | - x = self.dw_conv(x) | ||
| 100 | - x = self.pw_conv(x) | ||
| 101 | - | ||
| 102 | - if self.in_channel == self.out_channel and \ | ||
| 103 | - self.stride == 1: | ||
| 104 | - x = x + inputs | ||
| 105 | - | ||
| 106 | - return x | ||
| 107 | - | ||
| 108 | - | ||
| 109 | -class Net(nn.Module): | ||
| 110 | - """ EfficientNet """ | ||
| 111 | - | ||
| 112 | - def __init__(self, args): | ||
| 113 | - super(Net, self).__init__() | ||
| 114 | - pi = args.pi | ||
| 115 | - activation = args.activation | ||
| 116 | - num_classes = args.num_classes | ||
| 117 | - | ||
| 118 | - self.d = 1.2 ** pi | ||
| 119 | - self.w = 1.1 ** pi | ||
| 120 | - self.r = 1.15 ** pi | ||
| 121 | - self.img_size = (round_fn(224, self.r), round_fn(224, self.r)) | ||
| 122 | - | ||
| 123 | - self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
| 124 | - kernel_size=3, padding=1, stride=2, activation=activation) | ||
| 125 | - | ||
| 126 | - self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
| 127 | - depth=round_fn(1, self.d), kernel_size=3, | ||
| 128 | - half_resolution=False, expand_ratio=1, activation=activation) | ||
| 129 | - | ||
| 130 | - self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
| 131 | - depth=round_fn(2, self.d), kernel_size=3, | ||
| 132 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 133 | - | ||
| 134 | - self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
| 135 | - depth=round_fn(2, self.d), kernel_size=5, | ||
| 136 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 137 | - | ||
| 138 | - self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
| 139 | - depth=round_fn(3, self.d), kernel_size=3, | ||
| 140 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 141 | - | ||
| 142 | - self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
| 143 | - depth=round_fn(3, self.d), kernel_size=5, | ||
| 144 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
| 145 | - | ||
| 146 | - self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
| 147 | - depth=round_fn(4, self.d), kernel_size=5, | ||
| 148 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 149 | - | ||
| 150 | - self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
| 151 | - depth=round_fn(1, self.d), kernel_size=3, | ||
| 152 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
| 153 | - | ||
| 154 | - self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
| 155 | - kernel_size=1, activation=activation) | ||
| 156 | - | ||
| 157 | - self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes) | ||
| 158 | - | ||
| 159 | - def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
| 160 | - half_resolution=False, expand_ratio=1, activation="swish"): | ||
| 161 | - blocks = [] | ||
| 162 | - for i in range(depth): | ||
| 163 | - stride = 2 if half_resolution and i==0 else 1 | ||
| 164 | - blocks.append( | ||
| 165 | - MBConv(in_channel, out_channel, kernel_size, | ||
| 166 | - stride=stride, expand_ratio=expand_ratio, activation=activation)) | ||
| 167 | - in_channel = out_channel | ||
| 168 | - | ||
| 169 | - return nn.Sequential(*blocks) | ||
| 170 | - | ||
| 171 | - def forward(self, x): | ||
| 172 | - assert x.size()[-2:] == self.img_size, \ | ||
| 173 | - 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
| 174 | - | ||
| 175 | - x = self.stage1(x) | ||
| 176 | - x = self.stage2(x) | ||
| 177 | - x = self.stage3(x) | ||
| 178 | - x = self.stage4(x) | ||
| 179 | - x = self.stage5(x) | ||
| 180 | - x = self.stage6(x) | ||
| 181 | - x = self.stage7(x) | ||
| 182 | - x = self.stage8(x) | ||
| 183 | - x = self.stage9(x) | ||
| 184 | - x = x.reshape(x.size(0), -1) | ||
| 185 | - x = self.fc(x) | ||
| 186 | - return x, x |
| 1 | -import math | ||
| 2 | -import torch.nn as nn | ||
| 3 | -import torch.nn.functional as F | ||
| 4 | - | ||
| 5 | - | ||
| 6 | -def round_fn(orig, multiplier): | ||
| 7 | - if not multiplier: | ||
| 8 | - return orig | ||
| 9 | - | ||
| 10 | - return int(math.ceil(multiplier * orig)) | ||
| 11 | - | ||
| 12 | - | ||
| 13 | -def get_activation_fn(activation): | ||
| 14 | - if activation == "swish": | ||
| 15 | - return Swish | ||
| 16 | - | ||
| 17 | - elif activation == "relu": | ||
| 18 | - return nn.ReLU | ||
| 19 | - | ||
| 20 | - else: | ||
| 21 | - raise Exception('Unkown activation %s' % activation) | ||
| 22 | - | ||
| 23 | - | ||
| 24 | -class Swish(nn.Module): | ||
| 25 | - """ Swish activation function, s(x) = x * sigmoid(x) """ | ||
| 26 | - | ||
| 27 | - def __init__(self, inplace=False): | ||
| 28 | - super().__init__() | ||
| 29 | - self.inplace = True | ||
| 30 | - | ||
| 31 | - def forward(self, x): | ||
| 32 | - if self.inplace: | ||
| 33 | - x.mul_(F.sigmoid(x)) | ||
| 34 | - return x | ||
| 35 | - else: | ||
| 36 | - return x * F.sigmoid(x) | ||
| 37 | - | ||
| 38 | - | ||
| 39 | -class ConvBlock(nn.Module): | ||
| 40 | - """ Conv + BatchNorm + Activation """ | ||
| 41 | - | ||
| 42 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
| 43 | - padding=0, stride=1, activation="swish"): | ||
| 44 | - super().__init__() | ||
| 45 | - self.fw = nn.Sequential( | ||
| 46 | - nn.Conv2d(in_channel, out_channel, kernel_size, | ||
| 47 | - padding=padding, stride=stride, bias=False), | ||
| 48 | - nn.BatchNorm2d(out_channel), | ||
| 49 | - get_activation_fn(activation)()) | ||
| 50 | - | ||
| 51 | - def forward(self, x): | ||
| 52 | - return self.fw(x) | ||
| 53 | - | ||
| 54 | - | ||
| 55 | -class DepthwiseConvBlock(nn.Module): | ||
| 56 | - """ DepthwiseConv2D + BatchNorm + Activation """ | ||
| 57 | - | ||
| 58 | - def __init__(self, in_channel, kernel_size, | ||
| 59 | - padding=0, stride=1, activation="swish"): | ||
| 60 | - super().__init__() | ||
| 61 | - self.fw = nn.Sequential( | ||
| 62 | - nn.Conv2d(in_channel, in_channel, kernel_size, | ||
| 63 | - padding=padding, stride=stride, groups=in_channel, bias=False), | ||
| 64 | - nn.BatchNorm2d(in_channel), | ||
| 65 | - get_activation_fn(activation)()) | ||
| 66 | - | ||
| 67 | - def forward(self, x): | ||
| 68 | - return self.fw(x) | ||
| 69 | - | ||
| 70 | - | ||
| 71 | -class SEBlock(nn.Module): | ||
| 72 | - """ Squeeze and Excitation Block """ | ||
| 73 | - | ||
| 74 | - def __init__(self, in_channel, se_ratio=16): | ||
| 75 | - super().__init__() | ||
| 76 | - self.global_avgpool = nn.AdaptiveAvgPool2d((1,1)) | ||
| 77 | - inter_channel = in_channel // se_ratio | ||
| 78 | - | ||
| 79 | - self.reduce = nn.Sequential( | ||
| 80 | - nn.Conv2d(in_channel, inter_channel, | ||
| 81 | - kernel_size=1, padding=0, stride=1), | ||
| 82 | - nn.ReLU()) | ||
| 83 | - | ||
| 84 | - self.expand = nn.Sequential( | ||
| 85 | - nn.Conv2d(inter_channel, in_channel, | ||
| 86 | - kernel_size=1, padding=0, stride=1), | ||
| 87 | - nn.Sigmoid()) | ||
| 88 | - | ||
| 89 | - | ||
| 90 | - def forward(self, x): | ||
| 91 | - s = self.global_avgpool(x) | ||
| 92 | - s = self.reduce(s) | ||
| 93 | - s = self.expand(s) | ||
| 94 | - return x * s | ||
| 95 | - | ||
| 96 | - | ||
| 97 | -class MBConv(nn.Module): | ||
| 98 | - """ Inverted residual block """ | ||
| 99 | - | ||
| 100 | - def __init__(self, in_channel, out_channel, kernel_size, | ||
| 101 | - stride=1, expand_ratio=1, activation="swish", use_seblock=False): | ||
| 102 | - super().__init__() | ||
| 103 | - self.in_channel = in_channel | ||
| 104 | - self.out_channel = out_channel | ||
| 105 | - self.expand_ratio = expand_ratio | ||
| 106 | - self.stride = stride | ||
| 107 | - self.use_seblock = use_seblock | ||
| 108 | - | ||
| 109 | - if expand_ratio != 1: | ||
| 110 | - self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1, | ||
| 111 | - activation=activation) | ||
| 112 | - | ||
| 113 | - self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size, | ||
| 114 | - padding=(kernel_size-1)//2, | ||
| 115 | - stride=stride, activation=activation) | ||
| 116 | - | ||
| 117 | - if use_seblock: | ||
| 118 | - self.seblock = SEBlock(in_channel*expand_ratio) | ||
| 119 | - | ||
| 120 | - self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1, | ||
| 121 | - activation=activation) | ||
| 122 | - | ||
| 123 | - def forward(self, inputs): | ||
| 124 | - if self.expand_ratio != 1: | ||
| 125 | - x = self.expand(inputs) | ||
| 126 | - else: | ||
| 127 | - x = inputs | ||
| 128 | - | ||
| 129 | - x = self.dw_conv(x) | ||
| 130 | - | ||
| 131 | - if self.use_seblock: | ||
| 132 | - x = self.seblock(x) | ||
| 133 | - | ||
| 134 | - x = self.pw_conv(x) | ||
| 135 | - | ||
| 136 | - if self.in_channel == self.out_channel and \ | ||
| 137 | - self.stride == 1: | ||
| 138 | - x = x + inputs | ||
| 139 | - | ||
| 140 | - return x | ||
| 141 | - | ||
| 142 | - | ||
| 143 | -class Net(nn.Module): | ||
| 144 | - """ EfficientNet """ | ||
| 145 | - | ||
| 146 | - def __init__(self, args): | ||
| 147 | - super(Net, self).__init__() | ||
| 148 | - pi = args.pi | ||
| 149 | - activation = args.activation | ||
| 150 | - num_classes = 10 | ||
| 151 | - | ||
| 152 | - self.d = 1.2 ** pi | ||
| 153 | - self.w = 1.1 ** pi | ||
| 154 | - self.r = 1.15 ** pi | ||
| 155 | - self.img_size = (round_fn(32, self.r), round_fn(32, self.r)) | ||
| 156 | - self.use_seblock = args.use_seblock | ||
| 157 | - | ||
| 158 | - self.stage1 = ConvBlock(3, round_fn(32, self.w), | ||
| 159 | - kernel_size=3, padding=1, stride=2, activation=activation) | ||
| 160 | - | ||
| 161 | - self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w), | ||
| 162 | - depth=round_fn(1, self.d), kernel_size=3, | ||
| 163 | - half_resolution=False, expand_ratio=1, activation=activation) | ||
| 164 | - | ||
| 165 | - self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w), | ||
| 166 | - depth=round_fn(2, self.d), kernel_size=3, | ||
| 167 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 168 | - | ||
| 169 | - self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w), | ||
| 170 | - depth=round_fn(2, self.d), kernel_size=5, | ||
| 171 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 172 | - | ||
| 173 | - self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w), | ||
| 174 | - depth=round_fn(3, self.d), kernel_size=3, | ||
| 175 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 176 | - | ||
| 177 | - self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w), | ||
| 178 | - depth=round_fn(3, self.d), kernel_size=5, | ||
| 179 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
| 180 | - | ||
| 181 | - self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w), | ||
| 182 | - depth=round_fn(4, self.d), kernel_size=5, | ||
| 183 | - half_resolution=True, expand_ratio=6, activation=activation) | ||
| 184 | - | ||
| 185 | - self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w), | ||
| 186 | - depth=round_fn(1, self.d), kernel_size=3, | ||
| 187 | - half_resolution=False, expand_ratio=6, activation=activation) | ||
| 188 | - | ||
| 189 | - self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w), | ||
| 190 | - kernel_size=1, activation=activation) | ||
| 191 | - | ||
| 192 | - self.fc = nn.Linear(round_fn(1280, self.w), num_classes) | ||
| 193 | - | ||
| 194 | - def make_layers(self, in_channel, out_channel, depth, kernel_size, | ||
| 195 | - half_resolution=False, expand_ratio=1, activation="swish"): | ||
| 196 | - blocks = [] | ||
| 197 | - for i in range(depth): | ||
| 198 | - stride = 2 if half_resolution and i==0 else 1 | ||
| 199 | - blocks.append( | ||
| 200 | - MBConv(in_channel, out_channel, kernel_size, | ||
| 201 | - stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock)) | ||
| 202 | - in_channel = out_channel | ||
| 203 | - | ||
| 204 | - return nn.Sequential(*blocks) | ||
| 205 | - | ||
| 206 | - def forward(self, x): | ||
| 207 | - assert x.size()[-2:] == self.img_size, \ | ||
| 208 | - 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2]) | ||
| 209 | - | ||
| 210 | - s = self.stage1(x) | ||
| 211 | - x = self.stage2(s) | ||
| 212 | - x = self.stage3(x) | ||
| 213 | - x = self.stage4(x) | ||
| 214 | - x = self.stage5(x) | ||
| 215 | - x = self.stage6(x) | ||
| 216 | - x = self.stage7(x) | ||
| 217 | - x = self.stage8(x) | ||
| 218 | - x = self.stage9(x) | ||
| 219 | - x = x.reshape(x.size(0), -1) | ||
| 220 | - x = self.fc(x) | ||
| 221 | - return x, s |
| 1 | -import torch.nn as nn | ||
| 2 | - | ||
| 3 | - | ||
| 4 | -class ResidualBlock(nn.Module): | ||
| 5 | - def __init__(self, in_channel, out_channel, stride): | ||
| 6 | - super(ResidualBlock, self).__init__() | ||
| 7 | - self.in_channel = in_channel | ||
| 8 | - self.out_channel = out_channel | ||
| 9 | - self.stride = stride | ||
| 10 | - | ||
| 11 | - self.conv1 = nn.Sequential( | ||
| 12 | - nn.Conv2d(in_channel, out_channel, | ||
| 13 | - kernel_size=3, padding=1, stride=stride), | ||
| 14 | - nn.BatchNorm2d(out_channel)) | ||
| 15 | - | ||
| 16 | - self.relu = nn.ReLU(inplace=True) | ||
| 17 | - | ||
| 18 | - self.conv2 = nn.Sequential( | ||
| 19 | - nn.Conv2d(out_channel, out_channel, | ||
| 20 | - kernel_size=3, padding=1), | ||
| 21 | - nn.BatchNorm2d(out_channel)) | ||
| 22 | - | ||
| 23 | - if self.in_channel != self.out_channel or \ | ||
| 24 | - self.stride != 1: | ||
| 25 | - self.down = nn.Sequential( | ||
| 26 | - nn.Conv2d(in_channel, out_channel, | ||
| 27 | - kernel_size=1, stride=stride), | ||
| 28 | - nn.BatchNorm2d(out_channel)) | ||
| 29 | - | ||
| 30 | - def forward(self, b): | ||
| 31 | - t = self.conv1(b) | ||
| 32 | - t = self.relu(t) | ||
| 33 | - t = self.conv2(t) | ||
| 34 | - | ||
| 35 | - if self.in_channel != self.out_channel or \ | ||
| 36 | - self.stride != 1: | ||
| 37 | - b = self.down(b) | ||
| 38 | - | ||
| 39 | - t += b | ||
| 40 | - t = self.relu(t) | ||
| 41 | - | ||
| 42 | - return t | ||
| 43 | - | ||
| 44 | - | ||
| 45 | -class Net(nn.Module): | ||
| 46 | - def __init__(self, args): | ||
| 47 | - super(Net, self).__init__() | ||
| 48 | - scale = args.scale | ||
| 49 | - | ||
| 50 | - self.stem = nn.Sequential( | ||
| 51 | - nn.Conv2d(3, 16, | ||
| 52 | - kernel_size=3, padding=1), | ||
| 53 | - nn.BatchNorm2d(16), | ||
| 54 | - nn.ReLU(inplace=True)) | ||
| 55 | - | ||
| 56 | - self.layer1 = nn.Sequential(*[ | ||
| 57 | - ResidualBlock(16, 16, 1) for _ in range(2*scale)]) | ||
| 58 | - | ||
| 59 | - self.layer2 = nn.Sequential(*[ | ||
| 60 | - ResidualBlock(in_channel=(16 if i==0 else 32), | ||
| 61 | - out_channel=32, | ||
| 62 | - stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
| 63 | - | ||
| 64 | - self.layer3 = nn.Sequential(*[ | ||
| 65 | - ResidualBlock(in_channel=(32 if i==0 else 64), | ||
| 66 | - out_channel=64, | ||
| 67 | - stride=(2 if i==0 else 1)) for i in range(2*scale)]) | ||
| 68 | - | ||
| 69 | - self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) | ||
| 70 | - | ||
| 71 | - self.fc = nn.Linear(64, 10) | ||
| 72 | - | ||
| 73 | - for m in self.modules(): | ||
| 74 | - if isinstance(m, nn.Conv2d): | ||
| 75 | - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
| 76 | - | ||
| 77 | - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
| 78 | - nn.init.constant_(m.weight, 1) | ||
| 79 | - nn.init.constant_(m.bias, 0) | ||
| 80 | - | ||
| 81 | - def forward(self, x): | ||
| 82 | - s = self.stem(x) | ||
| 83 | - x = self.layer1(s) | ||
| 84 | - x = self.layer2(x) | ||
| 85 | - x = self.layer3(x) | ||
| 86 | - x = self.avg_pool(x) | ||
| 87 | - x = x.reshape(x.size(0), -1) | ||
| 88 | - x = self.fc(x) | ||
| 89 | - | ||
| 90 | - return x, s |
| 1 | -import os | ||
| 2 | -import fire | ||
| 3 | -import time | ||
| 4 | -import json | ||
| 5 | -import random | ||
| 6 | -from pprint import pprint | ||
| 7 | - | ||
| 8 | -import torch.nn as nn | ||
| 9 | -import torch.backends.cudnn as cudnn | ||
| 10 | -from torch.utils.tensorboard import SummaryWriter | ||
| 11 | - | ||
| 12 | -from networks import * | ||
| 13 | -from utils import * | ||
| 14 | - | ||
| 15 | - | ||
| 16 | -def train(**kwargs): | ||
| 17 | - print('\n[+] Parse arguments') | ||
| 18 | - args, kwargs = parse_args(kwargs) | ||
| 19 | - pprint(args) | ||
| 20 | - device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 21 | - | ||
| 22 | - print('\n[+] Create log dir') | ||
| 23 | - model_name = get_model_name(args) | ||
| 24 | - log_dir = os.path.join('./runs', model_name) | ||
| 25 | - os.makedirs(os.path.join(log_dir, 'model')) | ||
| 26 | - json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | ||
| 27 | - writer = SummaryWriter(log_dir=log_dir) | ||
| 28 | - | ||
| 29 | - if args.seed is not None: | ||
| 30 | - random.seed(args.seed) | ||
| 31 | - torch.manual_seed(args.seed) | ||
| 32 | - cudnn.deterministic = True | ||
| 33 | - | ||
| 34 | - print('\n[+] Create network') | ||
| 35 | - model = select_model(args) | ||
| 36 | - optimizer = select_optimizer(args, model) | ||
| 37 | - scheduler = select_scheduler(args, optimizer) | ||
| 38 | - criterion = nn.CrossEntropyLoss() | ||
| 39 | - if args.use_cuda: | ||
| 40 | - model = model.cuda() | ||
| 41 | - criterion = criterion.cuda() | ||
| 42 | - #writer.add_graph(model) | ||
| 43 | - | ||
| 44 | - print('\n[+] Load dataset') | ||
| 45 | - transform = get_train_transform(args, model, log_dir) | ||
| 46 | - val_transform = get_valid_transform(args, model) | ||
| 47 | - train_dataset = get_dataset(args, transform, 'train') | ||
| 48 | - valid_dataset = get_dataset(args, val_transform, 'val') | ||
| 49 | - train_loader = iter(get_inf_dataloader(args, train_dataset)) | ||
| 50 | - max_epoch = len(train_dataset) // args.batch_size | ||
| 51 | - best_acc = -1 | ||
| 52 | - | ||
| 53 | - print('\n[+] Start training') | ||
| 54 | - if torch.cuda.device_count() > 1: | ||
| 55 | - print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
| 56 | - model = nn.DataParallel(model) | ||
| 57 | - | ||
| 58 | - start_t = time.time() | ||
| 59 | - for step in range(args.start_step, args.max_step): | ||
| 60 | - batch = next(train_loader) | ||
| 61 | - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | ||
| 62 | - | ||
| 63 | - if step % args.print_step == 0: | ||
| 64 | - print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | ||
| 65 | - step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | ||
| 66 | - writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | ||
| 67 | - writer.add_scalar('train/acc1', _train_res[0], global_step=step) | ||
| 68 | - writer.add_scalar('train/acc5', _train_res[1], global_step=step) | ||
| 69 | - writer.add_scalar('train/loss', _train_res[2], global_step=step) | ||
| 70 | - writer.add_scalar('train/forward_time', _train_res[3], global_step=step) | ||
| 71 | - writer.add_scalar('train/backward_time', _train_res[4], global_step=step) | ||
| 72 | - print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
| 73 | - print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
| 74 | - print(' Loss : {}'.format(_train_res[2].data)) | ||
| 75 | - print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000)) | ||
| 76 | - print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | ||
| 77 | - | ||
| 78 | - if step % args.val_step == args.val_step-1: | ||
| 79 | - valid_loader = iter(get_dataloader(args, valid_dataset)) | ||
| 80 | - _valid_res = validate(args, model, criterion, valid_loader, step, writer) | ||
| 81 | - print('\n[+] Valid results') | ||
| 82 | - writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | ||
| 83 | - writer.add_scalar('valid/acc5', _valid_res[1], global_step=step) | ||
| 84 | - writer.add_scalar('valid/loss', _valid_res[2], global_step=step) | ||
| 85 | - print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | ||
| 86 | - print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100)) | ||
| 87 | - print(' Loss : {}'.format(_valid_res[2].data)) | ||
| 88 | - | ||
| 89 | - if _valid_res[0] > best_acc: | ||
| 90 | - best_acc = _valid_res[0] | ||
| 91 | - torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | ||
| 92 | - print('\n[+] Model saved') | ||
| 93 | - | ||
| 94 | - writer.close() | ||
| 95 | - | ||
| 96 | - | ||
| 97 | -if __name__ == '__main__': | ||
| 98 | - fire.Fire(train) |
| 1 | -import numpy as np | ||
| 2 | -import torch.nn as nn | ||
| 3 | -import torchvision.transforms as transforms | ||
| 4 | - | ||
| 5 | -from abc import ABC, abstractmethod | ||
| 6 | -from PIL import Image, ImageOps, ImageEnhance | ||
| 7 | - | ||
| 8 | - | ||
| 9 | -class BaseTransform(ABC): | ||
| 10 | - | ||
| 11 | - def __init__(self, prob, mag): | ||
| 12 | - self.prob = prob | ||
| 13 | - self.mag = mag | ||
| 14 | - | ||
| 15 | - def __call__(self, img): | ||
| 16 | - return transforms.RandomApply([self.transform], self.prob)(img) | ||
| 17 | - | ||
| 18 | - def __repr__(self): | ||
| 19 | - return '%s(prob=%.2f, magnitude=%.2f)' % \ | ||
| 20 | - (self.__class__.__name__, self.prob, self.mag) | ||
| 21 | - | ||
| 22 | - @abstractmethod | ||
| 23 | - def transform(self, img): | ||
| 24 | - pass | ||
| 25 | - | ||
| 26 | - | ||
| 27 | -class ShearXY(BaseTransform): | ||
| 28 | - | ||
| 29 | - def transform(self, img): | ||
| 30 | - degrees = self.mag * 360 | ||
| 31 | - t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR) | ||
| 32 | - return t(img) | ||
| 33 | - | ||
| 34 | - | ||
| 35 | -class TranslateXY(BaseTransform): | ||
| 36 | - | ||
| 37 | - def transform(self, img): | ||
| 38 | - translate = (self.mag, self.mag) | ||
| 39 | - t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR) | ||
| 40 | - return t(img) | ||
| 41 | - | ||
| 42 | - | ||
| 43 | -class Rotate(BaseTransform): | ||
| 44 | - | ||
| 45 | - def transform(self, img): | ||
| 46 | - degrees = self.mag * 360 | ||
| 47 | - t = transforms.RandomRotation(degrees, Image.BILINEAR) | ||
| 48 | - return t(img) | ||
| 49 | - | ||
| 50 | - | ||
| 51 | -class AutoContrast(BaseTransform): | ||
| 52 | - | ||
| 53 | - def transform(self, img): | ||
| 54 | - cutoff = int(self.mag * 49) | ||
| 55 | - return ImageOps.autocontrast(img, cutoff=cutoff) | ||
| 56 | - | ||
| 57 | - | ||
| 58 | -class Invert(BaseTransform): | ||
| 59 | - | ||
| 60 | - def transform(self, img): | ||
| 61 | - return ImageOps.invert(img) | ||
| 62 | - | ||
| 63 | - | ||
| 64 | -class Equalize(BaseTransform): | ||
| 65 | - | ||
| 66 | - def transform(self, img): | ||
| 67 | - return ImageOps.equalize(img) | ||
| 68 | - | ||
| 69 | - | ||
| 70 | -class Solarize(BaseTransform): | ||
| 71 | - | ||
| 72 | - def transform(self, img): | ||
| 73 | - threshold = (1-self.mag) * 255 | ||
| 74 | - return ImageOps.solarize(img, threshold) | ||
| 75 | - | ||
| 76 | - | ||
| 77 | -class Posterize(BaseTransform): | ||
| 78 | - | ||
| 79 | - def transform(self, img): | ||
| 80 | - bits = int((1-self.mag) * 8) | ||
| 81 | - return ImageOps.posterize(img, bits=bits) | ||
| 82 | - | ||
| 83 | - | ||
| 84 | -class Contrast(BaseTransform): | ||
| 85 | - | ||
| 86 | - def transform(self, img): | ||
| 87 | - factor = self.mag * 10 | ||
| 88 | - return ImageEnhance.Contrast(img).enhance(factor) | ||
| 89 | - | ||
| 90 | - | ||
| 91 | -class Color(BaseTransform): | ||
| 92 | - | ||
| 93 | - def transform(self, img): | ||
| 94 | - factor = self.mag * 10 | ||
| 95 | - return ImageEnhance.Color(img).enhance(factor) | ||
| 96 | - | ||
| 97 | - | ||
| 98 | -class Brightness(BaseTransform): | ||
| 99 | - | ||
| 100 | - def transform(self, img): | ||
| 101 | - factor = self.mag * 10 | ||
| 102 | - return ImageEnhance.Brightness(img).enhance(factor) | ||
| 103 | - | ||
| 104 | - | ||
| 105 | -class Sharpness(BaseTransform): | ||
| 106 | - | ||
| 107 | - def transform(self, img): | ||
| 108 | - factor = self.mag * 10 | ||
| 109 | - return ImageEnhance.Sharpness(img).enhance(factor) | ||
| 110 | - | ||
| 111 | - | ||
| 112 | -class Cutout(BaseTransform): | ||
| 113 | - | ||
| 114 | - def transform(self, img): | ||
| 115 | - n_holes = 1 | ||
| 116 | - length = 24 * self.mag | ||
| 117 | - cutout_op = CutoutOp(n_holes=n_holes, length=length) | ||
| 118 | - return cutout_op(img) | ||
| 119 | - | ||
| 120 | - | ||
| 121 | -class CutoutOp(object): | ||
| 122 | - """ | ||
| 123 | - https://github.com/uoguelph-mlrg/Cutout | ||
| 124 | - | ||
| 125 | - Randomly mask out one or more patches from an image. | ||
| 126 | - | ||
| 127 | - Args: | ||
| 128 | - n_holes (int): Number of patches to cut out of each image. | ||
| 129 | - length (int): The length (in pixels) of each square patch. | ||
| 130 | - """ | ||
| 131 | - def __init__(self, n_holes, length): | ||
| 132 | - self.n_holes = n_holes | ||
| 133 | - self.length = length | ||
| 134 | - | ||
| 135 | - def __call__(self, img): | ||
| 136 | - """ | ||
| 137 | - Args: | ||
| 138 | - img (Tensor): Tensor image of size (C, H, W). | ||
| 139 | - Returns: | ||
| 140 | - Tensor: Image with n_holes of dimension length x length cut out of it. | ||
| 141 | - """ | ||
| 142 | - w, h = img.size | ||
| 143 | - | ||
| 144 | - mask = np.ones((h, w, 1), np.uint8) | ||
| 145 | - | ||
| 146 | - for n in range(self.n_holes): | ||
| 147 | - y = np.random.randint(h) | ||
| 148 | - x = np.random.randint(w) | ||
| 149 | - | ||
| 150 | - y1 = np.clip(y - self.length // 2, 0, h).astype(int) | ||
| 151 | - y2 = np.clip(y + self.length // 2, 0, h).astype(int) | ||
| 152 | - x1 = np.clip(x - self.length // 2, 0, w).astype(int) | ||
| 153 | - x2 = np.clip(x + self.length // 2, 0, w).astype(int) | ||
| 154 | - | ||
| 155 | - mask[y1: y2, x1: x2, :] = 0. | ||
| 156 | - | ||
| 157 | - img = mask*np.asarray(img).astype(np.uint8) | ||
| 158 | - img = Image.fromarray(mask*np.asarray(img)) | ||
| 159 | - | ||
| 160 | - return img | ||
| 161 | - |
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment