You need to sign in or sign up before continuing.
Name Last Update
..
figures Loading commit data...
logs Loading commit data...
networks Loading commit data...
.gitignore Loading commit data...
FAA2.ipynb Loading commit data...
README.md Loading commit data...
cifar_utils.py Loading commit data...
eval.py Loading commit data...
fast_auto_augment.py Loading commit data...
requirements.txt Loading commit data...
train.py Loading commit data...
transforms.py Loading commit data...
utils.py Loading commit data...

Fast Autoaugment

A Pytorch Implementation of Fast AutoAugment and EfficientNet.

Prerequisite

  • torch==1.1.0
  • torchvision==0.2.2
  • hyperopt==0.1.2
  • future==0.17.1
  • tb-nightly==1.15.0a20190622

Usage

Training

CIFAR10

# ResNet20 (w/o FastAutoAugment)
python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=False

# ResNet20 (w/ FastAutoAugment)
python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True

# ResNet20 (w/ FastAutoAugment, Pre-found policy)
python train.py --seed=24 --scale=3 --optimizer=sgd --fast_auto_augment=True \
                --augment_path=runs/ResNet_Scale3_FastAutoAugment/augmentation.cp

# ResNet32 (w/o FastAutoAugment)
python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=False

# ResNet32 (w/ FastAutoAugment)
python train.py --seed=24 --scale=5 --optimizer=sgd --fast_auto_augment=True

# EfficientNet (w/ FastAutoAugment)
python train.py --seed=24 --pi=0 --optimizer=adam --fast_auto_augment=True \
                --network=efficientnet_cifar10 --activation=swish

ImageNet (You can use any backbone networks in torchvision.models)


# BaseNet (w/o FastAutoAugment)
python train.py --seed=24 --dataset=imagenet --optimizer=adam --network=resnet50

# EfficientNet (w/ FastAutoAugment) (UnderConstruction)
python train.py --seed=24 --dataset=imagenet --pi=0 --optimizer=adam --fast_auto_augment=True \
                --network=efficientnet --activation=swish

Eval

# Single Image testing
python eval.py --model_path=runs/ResNet_Scale3_Basline

# 5-crops testing
python eval.py --model_path=runs/ResNet_Scale3_Basline --five_crops=True

Experiments

Fast AutoAugment

ResNet20 (CIFAR10)

  • Pre-trained model [Download]
  • Validation Curve

  • Evaluation (Acc @1)

Valid Test(Single)
ResNet20 90.70 91.45
ResNet20 + FAA 92.46 91.45

ResNet34 (CIFAR10)

  • Validation Curve

  • Evaluation (Acc @1)

Valid Test(Single)
ResNet34 91.54 91.47
ResNet34 + FAA 92.76 91.99

Found Policy [Download]

Augmented images