조현아

FAA ver2.0 for colab

Showing 21 changed files with 1636 additions and 0 deletions
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# Fast Autoaugment
<img src="figures/faa.png" width=800px>
A Pytorch Implementation of [Fast AutoAugment](https://arxiv.org/pdf/1905.00397.pdf) and [EfficientNet](https://arxiv.org/abs/1905.11946).
## Prerequisite
* torch==1.1.0
* torchvision==0.2.2
* hyperopt==0.1.2
* future==0.17.1
* tb-nightly==1.15.0a20190622
## Usage
### Training
#### CIFAR10
```bash
# ResNet20 (w/o FastAutoAugment)
python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False
# ResNet20 (w/ FastAutoAugment)
python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True
# ResNet20 (w/ FastAutoAugment, Pre-found policy)
python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \
--augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp
# ResNet32 (w/o FastAutoAugment)
python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False
# ResNet32 (w/ FastAutoAugment)
python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True
# EfficientNet (w/ FastAutoAugment)
python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \
--network=efficientnet_cifar10 --activation=swish
```
#### ImageNet (You can use any backbone networks in [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html))
```bash
# BaseNet (w/o FastAutoAugment)
python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50
# EfficientNet (w/ FastAutoAugment) (UnderConstruction)
python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \
--network=efficientnet --activation=swish
```
### Eval
```bash
# Single Image testing
python eval.py --model_path=runs/ResNet_Scale3_Basline
# 5-crops testing
python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True
```
## Experiments
### Fast AutoAugment
#### ResNet20 (CIFAR10)
* Pre-trained model [[Download](https://drive.google.com/file/d/12D8050yGGiKWGt8_R8QTlkoQ6wq_icBn/view?usp=sharing)]
* Validation Curve
<img src="figures/resnet20_valid.png">
* Evaluation (Acc @1)
| | Valid | Test(Single) |
|----------------|-------|-------------|
| ResNet20 | 90.70 | **91.45** |
| ResNet20 + FAA |**92.46**| **91.45** |
#### ResNet34 (CIFAR10)
* Validation Curve
<img src="figures/resnet34_valid.png">
* Evaluation (Acc @1)
| | Valid | Test(Single) |
|----------------|-------|-------------|
| ResNet34 | 91.54 | 91.47 |
| ResNet34 + FAA |**92.76**| **91.99** |
### Found Policy [[Download](https://drive.google.com/file/d/1Ia_IxPY3-T7m8biyl3QpxV1s5EA5gRDF/view?usp=sharing)]
<img src="figures/pm.png">
### Augmented images
<img src="figures/augmented_images.png">
<img src="figures/augmented_images2.png">
import os
import fire
import json
from pprint import pprint
import torch
import torch.nn as nn
from utils import *
def eval(model_path):
print('\n[+] Parse arguments')
kwargs_path = os.path.join(model_path, 'kwargs.json')
kwargs = json.loads(open(kwargs_path).read())
args, kwargs = parse_args(kwargs)
pprint(args)
device = torch.device('cuda' if args.use_cuda else 'cpu')
print('\n[+] Create network')
model = select_model(args)
optimizer = select_optimizer(args, model)
criterion = nn.CrossEntropyLoss()
if args.use_cuda:
model = model.cuda()
criterion = criterion.cuda()
print('\n[+] Load model')
weight_path = os.path.join(model_path, 'model', 'model.pt')
model.load_state_dict(torch.load(weight_path))
print('\n[+] Load dataset')
test_transform = get_valid_transform(args, model)
test_dataset = get_dataset(args, test_transform, 'test')
test_loader = iter(get_dataloader(args, test_dataset))
print('\n[+] Start testing')
_test_res = validate(args, model, criterion, test_loader, step=0, writer=None)
print('\n[+] Valid results')
print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100))
print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100))
print(' Loss : {:.3f}'.format(_test_res[2].data))
print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset)))
if __name__ == '__main__':
fire.Fire(eval)
import copy
import json
import time
import torch
import random
import torchvision.transforms as transforms
from torch.utils.data import Subset
from sklearn.model_selection import StratifiedShuffleSplit
from concurrent.futures import ProcessPoolExecutor
from transforms import *
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from utils import *
DEFALUT_CANDIDATES = [
ShearXY,
TranslateXY,
Rotate,
AutoContrast,
Invert,
Equalize,
Solarize,
Posterize,
Contrast,
Color,
Brightness,
Sharpness,
Cutout,
# SamplePairing,
]
def train_child(args, model, dataset, subset_indx, device=None):
optimizer = select_optimizer(args, model)
scheduler = select_scheduler(args, optimizer)
criterion = nn.CrossEntropyLoss()
dataset.transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor()])
subset = Subset(dataset, subset_indx)
data_loader = get_inf_dataloader(args, subset)
if device:
model = model.to(device)
criterion = criterion.to(device)
elif args.use_cuda:
model = model.cuda()
criterion = criterion.cuda()
if torch.cuda.device_count() > 1:
print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
start_t = time.time()
for step in range(args.start_step, args.max_step):
batch = next(data_loader)
_train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device)
if step % args.print_step == 0:
print('\n[+] Training step: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}\tDevice: {}'.format(
step, args.max_step,(time.time()-start_t)/60, optimizer.param_groups[0]['lr'], device))
print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
print(' Loss : {}'.format(_train_res[2].data))
return _train_res
def validate_child(args, model, dataset, subset_indx, transform, device=None):
criterion = nn.CrossEntropyLoss()
if device:
model = model.to(device)
criterion = criterion.to(device)
elif args.use_cuda:
model = model.cuda()
criterion = criterion.cuda()
dataset.transform = transform
subset = Subset(dataset, subset_indx)
data_loader = get_dataloader(args, subset, pin_memory=False)
return validate(args, model, criterion, data_loader, 0, None, device)
def get_next_subpolicy(transform_candidates, op_per_subpolicy=2):
n_candidates = len(transform_candidates)
subpolicy = []
for i in range(op_per_subpolicy):
indx = random.randrange(n_candidates)
prob = random.random()
mag = random.random()
subpolicy.append(transform_candidates[indx](prob, mag))
subpolicy = transforms.Compose([
*subpolicy,
transforms.Resize(32),
transforms.ToTensor()])
return subpolicy
def search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device):
subpolicies = []
for b in range(B):
subpolicy = get_next_subpolicy(transform_candidates)
val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device)
subpolicies.append((subpolicy, val_res[2]))
return subpolicies
def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device):
def _objective(sampled):
subpolicy = [transform(prob, mag)
for transform, prob, mag in sampled]
subpolicy = transforms.Compose([
transforms.Resize(32),
*subpolicy,
transforms.ToTensor()])
val_res = validate_child(args, child_model, dataset, Da_indx, subpolicy, device)
loss = val_res[2].cpu().numpy()
return {'loss': loss, 'status': STATUS_OK }
space = [(hp.choice('transform1', transform_candidates), hp.uniform('prob1', 0, 1.0), hp.uniform('mag1', 0, 1.0)),
(hp.choice('transform2', transform_candidates), hp.uniform('prob2', 0, 1.0), hp.uniform('mag2', 0, 1.0))]
trials = Trials()
best = fmin(_objective,
space=space,
algo=tpe.suggest,
max_evals=B,
trials=trials)
subpolicies = []
for t in trials.trials:
vals = t['misc']['vals']
subpolicy = [transform_candidates[vals['transform1'][0]](vals['prob1'][0], vals['mag1'][0]),
transform_candidates[vals['transform2'][0]](vals['prob2'][0], vals['mag2'][0])]
subpolicy = transforms.Compose([
## baseline augmentation
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
## policy
*subpolicy,
## to tensor
transforms.ToTensor()])
subpolicies.append((subpolicy, t['result']['loss']))
return subpolicies
def get_topn_subpolicies(subpolicies, N=10):
return sorted(subpolicies, key=lambda subpolicy: subpolicy[1])[:N]
def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k):
kwargs = json.loads(args_str)
args, kwargs = parse_args(kwargs)
device_id = k % torch.cuda.device_count()
device = torch.device('cuda:%d' % device_id)
_transform = []
print('[+] Child %d training strated (GPU: %d)' % (k, device_id))
# train child model
child_model = copy.deepcopy(model)
train_res = train_child(args, child_model, dataset, Dm_indx, device)
# search sub policy
for t in range(T):
#subpolicies = search_subpolicies(args, transform_candidates, child_model, dataset, Da_indx, B, device)
subpolicies = search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset, Da_indx, B, device)
subpolicies = get_topn_subpolicies(subpolicies, N)
_transform.extend([subpolicy[0] for subpolicy in subpolicies])
return _transform
def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5):
args_str = json.dumps(args._asdict())
dataset = get_dataset(args, None, 'trainval')
num_process = min(torch.cuda.device_count(), num_process)
transform, futures = [], []
torch.multiprocessing.set_start_method('spawn', force=True)
if not transform_candidates:
transform_candidates = DEFALUT_CANDIDATES
# split
Dm_indexes, Da_indexes = split_dataset(args, dataset, K)
with ProcessPoolExecutor(max_workers=num_process) as executor:
for k, (Dm_indx, Da_indx) in enumerate(zip(Dm_indexes, Da_indexes)):
future = executor.submit(process_fn,
args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidates, B, N, k)
futures.append(future)
for future in futures:
transform.extend(future.result())
transform = transforms.RandomChoice(transform)
return transform
import torch.nn as nn
class BaseNet(nn.Module):
def __init__(self, backbone, args):
super(BaseNet, self).__init__()
# Separate layers
self.first = nn.Sequential(*list(backbone.children())[:1])
self.after = nn.Sequential(*list(backbone.children())[1:-1])
self.fc = list(backbone.children())[-1]
self.img_size = (224, 224)
def forward(self, x):
f = self.first(x)
x = self.after(f)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x, f
import math
import torch.nn as nn
import torch.nn.functional as F
def round_fn(orig, multiplier):
if not multiplier:
return orig
return int(math.ceil(multiplier * orig))
def get_activation_fn(activation):
if activation == "swish":
return Swish
elif activation == "relu":
return nn.ReLU
else:
raise Exception('Unkown activation %s' % activation)
class Swish(nn.Module):
""" Swish activation function, s(x) = x * sigmoid(x) """
def __init__(self, inplace=False):
super().__init__()
self.inplace = True
def forward(self, x):
if self.inplace:
x.mul_(F.sigmoid(x))
return x
else:
return x * F.sigmoid(x)
class ConvBlock(nn.Module):
""" Conv + BatchNorm + Activation """
def __init__(self, in_channel, out_channel, kernel_size,
padding=0, stride=1, activation="swish"):
super().__init__()
self.fw = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size,
padding=padding, stride=stride, bias=False),
nn.BatchNorm2d(out_channel),
get_activation_fn(activation)())
def forward(self, x):
return self.fw(x)
class DepthwiseConvBlock(nn.Module):
""" DepthwiseConv2D + BatchNorm + Activation """
def __init__(self, in_channel, kernel_size,
padding=0, stride=1, activation="swish"):
super().__init__()
self.fw = nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size,
padding=padding, stride=stride, groups=in_channel, bias=False),
nn.BatchNorm2d(in_channel),
get_activation_fn(activation)())
def forward(self, x):
return self.fw(x)
class MBConv(nn.Module):
""" Inverted residual block """
def __init__(self, in_channel, out_channel, kernel_size,
stride=1, expand_ratio=1, activation="swish"):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.expand_ratio = expand_ratio
self.stride = stride
if expand_ratio != 1:
self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
activation=activation)
self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
padding=(kernel_size-1)//2,
stride=stride, activation=activation)
self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
activation=activation)
def forward(self, inputs):
if self.expand_ratio != 1:
x = self.expand(inputs)
else:
x = inputs
x = self.dw_conv(x)
x = self.pw_conv(x)
if self.in_channel == self.out_channel and \
self.stride == 1:
x = x + inputs
return x
class Net(nn.Module):
""" EfficientNet """
def __init__(self, args):
super(Net, self).__init__()
pi = args.pi
activation = args.activation
num_classes = args.num_classes
self.d = 1.2 ** pi
self.w = 1.1 ** pi
self.r = 1.15 ** pi
self.img_size = (round_fn(224, self.r), round_fn(224, self.r))
self.stage1 = ConvBlock(3, round_fn(32, self.w),
kernel_size=3, padding=1, stride=2, activation=activation)
self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
depth=round_fn(1, self.d), kernel_size=3,
half_resolution=False, expand_ratio=1, activation=activation)
self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
depth=round_fn(2, self.d), kernel_size=3,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
depth=round_fn(2, self.d), kernel_size=5,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
depth=round_fn(3, self.d), kernel_size=3,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
depth=round_fn(3, self.d), kernel_size=5,
half_resolution=False, expand_ratio=6, activation=activation)
self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
depth=round_fn(4, self.d), kernel_size=5,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
depth=round_fn(1, self.d), kernel_size=3,
half_resolution=False, expand_ratio=6, activation=activation)
self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
kernel_size=1, activation=activation)
self.fc = nn.Linear(round_fn(7*7*1280, self.w), num_classes)
def make_layers(self, in_channel, out_channel, depth, kernel_size,
half_resolution=False, expand_ratio=1, activation="swish"):
blocks = []
for i in range(depth):
stride = 2 if half_resolution and i==0 else 1
blocks.append(
MBConv(in_channel, out_channel, kernel_size,
stride=stride, expand_ratio=expand_ratio, activation=activation))
in_channel = out_channel
return nn.Sequential(*blocks)
def forward(self, x):
assert x.size()[-2:] == self.img_size, \
'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.stage6(x)
x = self.stage7(x)
x = self.stage8(x)
x = self.stage9(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x, x
import math
import torch.nn as nn
import torch.nn.functional as F
def round_fn(orig, multiplier):
if not multiplier:
return orig
return int(math.ceil(multiplier * orig))
def get_activation_fn(activation):
if activation == "swish":
return Swish
elif activation == "relu":
return nn.ReLU
else:
raise Exception('Unkown activation %s' % activation)
class Swish(nn.Module):
""" Swish activation function, s(x) = x * sigmoid(x) """
def __init__(self, inplace=False):
super().__init__()
self.inplace = True
def forward(self, x):
if self.inplace:
x.mul_(F.sigmoid(x))
return x
else:
return x * F.sigmoid(x)
class ConvBlock(nn.Module):
""" Conv + BatchNorm + Activation """
def __init__(self, in_channel, out_channel, kernel_size,
padding=0, stride=1, activation="swish"):
super().__init__()
self.fw = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size,
padding=padding, stride=stride, bias=False),
nn.BatchNorm2d(out_channel),
get_activation_fn(activation)())
def forward(self, x):
return self.fw(x)
class DepthwiseConvBlock(nn.Module):
""" DepthwiseConv2D + BatchNorm + Activation """
def __init__(self, in_channel, kernel_size,
padding=0, stride=1, activation="swish"):
super().__init__()
self.fw = nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size,
padding=padding, stride=stride, groups=in_channel, bias=False),
nn.BatchNorm2d(in_channel),
get_activation_fn(activation)())
def forward(self, x):
return self.fw(x)
class SEBlock(nn.Module):
""" Squeeze and Excitation Block """
def __init__(self, in_channel, se_ratio=16):
super().__init__()
self.global_avgpool = nn.AdaptiveAvgPool2d((1,1))
inter_channel = in_channel // se_ratio
self.reduce = nn.Sequential(
nn.Conv2d(in_channel, inter_channel,
kernel_size=1, padding=0, stride=1),
nn.ReLU())
self.expand = nn.Sequential(
nn.Conv2d(inter_channel, in_channel,
kernel_size=1, padding=0, stride=1),
nn.Sigmoid())
def forward(self, x):
s = self.global_avgpool(x)
s = self.reduce(s)
s = self.expand(s)
return x * s
class MBConv(nn.Module):
""" Inverted residual block """
def __init__(self, in_channel, out_channel, kernel_size,
stride=1, expand_ratio=1, activation="swish", use_seblock=False):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.expand_ratio = expand_ratio
self.stride = stride
self.use_seblock = use_seblock
if expand_ratio != 1:
self.expand = ConvBlock(in_channel, in_channel*expand_ratio, 1,
activation=activation)
self.dw_conv = DepthwiseConvBlock(in_channel*expand_ratio, kernel_size,
padding=(kernel_size-1)//2,
stride=stride, activation=activation)
if use_seblock:
self.seblock = SEBlock(in_channel*expand_ratio)
self.pw_conv = ConvBlock(in_channel*expand_ratio, out_channel, 1,
activation=activation)
def forward(self, inputs):
if self.expand_ratio != 1:
x = self.expand(inputs)
else:
x = inputs
x = self.dw_conv(x)
if self.use_seblock:
x = self.seblock(x)
x = self.pw_conv(x)
if self.in_channel == self.out_channel and \
self.stride == 1:
x = x + inputs
return x
class Net(nn.Module):
""" EfficientNet """
def __init__(self, args):
super(Net, self).__init__()
pi = args.pi
activation = args.activation
num_classes = 10
self.d = 1.2 ** pi
self.w = 1.1 ** pi
self.r = 1.15 ** pi
self.img_size = (round_fn(32, self.r), round_fn(32, self.r))
self.use_seblock = args.use_seblock
self.stage1 = ConvBlock(3, round_fn(32, self.w),
kernel_size=3, padding=1, stride=2, activation=activation)
self.stage2 = self.make_layers(round_fn(32, self.w), round_fn(16, self.w),
depth=round_fn(1, self.d), kernel_size=3,
half_resolution=False, expand_ratio=1, activation=activation)
self.stage3 = self.make_layers(round_fn(16, self.w), round_fn(24, self.w),
depth=round_fn(2, self.d), kernel_size=3,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage4 = self.make_layers(round_fn(24, self.w), round_fn(40, self.w),
depth=round_fn(2, self.d), kernel_size=5,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage5 = self.make_layers(round_fn(40, self.w), round_fn(80, self.w),
depth=round_fn(3, self.d), kernel_size=3,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage6 = self.make_layers(round_fn(80, self.w), round_fn(112, self.w),
depth=round_fn(3, self.d), kernel_size=5,
half_resolution=False, expand_ratio=6, activation=activation)
self.stage7 = self.make_layers(round_fn(112, self.w), round_fn(192, self.w),
depth=round_fn(4, self.d), kernel_size=5,
half_resolution=True, expand_ratio=6, activation=activation)
self.stage8 = self.make_layers(round_fn(192, self.w), round_fn(320, self.w),
depth=round_fn(1, self.d), kernel_size=3,
half_resolution=False, expand_ratio=6, activation=activation)
self.stage9 = ConvBlock(round_fn(320, self.w), round_fn(1280, self.w),
kernel_size=1, activation=activation)
self.fc = nn.Linear(round_fn(1280, self.w), num_classes)
def make_layers(self, in_channel, out_channel, depth, kernel_size,
half_resolution=False, expand_ratio=1, activation="swish"):
blocks = []
for i in range(depth):
stride = 2 if half_resolution and i==0 else 1
blocks.append(
MBConv(in_channel, out_channel, kernel_size,
stride=stride, expand_ratio=expand_ratio, activation=activation, use_seblock=self.use_seblock))
in_channel = out_channel
return nn.Sequential(*blocks)
def forward(self, x):
assert x.size()[-2:] == self.img_size, \
'Image size must be %r, but %r given' % (self.img_size, x.size()[-2])
s = self.stage1(x)
x = self.stage2(s)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.stage6(x)
x = self.stage7(x)
x = self.stage8(x)
x = self.stage9(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x, s
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channel, out_channel, stride):
super(ResidualBlock, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.stride = stride
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, out_channel,
kernel_size=3, padding=1, stride=stride),
nn.BatchNorm2d(out_channel))
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channel, out_channel,
kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel))
if self.in_channel != self.out_channel or \
self.stride != 1:
self.down = nn.Sequential(
nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channel))
def forward(self, b):
t = self.conv1(b)
t = self.relu(t)
t = self.conv2(t)
if self.in_channel != self.out_channel or \
self.stride != 1:
b = self.down(b)
t += b
t = self.relu(t)
return t
class Net(nn.Module):
def __init__(self, args):
super(Net, self).__init__()
scale = args.scale
self.stem = nn.Sequential(
nn.Conv2d(3, 16,
kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True))
self.layer1 = nn.Sequential(*[
ResidualBlock(16, 16, 1) for _ in range(2*scale)])
self.layer2 = nn.Sequential(*[
ResidualBlock(in_channel=(16 if i==0 else 32),
out_channel=32,
stride=(2 if i==0 else 1)) for i in range(2*scale)])
self.layer3 = nn.Sequential(*[
ResidualBlock(in_channel=(32 if i==0 else 64),
out_channel=64,
stride=(2 if i==0 else 1)) for i in range(2*scale)])
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(64, 10)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
s = self.stem(x)
x = self.layer1(s)
x = self.layer2(x)
x = self.layer3(x)
x = self.avg_pool(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x, s
import os
import fire
import time
import json
import random
from pprint import pprint
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from networks import *
from utils import *
def train(**kwargs):
print('\n[+] Parse arguments')
args, kwargs = parse_args(kwargs)
pprint(args)
device = torch.device('cuda' if args.use_cuda else 'cpu')
print('\n[+] Create log dir')
model_name = get_model_name(args)
log_dir = os.path.join('./runs', model_name)
os.makedirs(os.path.join(log_dir, 'model'))
json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
writer = SummaryWriter(log_dir=log_dir)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
print('\n[+] Create network')
model = select_model(args)
optimizer = select_optimizer(args, model)
scheduler = select_scheduler(args, optimizer)
criterion = nn.CrossEntropyLoss()
if args.use_cuda:
model = model.cuda()
criterion = criterion.cuda()
#writer.add_graph(model)
print('\n[+] Load dataset')
transform = get_train_transform(args, model, log_dir)
val_transform = get_valid_transform(args, model)
train_dataset = get_dataset(args, transform, 'train')
valid_dataset = get_dataset(args, val_transform, 'val')
train_loader = iter(get_inf_dataloader(args, train_dataset))
max_epoch = len(train_dataset) // args.batch_size
best_acc = -1
print('\n[+] Start training')
if torch.cuda.device_count() > 1:
print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
start_t = time.time()
for step in range(args.start_step, args.max_step):
batch = next(train_loader)
_train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer)
if step % args.print_step == 0:
print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format(
step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr']))
writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step)
writer.add_scalar('train/acc1', _train_res[0], global_step=step)
writer.add_scalar('train/acc5', _train_res[1], global_step=step)
writer.add_scalar('train/loss', _train_res[2], global_step=step)
writer.add_scalar('train/forward_time', _train_res[3], global_step=step)
writer.add_scalar('train/backward_time', _train_res[4], global_step=step)
print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
print(' Loss : {}'.format(_train_res[2].data))
print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000))
print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
if step % args.val_step == args.val_step-1:
valid_loader = iter(get_dataloader(args, valid_dataset))
_valid_res = validate(args, model, criterion, valid_loader, step, writer)
print('\n[+] Valid results')
writer.add_scalar('valid/acc1', _valid_res[0], global_step=step)
writer.add_scalar('valid/acc5', _valid_res[1], global_step=step)
writer.add_scalar('valid/loss', _valid_res[2], global_step=step)
print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100))
print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100))
print(' Loss : {}'.format(_valid_res[2].data))
if _valid_res[0] > best_acc:
best_acc = _valid_res[0]
torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt"))
print('\n[+] Model saved')
writer.close()
if __name__ == '__main__':
fire.Fire(train)
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
from PIL import Image, ImageOps, ImageEnhance
class BaseTransform(ABC):
def __init__(self, prob, mag):
self.prob = prob
self.mag = mag
def __call__(self, img):
return transforms.RandomApply([self.transform], self.prob)(img)
def __repr__(self):
return '%s(prob=%.2f, magnitude=%.2f)' % \
(self.__class__.__name__, self.prob, self.mag)
@abstractmethod
def transform(self, img):
pass
class ShearXY(BaseTransform):
def transform(self, img):
degrees = self.mag * 360
t = transforms.RandomAffine(0, shear=degrees, resample=Image.BILINEAR)
return t(img)
class TranslateXY(BaseTransform):
def transform(self, img):
translate = (self.mag, self.mag)
t = transforms.RandomAffine(0, translate=translate, resample=Image.BILINEAR)
return t(img)
class Rotate(BaseTransform):
def transform(self, img):
degrees = self.mag * 360
t = transforms.RandomRotation(degrees, Image.BILINEAR)
return t(img)
class AutoContrast(BaseTransform):
def transform(self, img):
cutoff = int(self.mag * 49)
return ImageOps.autocontrast(img, cutoff=cutoff)
class Invert(BaseTransform):
def transform(self, img):
return ImageOps.invert(img)
class Equalize(BaseTransform):
def transform(self, img):
return ImageOps.equalize(img)
class Solarize(BaseTransform):
def transform(self, img):
threshold = (1-self.mag) * 255
return ImageOps.solarize(img, threshold)
class Posterize(BaseTransform):
def transform(self, img):
bits = int((1-self.mag) * 8)
return ImageOps.posterize(img, bits=bits)
class Contrast(BaseTransform):
def transform(self, img):
factor = self.mag * 10
return ImageEnhance.Contrast(img).enhance(factor)
class Color(BaseTransform):
def transform(self, img):
factor = self.mag * 10
return ImageEnhance.Color(img).enhance(factor)
class Brightness(BaseTransform):
def transform(self, img):
factor = self.mag * 10
return ImageEnhance.Brightness(img).enhance(factor)
class Sharpness(BaseTransform):
def transform(self, img):
factor = self.mag * 10
return ImageEnhance.Sharpness(img).enhance(factor)
class Cutout(BaseTransform):
def transform(self, img):
n_holes = 1
length = 24 * self.mag
cutout_op = CutoutOp(n_holes=n_holes, length=length)
return cutout_op(img)
class CutoutOp(object):
"""
https://github.com/uoguelph-mlrg/Cutout
Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
w, h = img.size
mask = np.ones((h, w, 1), np.uint8)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h).astype(int)
y2 = np.clip(y + self.length // 2, 0, h).astype(int)
x1 = np.clip(x - self.length // 2, 0, w).astype(int)
x2 = np.clip(x + self.length // 2, 0, w).astype(int)
mask[y1: y2, x1: x2, :] = 0.
img = mask*np.asarray(img).astype(np.uint8)
img = Image.fromarray(mask*np.asarray(img))
return img
import os
import time
import importlib
import collections
import pickle as cp
import glob
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
current_epoch = 0
def split_dataset(args, dataset, k):
# load dataset
X = list(range(len(dataset)))
Y = dataset.targets
# split to k-fold
assert len(X) == len(Y)
def _it_to_list(_it):
return list(zip(*list(_it)))
sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
return Dm_indexes, Da_indexes
def concat_image_features(image, features, max_features=3):
_, h, w = image.shape
max_features = min(features.size(0), max_features)
image_feature = image.clone()
for i in range(max_features):
feature = features[i:i+1]
_min, _max = torch.min(feature), torch.max(feature)
feature = (feature - _min) / (_max - _min + 1e-6)
feature = torch.cat([feature]*3, 0)
feature = feature.view(1, 3, feature.size(1), feature.size(2))
feature = F.upsample(feature, size=(h,w), mode="bilinear")
feature = feature.view(3, h, w)
image_feature = torch.cat((image_feature, feature), 2)
return image_feature
def get_model_name(args):
from datetime import datetime
now = datetime.now()
date_time = now.strftime("%B_%d_%H:%M:%S")
model_name = '__'.join([date_time, args.network, str(args.seed)])
return model_name
def dict_to_namedtuple(d):
Args = collections.namedtuple('Args', sorted(d.keys()))
for k,v in d.items():
if type(v) is dict:
d[k] = dict_to_namedtuple(v)
elif type(v) is str:
try:
d[k] = eval(v)
except:
d[k] = v
args = Args(**d)
return args
def parse_args(kwargs):
# combine with default args
kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'cifar10'
kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet_cifar10'
kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.1
kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 2000
kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 2000
kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128
kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 64000
kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False
kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
# to named tuple
args = dict_to_namedtuple(kwargs)
return args, kwargs
def select_model(args):
if args.network in models.__dict__:
backbone = models.__dict__[args.network]()
model = BaseNet(backbone, args)
else:
Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
model = Net(args)
print(model)
return model
def select_optimizer(args, model):
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
elif args.optimizer == 'rms':
#optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5)
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
else:
raise Exception('Unknown Optimizer')
return optimizer
def select_scheduler(args, optimizer):
if not args.scheduler or args.scheduler == 'None':
return None
elif args.scheduler =='clr':
return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
elif args.scheduler =='exp':
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
else:
raise Exception('Unknown Scheduler')
class CustomDataset(Dataset):
def __init__(self, path, transform = None):
self.path = path
self.transform = transform
self.img = np.load(path)
self.len = self.img.shape[0]
def __len__(self):
return self.len
def __getitem__(self, idx):
if self.transforms is not None:
img = self.transforms(img)
return img
def get_dataset(args, transform, split='train'):
assert split in ['train', 'val', 'test', 'trainval']
if args.dataset == 'cifar10':
train = split in ['train', 'val', 'trainval']
dataset = torchvision.datasets.CIFAR10(DATASET_PATH,
train=train,
transform=transform,
download=True)
if split in ['train', 'val']:
split_path = os.path.join(DATASET_PATH,
'cifar-10-batches-py', 'train_val_index.cp')
if not os.path.exists(split_path):
[train_index], [val_index] = split_dataset(args, dataset, k=1)
split_index = {'train':train_index, 'val':val_index}
cp.dump(split_index, open(split_path, 'wb'))
split_index = cp.load(open(split_path, 'rb'))
dataset = Subset(dataset, split_index[split])
elif args.dataset == 'imagenet':
dataset = torchvision.datasets.ImageNet(DATASET_PATH,
split=split,
transform=transform,
download=(split is 'val'))
elif args.dataset == 'BraTS':
if split in ['train']:
dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform)
else:
dataset = CustomDataset(VAL_DATASET_PATH, transform=transform)
else:
raise Exception('Unknown dataset')
return dataset
def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.num_workers,
pin_memory=pin_memory)
return data_loader
def get_inf_dataloader(args, dataset):
global current_epoch
data_loader = iter(get_dataloader(args, dataset, shuffle=True))
while True:
try:
batch = next(data_loader)
except StopIteration:
current_epoch += 1
data_loader = iter(get_dataloader(args, dataset, shuffle=True))
batch = next(data_loader)
yield batch
def get_train_transform(args, model, log_dir=None):
if args.fast_auto_augment:
assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet
from fast_auto_augment import fast_auto_augment
if args.augment_path:
transform = cp.load(open(args.augment_path, 'rb'))
os.system('cp {} {}'.format(
args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
else:
transform = fast_auto_augment(args, model, K=4, B=1, num_process=4)
if log_dir:
cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
elif args.dataset == 'cifar10':
transform = transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
elif args.dataset == 'imagenet':
resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
transform = transforms.Compose([
transforms.Resize([resize_h, resize_w]),
transforms.RandomCrop(model.img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
elif args.dataset == 'BraTS':
resize_h, resize_w = 256, 256
transform = transforms.Compose([
transforms.Resize([resize_h, resize_w]),
transforms.RandomCrop(model.img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
else:
raise Exception('Unknown Dataset')
print(transform)
return transform
def get_valid_transform(args, model):
if args.dataset == 'cifar10':
val_transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor()
])
elif args.dataset == 'imagenet':
resize_h, resize_w = model.img_size[0], int(model.img_size[1]*1.875)
val_transform = transforms.Compose([
transforms.Resize([resize_h, resize_w]),
transforms.ToTensor()
])
elif args.dataset == 'BraTS':
resize_h, resize_w = 256, 256
val_transform = transforms.Compose([
transforms.Resize([resize_h, resize_w]),
transforms.ToTensor()
])
else:
raise Exception('Unknown Dataset')
return val_transform
def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
model.train()
images, target = batch
if device:
images = images.to(device)
target = target.to(device)
elif args.use_cuda:
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
start_t = time.time()
output, first = model(images)
forward_t = time.time() - start_t
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1 /= images.size(0)
acc5 /= images.size(0)
# compute gradient and do SGD step
optimizer.zero_grad()
start_t = time.time()
loss.backward()
backward_t = time.time() - start_t
optimizer.step()
if scheduler: scheduler.step()
if writer and step % args.print_step == 0:
n_imgs = min(images.size(0), 10)
for j in range(n_imgs):
writer.add_image('train/input_image',
concat_image_features(images[j], first[j]), global_step=step)
return acc1, acc5, loss, forward_t, backward_t
def validate(args, model, criterion, valid_loader, step, writer, device=None):
# switch to evaluate mode
model.eval()
acc1, acc5 = 0, 0
samples = 0
infer_t = 0
with torch.no_grad():
for i, (images, target) in enumerate(valid_loader):
start_t = time.time()
if device:
images = images.to(device)
target = target.to(device)
elif args.use_cuda is not None:
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output, first = model(images)
loss = criterion(output, target)
infer_t += time.time() - start_t
# measure accuracy and record loss
_acc1, _acc5 = accuracy(output, target, topk=(1, 5))
acc1 += _acc1
acc5 += _acc5
samples += images.size(0)
acc1 /= samples
acc5 /= samples
if writer:
n_imgs = min(images.size(0), 10)
for j in range(n_imgs):
writer.add_image('valid/input_image',
concat_image_features(images[j], first[j]), global_step=step)
return acc1, acc5, loss, infer_t
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k)
return res