조현아

rm ori FAA2

Showing 21 changed files with 0 additions and 1636 deletions
1 -# Byte-compiled / optimized / DLL files
2 -__pycache__/
3 -*.py[cod]
4 -*$py.class
5 -
6 -# C extensions
7 -*.so
8 -
9 -# Distribution / packaging
10 -.Python
11 -build/
12 -develop-eggs/
13 -dist/
14 -downloads/
15 -eggs/
16 -.eggs/
17 -lib/
18 -lib64/
19 -parts/
20 -sdist/
21 -var/
22 -wheels/
23 -*.egg-info/
24 -.installed.cfg
25 -*.egg
26 -MANIFEST
27 -
28 -# PyInstaller
29 -# Usually these files are written by a python script from a template
30 -# before PyInstaller builds the exe, so as to inject date/other infos into it.
31 -*.manifest
32 -*.spec
33 -
34 -# Installer logs
35 -pip-log.txt
36 -pip-delete-this-directory.txt
37 -
38 -# Unit test / coverage reports
39 -htmlcov/
40 -.tox/
41 -.coverage
42 -.coverage.*
43 -.cache
44 -nosetests.xml
45 -coverage.xml
46 -*.cover
47 -.hypothesis/
48 -.pytest_cache/
49 -
50 -# Translations
51 -*.mo
52 -*.pot
53 -
54 -# Django stuff:
55 -*.log
56 -local_settings.py
57 -db.sqlite3
58 -
59 -# Flask stuff:
60 -instance/
61 -.webassets-cache
62 -
63 -# Scrapy stuff:
64 -.scrapy
65 -
66 -# Sphinx documentation
67 -docs/_build/
68 -
69 -# PyBuilder
70 -target/
71 -
72 -# Jupyter Notebook
73 -.ipynb_checkpoints
74 -
75 -# pyenv
76 -.python-version
77 -
78 -# celery beat schedule file
79 -celerybeat-schedule
80 -
81 -# SageMath parsed files
82 -*.sage.py
83 -
84 -# Environments
85 -.env
86 -.venv
87 -env/
88 -venv/
89 -ENV/
90 -env.bak/
91 -venv.bak/
92 -
93 -# Spyder project settings
94 -.spyderproject
95 -.spyproject
96 -
97 -# Rope project settings
98 -.ropeproject
99 -
100 -# mkdocs documentation
101 -/site
102 -
103 -# mypy
104 -.mypy_cache/
105 -
1 -# Fast Autoaugment
2 -<img src="figures/faa.png" width=800px>
3 -
4 -A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946).
5 -
6 -## Prerequisite
7 -* torch==1.1.0
8 -* torchvision==0.2.2
9 -* hyperopt==0.1.2
10 -* future==0.17.1
11 -* tb-nightly==1.15.0a20190622
12 -
13 -## Usage
14 -### Training
15 -#### CIFAR10
16 -```bash
17 -# ResNet20 (w/o FastAutoAugment)
18 -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False
19 -
20 -# ResNet20 (w/ FastAutoAugment)
21 -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True
22 -
23 -# ResNet20 (w/ FastAutoAugment, Pre-found policy)
24 -python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \
25 - --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp
26 -
27 -# ResNet32 (w/o FastAutoAugment)
28 -python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False
29 -
30 -# ResNet32 (w/ FastAutoAugment)
31 -python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True
32 -
33 -# EfficientNet (w/ FastAutoAugment)
34 -python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \
35 - --network=efficientnet_cifar10 --activation=swish
36 -```
37 -
38 -#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html))
39 -```bash
40 -
41 -# BaseNet (w/o FastAutoAugment)
42 -python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50
43 -
44 -# EfficientNet (w/ FastAutoAugment) (UnderConstruction)
45 -python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \
46 - --network=efficientnet --activation=swish
47 -```
48 -
49 -### Eval
50 -```bash
51 -# Single Image testing
52 -python eval.py --model_path=runs/ResNet_Scale3_Basline
53 -
54 -# 5-crops testing
55 -python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True
56 -```
57 -
58 -## Experiments
59 -### Fast AutoAugment
60 -#### ResNet20 (CIFAR10)
61 -* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)]
62 -* Validation Curve
63 -<img src="figures/resnet20_valid.png">
64 -
65 -* Evaluation (Acc @1)
66 -
67 -| | Valid | Test(Single) |
68 -|----------------|-------|-------------|
69 -| ResNet20 | 90.70 | **91.45** |
70 -| ResNet20 + FAA |**92.46**| **91.45** |
71 -
72 -#### ResNet34 (CIFAR10)
73 -* Validation Curve
74 -<img src="figures/resnet34_valid.png">
75 -
76 -* Evaluation (Acc @1)
77 -
78 -| | Valid | Test(Single) |
79 -|----------------|-------|-------------|
80 -| ResNet34 | 91.54 | 91.47 |
81 -| ResNet34 + FAA |**92.76**| **91.99** |
82 -
83 -### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)]
84 -<img src="figures/pm.png">
85 -
86 -### Augmented images
87 -<img src="figures/augmented_images.png">
88 -<img src="figures/augmented_images2.png">
1 -import os
2 -import fire
3 -import json
4 -from pprint import pprint
5 -
6 -import torch
7 -import torch.nn as nn
8 -
9 -from utils import *
10 -
11 -
12 -def eval(model_path):
13 - print('\n[+] Parse arguments')
14 - kwargs_path = os.path.join(model_path, 'kwargs.json')
15 - kwargs = json.loads(open(kwargs_path).read())
16 - args, kwargs = parse_args(kwargs)
17 - pprint(args)
18 - device = torch.device('cuda' if args.use_cuda else 'cpu')
19 -
20 - print('\n[+] Create network')
21 - model = select_model(args)
22 - optimizer = select_optimizer(args, model)
23 - criterion = nn.CrossEntropyLoss()
24 - if args.use_cuda:
25 - model = model.cuda()
26 - criterion = criterion.cuda()
27 -
28 - print('\n[+] Load model')
29 - weight_path = os.path.join(model_path, 'model', 'model.pt')
30 - model.load_state_dict(torch.load(weight_path))
31 -
32 - print('\n[+] Load dataset')
33 - test_transform = get_valid_transform(args, model)
34 - test_dataset = get_dataset(args, test_transform, 'test')
35 - test_loader = iter(get_dataloader(args, test_dataset))
36 -
37 - print('\n[+] Start testing')
38 - _test_res = validate(args, model, criterion, test_loader, step=0, writer=None)
39 -
40 - print('\n[+] Valid results')
41 - print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100))
42 - print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100))
43 - print(' Loss : {:.3f}'.format(_test_res[2].data))
44 - print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset)))
45 -
46 -
47 -if __name__ == '__main__':
48 - fire.Fire(eval)
1 -import copy
2 -import json
3 -import time
4 -import torch
5 -import random
6 -import torchvision.transforms as transforms
7 -
8 -from torch.utils.data import Subset
9 -from sklearn.model_selection import StratifiedShuffleSplit
10 -from concurrent.futures import ProcessPoolExecutor
11 -
12 -from transforms import *
13 -from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
14 -from utils import *
15 -
16 -
17 -DEFALUT_CANDIDATES = [
18 - ShearXY,
19 - TranslateXY,
20 - Rotate,
21 - AutoContrast,
22 - Invert,
23 - Equalize,
24 - Solarize,
25 - Posterize,
26 - Contrast,
27 - Color,
28 - Brightness,
29 - Sharpness,
30 - Cutout,
31 -# SamplePairing,
32 -]
33 -
34 -
35 -def train_child(args, model, dataset, subset_indx, device=None):
36 - optimizer = select_optimizer(args, model)
37 - scheduler = select_scheduler(args, optimizer)
38 - criterion = nn.CrossEntropyLoss()
39 -
40 - dataset.transform = transforms.Compose([
41 - transforms.Resize(32),
42 - transforms.ToTensor()])
43 - subset = Subset(dataset, subset_indx)
44 - data_loader = get_inf_dataloader(args, subset)
45 -
46 - if device:
47 - model = model.to(device)
48 - criterion = criterion.to(device)
49 -
50 - elif args.use_cuda:
51 - model = model.cuda()
52 - criterion = criterion.cuda()
53 -
54 - if torch.cuda.device_count() > 1:
55 - print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
56 - model = nn.DataParallel(model)
57 -
58 - start_t = time.time()
59 - for step in range(args.start_step, args.max_step):
60 - batch = next(data_loader)
61 - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device)
62 -
63 - if step % args.print_step == 0:
64 - print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format(
65 - step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device))
66 -
67 - print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
68 - print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
69 - print(' Loss : {}'.format(_train_res[2].data))
70 -
71 - return _train_res
72 -
73 -
74 -def validate_child(args, model, dataset, subset_indx, transform, device=None):
75 - criterion = nn.CrossEntropyLoss()
76 -
77 - if device:
78 - model = model.to(device)
79 - criterion = criterion.to(device)
80 -
81 - elif args.use_cuda:
82 - model = model.cuda()
83 - criterion = criterion.cuda()
84 -
85 - dataset.transform = transform
86 - subset = Subset(dataset, subset_indx)
87 - data_loader = get_dataloader(args, subset, pin_memory=False)
88 -
89 - return validate(args, model, criterion, data_loader, 0, None, device)
90 -
91 -
92 -def get_next_subpolicy(transform_candidates, op_per_subpolicy=2):
93 - n_candidates = len(transform_candidates)
94 - subpolicy = []
95 -
96 - for i in range(op_per_subpolicy):
97 - indx = random.randrange(n_candidates)
98 - prob = random.random()
99 - mag = random.random()
100 - subpolicy.append(transform_candidates[indx](prob, mag))
101 -
102 - subpolicy = transforms.Compose([
103 - *subpolicy,
104 - transforms.Resize(32),
105 - transforms.ToTensor()])
106 -
107 - return subpolicy
108 -
109 -
110 -def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device):
111 - subpolicies = []
112 -
113 - for b in range(B):
114 - subpolicy = get_next_subpolicy(transform_candidates)
115 - val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device)
116 - subpolicies.append((subpolicy, val_res[2]))
117 -
118 - return subpolicies
119 -
120 -
121 -def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device):
122 -
123 - def _objective(sampled):
124 - subpolicy = [transform(prob, mag)
125 - for transform, prob, mag in sampled]
126 -
127 - subpolicy = transforms.Compose([
128 - transforms.Resize(32),
129 - *subpolicy,
130 - transforms.ToTensor()])
131 -
132 - val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device)
133 - loss = val_res[2].cpu().numpy()
134 - return {'loss': loss, 'status': STATUS_OK }
135 -
136 - space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)),
137 - (hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))]
138 -
139 - trials = Trials()
140 - best = fmin(_objective,
141 - space=space,
142 - algo=tpe.suggest,
143 - max_evals=B,
144 - trials=trials)
145 -
146 - subpolicies = []
147 - for t in trials.trials:
148 - vals = t['misc']['vals']
149 - subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]),
150 - transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])]
151 - subpolicy = transforms.Compose([
152 - ## baseline augmentation
153 - transforms.Pad(4),
154 - transforms.RandomCrop(32),
155 - transforms.RandomHorizontalFlip(),
156 - ## policy
157 - *subpolicy,
158 - ## to tensor
159 - transforms.ToTensor()])
160 - subpolicies.append((subpolicy, t['result']['loss']))
161 -
162 - return subpolicies
163 -
164 -
165 -def get_topn_subpolicies(subpolicies, N=10):
166 - return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N]
167 -
168 -
169 -def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k):
170 - kwargs = json.loads(args_str)
171 - args, kwargs = parse_args(kwargs)
172 - device_id = k % torch.cuda.device_count()
173 - device = torch.device('cuda:%d' % device_id)
174 - _transform = []
175 -
176 - print('[+] Child %d training strated (GPU: %d)' % (k, device_id))
177 -
178 - # train child model
179 - child_model = copy.deepcopy(model)
180 - train_res = train_child(args, child_model, dataset, Dm_indx, device)
181 -
182 - # search sub policy
183 - for t in range(T):
184 - #subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device)
185 - subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device)
186 - subpolicies = get_topn_subpolicies(subpolicies, N)
187 - _transform.extend([subpolicy[0] for subpolicy in subpolicies])
188 -
189 - return _transform
190 -
191 -
192 -def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5):
193 - args_str = json.dumps(args._asdict())
194 - dataset = get_dataset(args, None, 'trainval')
195 - num_process = min(torch.cuda.device_count(), num_process)
196 - transform, futures = [], []
197 -
198 - torch.multiprocessing.set_start_method('spawn', force=True)
199 -
200 - if not transform_candidates:
201 - transform_candidates = DEFALUT_CANDIDATES
202 -
203 - # split
204 - Dm_indexes, Da_indexes = split_dataset(args, dataset, K)
205 -
206 - with ProcessPoolExecutor(max_workers=num_process) as executor:
207 - for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)):
208 - future = executor.submit(process_fn,
209 - args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k)
210 - futures.append(future)
211 -
212 - for future in futures:
213 - transform.extend(future.result())
214 -
215 - transform = transforms.RandomChoice(transform)
216 -
217 - return transform
1 -import torch.nn as nn
2 -
3 -class BaseNet(nn.Module):
4 - def __init__(self, backbone, args):
5 - super(BaseNet, self).__init__()
6 -
7 - # Separate layers
8 - self.first = nn.Sequential(*list(backbone.children())[:1])
9 - self.after = nn.Sequential(*list(backbone.children())[1:-1])
10 - self.fc = list(backbone.children())[-1]
11 -
12 - self.img_size = (224, 224)
13 -
14 - def forward(self, x):
15 - f = self.first(x)
16 - x = self.after(f)
17 - x = x.reshape(x.size(0), -1)
18 - x = self.fc(x)
19 - return x, f
1 -import math
2 -import torch.nn as nn
3 -import torch.nn.functional as F
4 -
5 -
6 -def round_fn(orig, multiplier):
7 - if not multiplier:
8 - return orig
9 -
10 - return int(math.ceil(multiplier * orig))
11 -
12 -
13 -def get_activation_fn(activation):
14 - if activation == "swish":
15 - return Swish
16 -
17 - elif activation == "relu":
18 - return nn.ReLU
19 -
20 - else:
21 - raise Exception('Unkown activation %s' % activation)
22 -
23 -
24 -class Swish(nn.Module):
25 - """ Swish activation function, s(x) = x * sigmoid(x) """
26 -
27 - def __init__(self, inplace=False):
28 - super().__init__()
29 - self.inplace = True
30 -
31 - def forward(self, x):
32 - if self.inplace:
33 - x.mul_(F.sigmoid(x))
34 - return x
35 - else:
36 - return x * F.sigmoid(x)
37 -
38 -
39 -class ConvBlock(nn.Module):
40 - """ Conv + BatchNorm + Activation """
41 -
42 - def __init__(self, in_channel, out_channel, kernel_size,
43 - padding=0, stride=1, activation="swish"):
44 - super().__init__()
45 - self.fw = nn.Sequential(
46 - nn.Conv2d(in_channel, out_channel, kernel_size,
47 - padding=padding, stride=stride, bias=False),
48 - nn.BatchNorm2d(out_channel),
49 - get_activation_fn(activation)())
50 -
51 - def forward(self, x):
52 - return self.fw(x)
53 -
54 -
55 -class DepthwiseConvBlock(nn.Module):
56 - """ DepthwiseConv2D + BatchNorm + Activation """
57 -
58 - def __init__(self, in_channel, kernel_size,
59 - padding=0, stride=1, activation="swish"):
60 - super().__init__()
61 - self.fw = nn.Sequential(
62 - nn.Conv2d(in_channel, in_channel, kernel_size,
63 - padding=padding, stride=stride, groups=in_channel, bias=False),
64 - nn.BatchNorm2d(in_channel),
65 - get_activation_fn(activation)())
66 -
67 - def forward(self, x):
68 - return self.fw(x)
69 -
70 -
71 -class MBConv(nn.Module):
72 - """ Inverted residual block """
73 -
74 - def __init__(self, in_channel, out_channel, kernel_size,
75 - stride=1, expand_ratio=1, activation="swish"):
76 - super().__init__()
77 - self.in_channel = in_channel
78 - self.out_channel = out_channel
79 - self.expand_ratio = expand_ratio
80 - self.stride = stride
81 -
82 - if expand_ratio != 1:
83 - self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
84 - activation=activation)
85 -
86 - self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
87 - padding=(kernel_size-1)//2,
88 - stride=stride, activation=activation)
89 -
90 - self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
91 - activation=activation)
92 -
93 - def forward(self, inputs):
94 - if self.expand_ratio != 1:
95 - x = self.expand(inputs)
96 - else:
97 - x = inputs
98 -
99 - x = self.dw_conv(x)
100 - x = self.pw_conv(x)
101 -
102 - if self.in_channel == self.out_channel and \
103 - self.stride == 1:
104 - x = x + inputs
105 -
106 - return x
107 -
108 -
109 -class Net(nn.Module):
110 - """ EfficientNet """
111 -
112 - def __init__(self, args):
113 - super(Net, self).__init__()
114 - pi = args.pi
115 - activation = args.activation
116 - num_classes = args.num_classes
117 -
118 - self.d = 1.2 ** pi
119 - self.w = 1.1 ** pi
120 - self.r = 1.15 ** pi
121 - self.img_size = (round_fn(224, self.r), round_fn(224, self.r))
122 -
123 - self.stage1 = ConvBlock(3, round_fn(32, self.w),
124 - kernel_size=3, padding=1, stride=2, activation=activation)
125 -
126 - self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
127 - depth=round_fn(1, self.d), kernel_size=3,
128 - half_resolution=False, expand_ratio=1, activation=activation)
129 -
130 - self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
131 - depth=round_fn(2, self.d), kernel_size=3,
132 - half_resolution=True, expand_ratio=6, activation=activation)
133 -
134 - self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
135 - depth=round_fn(2, self.d), kernel_size=5,
136 - half_resolution=True, expand_ratio=6, activation=activation)
137 -
138 - self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
139 - depth=round_fn(3, self.d), kernel_size=3,
140 - half_resolution=True, expand_ratio=6, activation=activation)
141 -
142 - self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
143 - depth=round_fn(3, self.d), kernel_size=5,
144 - half_resolution=False, expand_ratio=6, activation=activation)
145 -
146 - self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
147 - depth=round_fn(4, self.d), kernel_size=5,
148 - half_resolution=True, expand_ratio=6, activation=activation)
149 -
150 - self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
151 - depth=round_fn(1, self.d), kernel_size=3,
152 - half_resolution=False, expand_ratio=6, activation=activation)
153 -
154 - self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
155 - kernel_size=1, activation=activation)
156 -
157 - self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes)
158 -
159 - def make_layers(self, in_channel, out_channel, depth, kernel_size,
160 - half_resolution=False, expand_ratio=1, activation="swish"):
161 - blocks = []
162 - for i in range(depth):
163 - stride = 2 if half_resolution and i==0 else 1
164 - blocks.append(
165 - MBConv(in_channel, out_channel, kernel_size,
166 - stride=stride, expand_ratio=expand_ratio, activation=activation))
167 - in_channel = out_channel
168 -
169 - return nn.Sequential(*blocks)
170 -
171 - def forward(self, x):
172 - assert x.size()[-2:] == self.img_size, \
173 - 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
174 -
175 - x = self.stage1(x)
176 - x = self.stage2(x)
177 - x = self.stage3(x)
178 - x = self.stage4(x)
179 - x = self.stage5(x)
180 - x = self.stage6(x)
181 - x = self.stage7(x)
182 - x = self.stage8(x)
183 - x = self.stage9(x)
184 - x = x.reshape(x.size(0), -1)
185 - x = self.fc(x)
186 - return x, x
1 -import math
2 -import torch.nn as nn
3 -import torch.nn.functional as F
4 -
5 -
6 -def round_fn(orig, multiplier):
7 - if not multiplier:
8 - return orig
9 -
10 - return int(math.ceil(multiplier * orig))
11 -
12 -
13 -def get_activation_fn(activation):
14 - if activation == "swish":
15 - return Swish
16 -
17 - elif activation == "relu":
18 - return nn.ReLU
19 -
20 - else:
21 - raise Exception('Unkown activation %s' % activation)
22 -
23 -
24 -class Swish(nn.Module):
25 - """ Swish activation function, s(x) = x * sigmoid(x) """
26 -
27 - def __init__(self, inplace=False):
28 - super().__init__()
29 - self.inplace = True
30 -
31 - def forward(self, x):
32 - if self.inplace:
33 - x.mul_(F.sigmoid(x))
34 - return x
35 - else:
36 - return x * F.sigmoid(x)
37 -
38 -
39 -class ConvBlock(nn.Module):
40 - """ Conv + BatchNorm + Activation """
41 -
42 - def __init__(self, in_channel, out_channel, kernel_size,
43 - padding=0, stride=1, activation="swish"):
44 - super().__init__()
45 - self.fw = nn.Sequential(
46 - nn.Conv2d(in_channel, out_channel, kernel_size,
47 - padding=padding, stride=stride, bias=False),
48 - nn.BatchNorm2d(out_channel),
49 - get_activation_fn(activation)())
50 -
51 - def forward(self, x):
52 - return self.fw(x)
53 -
54 -
55 -class DepthwiseConvBlock(nn.Module):
56 - """ DepthwiseConv2D + BatchNorm + Activation """
57 -
58 - def __init__(self, in_channel, kernel_size,
59 - padding=0, stride=1, activation="swish"):
60 - super().__init__()
61 - self.fw = nn.Sequential(
62 - nn.Conv2d(in_channel, in_channel, kernel_size,
63 - padding=padding, stride=stride, groups=in_channel, bias=False),
64 - nn.BatchNorm2d(in_channel),
65 - get_activation_fn(activation)())
66 -
67 - def forward(self, x):
68 - return self.fw(x)
69 -
70 -
71 -class SEBlock(nn.Module):
72 - """ Squeeze and Excitation Block """
73 -
74 - def __init__(self, in_channel, se_ratio=16):
75 - super().__init__()
76 - self.global_avgpool = nn.AdaptiveAvgPool2d((1,1))
77 - inter_channel = in_channel // se_ratio
78 -
79 - self.reduce = nn.Sequential(
80 - nn.Conv2d(in_channel, inter_channel,
81 - kernel_size=1, padding=0, stride=1),
82 - nn.ReLU())
83 -
84 - self.expand = nn.Sequential(
85 - nn.Conv2d(inter_channel, in_channel,
86 - kernel_size=1, padding=0, stride=1),
87 - nn.Sigmoid())
88 -
89 -
90 - def forward(self, x):
91 - s = self.global_avgpool(x)
92 - s = self.reduce(s)
93 - s = self.expand(s)
94 - return x * s
95 -
96 -
97 -class MBConv(nn.Module):
98 - """ Inverted residual block """
99 -
100 - def __init__(self, in_channel, out_channel, kernel_size,
101 - stride=1, expand_ratio=1, activation="swish", use_seblock=False):
102 - super().__init__()
103 - self.in_channel = in_channel
104 - self.out_channel = out_channel
105 - self.expand_ratio = expand_ratio
106 - self.stride = stride
107 - self.use_seblock = use_seblock
108 -
109 - if expand_ratio != 1:
110 - self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
111 - activation=activation)
112 -
113 - self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
114 - padding=(kernel_size-1)//2,
115 - stride=stride, activation=activation)
116 -
117 - if use_seblock:
118 - self.seblock = SEBlock(in_channel*expand_ratio)
119 -
120 - self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
121 - activation=activation)
122 -
123 - def forward(self, inputs):
124 - if self.expand_ratio != 1:
125 - x = self.expand(inputs)
126 - else:
127 - x = inputs
128 -
129 - x = self.dw_conv(x)
130 -
131 - if self.use_seblock:
132 - x = self.seblock(x)
133 -
134 - x = self.pw_conv(x)
135 -
136 - if self.in_channel == self.out_channel and \
137 - self.stride == 1:
138 - x = x + inputs
139 -
140 - return x
141 -
142 -
143 -class Net(nn.Module):
144 - """ EfficientNet """
145 -
146 - def __init__(self, args):
147 - super(Net, self).__init__()
148 - pi = args.pi
149 - activation = args.activation
150 - num_classes = 10
151 -
152 - self.d = 1.2 ** pi
153 - self.w = 1.1 ** pi
154 - self.r = 1.15 ** pi
155 - self.img_size = (round_fn(32, self.r), round_fn(32, self.r))
156 - self.use_seblock = args.use_seblock
157 -
158 - self.stage1 = ConvBlock(3, round_fn(32, self.w),
159 - kernel_size=3, padding=1, stride=2, activation=activation)
160 -
161 - self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
162 - depth=round_fn(1, self.d), kernel_size=3,
163 - half_resolution=False, expand_ratio=1, activation=activation)
164 -
165 - self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
166 - depth=round_fn(2, self.d), kernel_size=3,
167 - half_resolution=True, expand_ratio=6, activation=activation)
168 -
169 - self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
170 - depth=round_fn(2, self.d), kernel_size=5,
171 - half_resolution=True, expand_ratio=6, activation=activation)
172 -
173 - self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
174 - depth=round_fn(3, self.d), kernel_size=3,
175 - half_resolution=True, expand_ratio=6, activation=activation)
176 -
177 - self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
178 - depth=round_fn(3, self.d), kernel_size=5,
179 - half_resolution=False, expand_ratio=6, activation=activation)
180 -
181 - self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
182 - depth=round_fn(4, self.d), kernel_size=5,
183 - half_resolution=True, expand_ratio=6, activation=activation)
184 -
185 - self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
186 - depth=round_fn(1, self.d), kernel_size=3,
187 - half_resolution=False, expand_ratio=6, activation=activation)
188 -
189 - self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
190 - kernel_size=1, activation=activation)
191 -
192 - self.fc = nn.Linear(round_fn(1280, self.w), num_classes)
193 -
194 - def make_layers(self, in_channel, out_channel, depth, kernel_size,
195 - half_resolution=False, expand_ratio=1, activation="swish"):
196 - blocks = []
197 - for i in range(depth):
198 - stride = 2 if half_resolution and i==0 else 1
199 - blocks.append(
200 - MBConv(in_channel, out_channel, kernel_size,
201 - stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock))
202 - in_channel = out_channel
203 -
204 - return nn.Sequential(*blocks)
205 -
206 - def forward(self, x):
207 - assert x.size()[-2:] == self.img_size, \
208 - 'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
209 -
210 - s = self.stage1(x)
211 - x = self.stage2(s)
212 - x = self.stage3(x)
213 - x = self.stage4(x)
214 - x = self.stage5(x)
215 - x = self.stage6(x)
216 - x = self.stage7(x)
217 - x = self.stage8(x)
218 - x = self.stage9(x)
219 - x = x.reshape(x.size(0), -1)
220 - x = self.fc(x)
221 - return x, s
1 -import torch.nn as nn
2 -
3 -
4 -class ResidualBlock(nn.Module):
5 - def __init__(self, in_channel, out_channel, stride):
6 - super(ResidualBlock, self).__init__()
7 - self.in_channel = in_channel
8 - self.out_channel = out_channel
9 - self.stride = stride
10 -
11 - self.conv1 = nn.Sequential(
12 - nn.Conv2d(in_channel, out_channel,
13 - kernel_size=3, padding=1, stride=stride),
14 - nn.BatchNorm2d(out_channel))
15 -
16 - self.relu = nn.ReLU(inplace=True)
17 -
18 - self.conv2 = nn.Sequential(
19 - nn.Conv2d(out_channel, out_channel,
20 - kernel_size=3, padding=1),
21 - nn.BatchNorm2d(out_channel))
22 -
23 - if self.in_channel != self.out_channel or \
24 - self.stride != 1:
25 - self.down = nn.Sequential(
26 - nn.Conv2d(in_channel, out_channel,
27 - kernel_size=1, stride=stride),
28 - nn.BatchNorm2d(out_channel))
29 -
30 - def forward(self, b):
31 - t = self.conv1(b)
32 - t = self.relu(t)
33 - t = self.conv2(t)
34 -
35 - if self.in_channel != self.out_channel or \
36 - self.stride != 1:
37 - b = self.down(b)
38 -
39 - t += b
40 - t = self.relu(t)
41 -
42 - return t
43 -
44 -
45 -class Net(nn.Module):
46 - def __init__(self, args):
47 - super(Net, self).__init__()
48 - scale = args.scale
49 -
50 - self.stem = nn.Sequential(
51 - nn.Conv2d(3, 16,
52 - kernel_size=3, padding=1),
53 - nn.BatchNorm2d(16),
54 - nn.ReLU(inplace=True))
55 -
56 - self.layer1 = nn.Sequential(*[
57 - ResidualBlock(16, 16, 1) for _ in range(2*scale)])
58 -
59 - self.layer2 = nn.Sequential(*[
60 - ResidualBlock(in_channel=(16 if i==0 else 32),
61 - out_channel=32,
62 - stride=(2 if i==0 else 1)) for i in range(2*scale)])
63 -
64 - self.layer3 = nn.Sequential(*[
65 - ResidualBlock(in_channel=(32 if i==0 else 64),
66 - out_channel=64,
67 - stride=(2 if i==0 else 1)) for i in range(2*scale)])
68 -
69 - self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
70 -
71 - self.fc = nn.Linear(64, 10)
72 -
73 - for m in self.modules():
74 - if isinstance(m, nn.Conv2d):
75 - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
76 -
77 - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
78 - nn.init.constant_(m.weight, 1)
79 - nn.init.constant_(m.bias, 0)
80 -
81 - def forward(self, x):
82 - s = self.stem(x)
83 - x = self.layer1(s)
84 - x = self.layer2(x)
85 - x = self.layer3(x)
86 - x = self.avg_pool(x)
87 - x = x.reshape(x.size(0), -1)
88 - x = self.fc(x)
89 -
90 - return x, s
1 -import os
2 -import fire
3 -import time
4 -import json
5 -import random
6 -from pprint import pprint
7 -
8 -import torch.nn as nn
9 -import torch.backends.cudnn as cudnn
10 -from torch.utils.tensorboard import SummaryWriter
11 -
12 -from networks import *
13 -from utils import *
14 -
15 -
16 -def train(**kwargs):
17 - print('\n[+] Parse arguments')
18 - args, kwargs = parse_args(kwargs)
19 - pprint(args)
20 - device = torch.device('cuda' if args.use_cuda else 'cpu')
21 -
22 - print('\n[+] Create log dir')
23 - model_name = get_model_name(args)
24 - log_dir = os.path.join('./runs', model_name)
25 - os.makedirs(os.path.join(log_dir, 'model'))
26 - json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
27 - writer = SummaryWriter(log_dir=log_dir)
28 -
29 - if args.seed is not None:
30 - random.seed(args.seed)
31 - torch.manual_seed(args.seed)
32 - cudnn.deterministic = True
33 -
34 - print('\n[+] Create network')
35 - model = select_model(args)
36 - optimizer = select_optimizer(args, model)
37 - scheduler = select_scheduler(args, optimizer)
38 - criterion = nn.CrossEntropyLoss()
39 - if args.use_cuda:
40 - model = model.cuda()
41 - criterion = criterion.cuda()
42 - #writer.add_graph(model)
43 -
44 - print('\n[+] Load dataset')
45 - transform = get_train_transform(args, model, log_dir)
46 - val_transform = get_valid_transform(args, model)
47 - train_dataset = get_dataset(args, transform, 'train')
48 - valid_dataset = get_dataset(args, val_transform, 'val')
49 - train_loader = iter(get_inf_dataloader(args, train_dataset))
50 - max_epoch = len(train_dataset) // args.batch_size
51 - best_acc = -1
52 -
53 - print('\n[+] Start training')
54 - if torch.cuda.device_count() > 1:
55 - print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
56 - model = nn.DataParallel(model)
57 -
58 - start_t = time.time()
59 - for step in range(args.start_step, args.max_step):
60 - batch = next(train_loader)
61 - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer)
62 -
63 - if step % args.print_step == 0:
64 - print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format(
65 - step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr']))
66 - writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step)
67 - writer.add_scalar('train/acc1', _train_res[0], global_step=step)
68 - writer.add_scalar('train/acc5', _train_res[1], global_step=step)
69 - writer.add_scalar('train/loss', _train_res[2], global_step=step)
70 - writer.add_scalar('train/forward_time', _train_res[3], global_step=step)
71 - writer.add_scalar('train/backward_time', _train_res[4], global_step=step)
72 - print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
73 - print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
74 - print(' Loss : {}'.format(_train_res[2].data))
75 - print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000))
76 - print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
77 -
78 - if step % args.val_step == args.val_step-1:
79 - valid_loader = iter(get_dataloader(args, valid_dataset))
80 - _valid_res = validate(args, model, criterion, valid_loader, step, writer)
81 - print('\n[+] Valid results')
82 - writer.add_scalar('valid/acc1', _valid_res[0], global_step=step)
83 - writer.add_scalar('valid/acc5', _valid_res[1], global_step=step)
84 - writer.add_scalar('valid/loss', _valid_res[2], global_step=step)
85 - print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100))
86 - print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100))
87 - print(' Loss : {}'.format(_valid_res[2].data))
88 -
89 - if _valid_res[0] > best_acc:
90 - best_acc = _valid_res[0]
91 - torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt"))
92 - print('\n[+] Model saved')
93 -
94 - writer.close()
95 -
96 -
97 -if __name__ == '__main__':
98 - fire.Fire(train)
1 -import numpy as np
2 -import torch.nn as nn
3 -import torchvision.transforms as transforms
4 -
5 -from abc import ABC, abstractmethod
6 -from PIL import Image, ImageOps, ImageEnhance
7 -
8 -
9 -class BaseTransform(ABC):
10 -
11 - def __init__(self, prob, mag):
12 - self.prob = prob
13 - self.mag = mag
14 -
15 - def __call__(self, img):
16 - return transforms.RandomApply([self.transform], self.prob)(img)
17 -
18 - def __repr__(self):
19 - return '%s(prob=%.2f, magnitude=%.2f)' % \
20 - (self.__class__.__name__, self.prob, self.mag)
21 -
22 - @abstractmethod
23 - def transform(self, img):
24 - pass
25 -
26 -
27 -class ShearXY(BaseTransform):
28 -
29 - def transform(self, img):
30 - degrees = self.mag * 360
31 - t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR)
32 - return t(img)
33 -
34 -
35 -class TranslateXY(BaseTransform):
36 -
37 - def transform(self, img):
38 - translate = (self.mag, self.mag)
39 - t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR)
40 - return t(img)
41 -
42 -
43 -class Rotate(BaseTransform):
44 -
45 - def transform(self, img):
46 - degrees = self.mag * 360
47 - t = transforms.RandomRotation(degrees, Image.BILINEAR)
48 - return t(img)
49 -
50 -
51 -class AutoContrast(BaseTransform):
52 -
53 - def transform(self, img):
54 - cutoff = int(self.mag * 49)
55 - return ImageOps.autocontrast(img, cutoff=cutoff)
56 -
57 -
58 -class Invert(BaseTransform):
59 -
60 - def transform(self, img):
61 - return ImageOps.invert(img)
62 -
63 -
64 -class Equalize(BaseTransform):
65 -
66 - def transform(self, img):
67 - return ImageOps.equalize(img)
68 -
69 -
70 -class Solarize(BaseTransform):
71 -
72 - def transform(self, img):
73 - threshold = (1-self.mag) * 255
74 - return ImageOps.solarize(img, threshold)
75 -
76 -
77 -class Posterize(BaseTransform):
78 -
79 - def transform(self, img):
80 - bits = int((1-self.mag) * 8)
81 - return ImageOps.posterize(img, bits=bits)
82 -
83 -
84 -class Contrast(BaseTransform):
85 -
86 - def transform(self, img):
87 - factor = self.mag * 10
88 - return ImageEnhance.Contrast(img).enhance(factor)
89 -
90 -
91 -class Color(BaseTransform):
92 -
93 - def transform(self, img):
94 - factor = self.mag * 10
95 - return ImageEnhance.Color(img).enhance(factor)
96 -
97 -
98 -class Brightness(BaseTransform):
99 -
100 - def transform(self, img):
101 - factor = self.mag * 10
102 - return ImageEnhance.Brightness(img).enhance(factor)
103 -
104 -
105 -class Sharpness(BaseTransform):
106 -
107 - def transform(self, img):
108 - factor = self.mag * 10
109 - return ImageEnhance.Sharpness(img).enhance(factor)
110 -
111 -
112 -class Cutout(BaseTransform):
113 -
114 - def transform(self, img):
115 - n_holes = 1
116 - length = 24 * self.mag
117 - cutout_op = CutoutOp(n_holes=n_holes, length=length)
118 - return cutout_op(img)
119 -
120 -
121 -class CutoutOp(object):
122 - """
123 - https://github.com/uoguelph-mlrg/Cutout
124 -
125 - Randomly mask out one or more patches from an image.
126 -
127 - Args:
128 - n_holes (int): Number of patches to cut out of each image.
129 - length (int): The length (in pixels) of each square patch.
130 - """
131 - def __init__(self, n_holes, length):
132 - self.n_holes = n_holes
133 - self.length = length
134 -
135 - def __call__(self, img):
136 - """
137 - Args:
138 - img (Tensor): Tensor image of size (C, H, W).
139 - Returns:
140 - Tensor: Image with n_holes of dimension length x length cut out of it.
141 - """
142 - w, h = img.size
143 -
144 - mask = np.ones((h, w, 1), np.uint8)
145 -
146 - for n in range(self.n_holes):
147 - y = np.random.randint(h)
148 - x = np.random.randint(w)
149 -
150 - y1 = np.clip(y - self.length // 2, 0, h).astype(int)
151 - y2 = np.clip(y + self.length // 2, 0, h).astype(int)
152 - x1 = np.clip(x - self.length // 2, 0, w).astype(int)
153 - x2 = np.clip(x + self.length // 2, 0, w).astype(int)
154 -
155 - mask[y1: y2, x1: x2, :] = 0.
156 -
157 - img = mask*np.asarray(img).astype(np.uint8)
158 - img = Image.fromarray(mask*np.asarray(img))
159 -
160 - return img
161 -
1 -import os
2 -import time
3 -import importlib
4 -import collections
5 -import pickle as cp
6 -import glob
7 -import numpy as np
8 -
9 -import torch
10 -import torchvision
11 -import torch.nn.functional as F
12 -import torchvision.models as models
13 -import torchvision.transforms as transforms
14 -from torch.utils.data import Subset
15 -from torch.utils.data import Dataset, DataLoader
16 -
17 -from sklearn.model_selection import StratifiedShuffleSplit
18 -
19 -
20 -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
21 -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
22 -current_epoch = 0
23 -
24 -
25 -def split_dataset(args, dataset, k):
26 - # load dataset
27 - X = list(range(len(dataset)))
28 - Y = dataset.targets
29 -
30 - # split to k-fold
31 - assert len(X) == len(Y)
32 -
33 - def _it_to_list(_it):
34 - return list(zip(*list(_it)))
35 -
36 - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
37 - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
38 -
39 - return Dm_indexes, Da_indexes
40 -
41 -
42 -def concat_image_features(image, features, max_features=3):
43 - _, h, w = image.shape
44 -
45 - max_features = min(features.size(0), max_features)
46 - image_feature = image.clone()
47 -
48 - for i in range(max_features):
49 - feature = features[i:i+1]
50 - _min, _max = torch.min(feature), torch.max(feature)
51 - feature = (feature - _min) / (_max - _min + 1e-6)
52 - feature = torch.cat([feature]*3, 0)
53 - feature = feature.view(1, 3, feature.size(1), feature.size(2))
54 - feature = F.upsample(feature, size=(h,w), mode="bilinear")
55 - feature = feature.view(3, h, w)
56 - image_feature = torch.cat((image_feature, feature), 2)
57 -
58 - return image_feature
59 -
60 -
61 -def get_model_name(args):
62 - from datetime import datetime
63 - now = datetime.now()
64 - date_time = now.strftime("%B_%d_%H:%M:%S")
65 - model_name = '__'.join([date_time, args.network, str(args.seed)])
66 - return model_name
67 -
68 -
69 -def dict_to_namedtuple(d):
70 - Args = collections.namedtuple('Args', sorted(d.keys()))
71 -
72 - for k,v in d.items():
73 - if type(v) is dict:
74 - d[k] = dict_to_namedtuple(v)
75 -
76 - elif type(v) is str:
77 - try:
78 - d[k] = eval(v)
79 - except:
80 - d[k] = v
81 -
82 - args = Args(**d)
83 - return args
84 -
85 -
86 -def parse_args(kwargs):
87 - # combine with default args
88 - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10'
89 - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10'
90 - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
91 - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1
92 - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
93 - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
94 - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
95 - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
96 - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000
97 - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000
98 - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
99 - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128
100 - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
101 - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000
102 - kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False
103 - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
104 -
105 - # to named tuple
106 - args = dict_to_namedtuple(kwargs)
107 - return args, kwargs
108 -
109 -
110 -def select_model(args):
111 - if args.network in models.__dict__:
112 - backbone = models.__dict__[args.network]()
113 - model = BaseNet(backbone, args)
114 - else:
115 - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
116 - model = Net(args)
117 -
118 - print(model)
119 - return model
120 -
121 -
122 -def select_optimizer(args, model):
123 - if args.optimizer == 'sgd':
124 - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
125 - elif args.optimizer == 'rms':
126 - #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5)
127 - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
128 - elif args.optimizer == 'adam':
129 - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
130 - else:
131 - raise Exception('Unknown Optimizer')
132 - return optimizer
133 -
134 -
135 -def select_scheduler(args, optimizer):
136 - if not args.scheduler or args.scheduler == 'None':
137 - return None
138 - elif args.scheduler =='clr':
139 - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
140 - elif args.scheduler =='exp':
141 - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
142 - else:
143 - raise Exception('Unknown Scheduler')
144 -
145 -
146 -class CustomDataset(Dataset):
147 - def __init__(self, path, transform = None):
148 - self.path = path
149 - self.transform = transform
150 - self.img = np.load(path)
151 - self.len = self.img.shape[0]
152 -
153 - def __len__(self):
154 - return self.len
155 -
156 - def __getitem__(self, idx):
157 - if self.transforms is not None:
158 - img = self.transforms(img)
159 - return img
160 -
161 -def get_dataset(args, transform, split='train'):
162 - assert split in ['train', 'val', 'test', 'trainval']
163 -
164 - if args.dataset == 'cifar10':
165 - train = split in ['train', 'val', 'trainval']
166 - dataset = torchvision.datasets.CIFAR10(DATASET_PATH,
167 - train=train,
168 - transform=transform,
169 - download=True)
170 -
171 - if split in ['train', 'val']:
172 - split_path = os.path.join(DATASET_PATH,
173 - 'cifar-10-batches-py', 'train_val_index.cp')
174 -
175 - if not os.path.exists(split_path):
176 - [train_index], [val_index] = split_dataset(args, dataset, k=1)
177 - split_index = {'train':train_index, 'val':val_index}
178 - cp.dump(split_index, open(split_path, 'wb'))
179 -
180 - split_index = cp.load(open(split_path, 'rb'))
181 - dataset = Subset(dataset, split_index[split])
182 -
183 - elif args.dataset == 'imagenet':
184 - dataset = torchvision.datasets.ImageNet(DATASET_PATH,
185 - split=split,
186 - transform=transform,
187 - download=(split is 'val'))
188 -
189 - elif args.dataset == 'BraTS':
190 - if split in ['train']:
191 - dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform)
192 - else:
193 - dataset = CustomDataset(VAL_DATASET_PATH, transform=transform)
194 -
195 -
196 - else:
197 - raise Exception('Unknown dataset')
198 -
199 - return dataset
200 -
201 -
202 -def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
203 - data_loader = torch.utils.data.DataLoader(dataset,
204 - batch_size=args.batch_size,
205 - shuffle=shuffle,
206 - num_workers=args.num_workers,
207 - pin_memory=pin_memory)
208 - return data_loader
209 -
210 -
211 -def get_inf_dataloader(args, dataset):
212 - global current_epoch
213 - data_loader = iter(get_dataloader(args, dataset, shuffle=True))
214 -
215 - while True:
216 - try:
217 - batch = next(data_loader)
218 -
219 - except StopIteration:
220 - current_epoch += 1
221 - data_loader = iter(get_dataloader(args, dataset, shuffle=True))
222 - batch = next(data_loader)
223 -
224 - yield batch
225 -
226 -
227 -def get_train_transform(args, model, log_dir=None):
228 - if args.fast_auto_augment:
229 - assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet
230 -
231 - from fast_auto_augment import fast_auto_augment
232 - if args.augment_path:
233 - transform = cp.load(open(args.augment_path, 'rb'))
234 - os.system('cp {} {}'.format(
235 - args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
236 - else:
237 - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4)
238 - if log_dir:
239 - cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
240 -
241 - elif args.dataset == 'cifar10':
242 - transform = transforms.Compose([
243 - transforms.Pad(4),
244 - transforms.RandomCrop(32),
245 - transforms.RandomHorizontalFlip(),
246 - transforms.ToTensor()
247 - ])
248 -
249 - elif args.dataset == 'imagenet':
250 - resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
251 - transform = transforms.Compose([
252 - transforms.Resize([resize_h, resize_w]),
253 - transforms.RandomCrop(model.img_size),
254 - transforms.RandomHorizontalFlip(),
255 - transforms.ToTensor()
256 - ])
257 -
258 - elif args.dataset == 'BraTS':
259 - resize_h, resize_w = 256, 256
260 - transform = transforms.Compose([
261 - transforms.Resize([resize_h, resize_w]),
262 - transforms.RandomCrop(model.img_size),
263 - transforms.RandomHorizontalFlip(),
264 - transforms.ToTensor()
265 - ])
266 - else:
267 - raise Exception('Unknown Dataset')
268 -
269 - print(transform)
270 -
271 - return transform
272 -
273 -
274 -def get_valid_transform(args, model):
275 - if args.dataset == 'cifar10':
276 - val_transform = transforms.Compose([
277 - transforms.Resize(32),
278 - transforms.ToTensor()
279 - ])
280 -
281 - elif args.dataset == 'imagenet':
282 - resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
283 - val_transform = transforms.Compose([
284 - transforms.Resize([resize_h, resize_w]),
285 - transforms.ToTensor()
286 - ])
287 - elif args.dataset == 'BraTS':
288 - resize_h, resize_w = 256, 256
289 - val_transform = transforms.Compose([
290 - transforms.Resize([resize_h, resize_w]),
291 - transforms.ToTensor()
292 - ])
293 - else:
294 - raise Exception('Unknown Dataset')
295 -
296 - return val_transform
297 -
298 -
299 -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
300 - model.train()
301 - images, target = batch
302 -
303 - if device:
304 - images = images.to(device)
305 - target = target.to(device)
306 -
307 - elif args.use_cuda:
308 - images = images.cuda(non_blocking=True)
309 - target = target.cuda(non_blocking=True)
310 -
311 - # compute output
312 - start_t = time.time()
313 - output, first = model(images)
314 - forward_t = time.time() - start_t
315 - loss = criterion(output, target)
316 -
317 - # measure accuracy and record loss
318 - acc1, acc5 = accuracy(output, target, topk=(1, 5))
319 - acc1 /= images.size(0)
320 - acc5 /= images.size(0)
321 -
322 - # compute gradient and do SGD step
323 - optimizer.zero_grad()
324 - start_t = time.time()
325 - loss.backward()
326 - backward_t = time.time() - start_t
327 - optimizer.step()
328 - if scheduler: scheduler.step()
329 -
330 - if writer and step % args.print_step == 0:
331 - n_imgs = min(images.size(0), 10)
332 - for j in range(n_imgs):
333 - writer.add_image('train/input_image',
334 - concat_image_features(images[j], first[j]), global_step=step)
335 -
336 - return acc1, acc5, loss, forward_t, backward_t
337 -
338 -
339 -def validate(args, model, criterion, valid_loader, step, writer, device=None):
340 - # switch to evaluate mode
341 - model.eval()
342 -
343 - acc1, acc5 = 0, 0
344 - samples = 0
345 - infer_t = 0
346 -
347 - with torch.no_grad():
348 - for i, (images, target) in enumerate(valid_loader):
349 -
350 - start_t = time.time()
351 - if device:
352 - images = images.to(device)
353 - target = target.to(device)
354 -
355 - elif args.use_cuda is not None:
356 - images = images.cuda(non_blocking=True)
357 - target = target.cuda(non_blocking=True)
358 -
359 - # compute output
360 - output, first = model(images)
361 - loss = criterion(output, target)
362 - infer_t += time.time() - start_t
363 -
364 - # measure accuracy and record loss
365 - _acc1, _acc5 = accuracy(output, target, topk=(1, 5))
366 - acc1 += _acc1
367 - acc5 += _acc5
368 - samples += images.size(0)
369 -
370 - acc1 /= samples
371 - acc5 /= samples
372 -
373 - if writer:
374 - n_imgs = min(images.size(0), 10)
375 - for j in range(n_imgs):
376 - writer.add_image('valid/input_image',
377 - concat_image_features(images[j], first[j]), global_step=step)
378 -
379 - return acc1, acc5, loss, infer_t
380 -
381 -
382 -def accuracy(output, target, topk=(1,)):
383 - """Computes the accuracy over the k top predictions for the specified values of k"""
384 - with torch.no_grad():
385 - maxk = max(topk)
386 - batch_size = target.size(0)
387 -
388 - _, pred = output.topk(maxk, 1, True, True)
389 - pred = pred.t()
390 - correct = pred.eq(target.view(1, -1).expand_as(pred))
391 -
392 - res = []
393 - for k in topk:
394 - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
395 - res.append(correct_k)
396 - return res
397 -