Showing
5 changed files
with
1071 additions
and
272 deletions
code/FAA2_VM/getAugmented_1.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | + | ||
| 7 | +import torch | ||
| 8 | +import torch.nn as nn | ||
| 9 | +from torch.utils.tensorboard import SummaryWriter | ||
| 10 | + | ||
| 11 | +from utils import * | ||
| 12 | + | ||
| 13 | +# command | ||
| 14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
| 15 | + | ||
| 16 | +def eval(model_path): | ||
| 17 | + print('\n[+] Parse arguments') | ||
| 18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 20 | + args, kwargs = parse_args(kwargs) | ||
| 21 | + pprint(args) | ||
| 22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 23 | + | ||
| 24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 25 | + | ||
| 26 | + writer = SummaryWriter(log_dir=model_path) | ||
| 27 | + | ||
| 28 | + | ||
| 29 | + print('\n[+] Load transform') | ||
| 30 | + # list | ||
| 31 | + with open(cp_path, 'rb') as f: | ||
| 32 | + aug_transform_list = pickle.load(f) | ||
| 33 | + | ||
| 34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test')) | ||
| 35 | + | ||
| 36 | + | ||
| 37 | + print('\n[+] Load dataset') | ||
| 38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
| 39 | + dataset = get_dataset(args, aug_transform, 'test') | ||
| 40 | + | ||
| 41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 42 | + | ||
| 43 | + for i, (images, target) in enumerate(loader): | ||
| 44 | + images = images.view(240, 240) | ||
| 45 | + | ||
| 46 | + # concat image | ||
| 47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
| 48 | + | ||
| 49 | + if i % 1000 == 0: | ||
| 50 | + print("\n images size: ", augmented_image_list[i].size()) # [240, 240] | ||
| 51 | + | ||
| 52 | + break | ||
| 53 | + # break | ||
| 54 | + | ||
| 55 | + | ||
| 56 | + # print(augmented_image_list) | ||
| 57 | + | ||
| 58 | + | ||
| 59 | + print('\n[+] Write on tensorboard') | ||
| 60 | + if writer: | ||
| 61 | + for i, data in enumerate(augmented_image_list): | ||
| 62 | + tag = 'img/' + str(i) | ||
| 63 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
| 64 | + break | ||
| 65 | + | ||
| 66 | + writer.close() | ||
| 67 | + | ||
| 68 | + | ||
| 69 | + # if writer: | ||
| 70 | + # for j in range(): | ||
| 71 | + # tag = 'img/' + str(img_count) + '_' + str(j) | ||
| 72 | + # # writer.add_image(tag, | ||
| 73 | + # # concat_image_features(images[j], first[j]), global_step=step) | ||
| 74 | + # # if j > 0: | ||
| 75 | + # # fore = concat_image_features(fore, images[j]) | ||
| 76 | + | ||
| 77 | + # writer.add_image(tag, fore, global_step=0) | ||
| 78 | + # img_count = img_count + 1 | ||
| 79 | + | ||
| 80 | + # writer.close() | ||
| 81 | + | ||
| 82 | +if __name__ == '__main__': | ||
| 83 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_all.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | + | ||
| 7 | +import torch | ||
| 8 | +import torch.nn as nn | ||
| 9 | +from torch.utils.tensorboard import SummaryWriter | ||
| 10 | + | ||
| 11 | +from utils import * | ||
| 12 | + | ||
| 13 | +# command | ||
| 14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
| 15 | + | ||
| 16 | +def eval(model_path): | ||
| 17 | + print('\n[+] Parse arguments') | ||
| 18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 20 | + args, kwargs = parse_args(kwargs) | ||
| 21 | + pprint(args) | ||
| 22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 23 | + | ||
| 24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 25 | + | ||
| 26 | + writer = SummaryWriter(log_dir=model_path) | ||
| 27 | + | ||
| 28 | + | ||
| 29 | + print('\n[+] Load transform') | ||
| 30 | + # list | ||
| 31 | + with open(cp_path, 'rb') as f: | ||
| 32 | + aug_transform_list = pickle.load(f) | ||
| 33 | + | ||
| 34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'train')) | ||
| 35 | + | ||
| 36 | + | ||
| 37 | + print('\n[+] Load dataset') | ||
| 38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
| 39 | + dataset = get_dataset(args, aug_transform, 'train') | ||
| 40 | + | ||
| 41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 42 | + | ||
| 43 | + for i, (images, target) in enumerate(loader): | ||
| 44 | + images = images.view(240, 240) | ||
| 45 | + | ||
| 46 | + # concat image | ||
| 47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
| 48 | + | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + | ||
| 52 | + print('\n[+] Write on tensorboard') | ||
| 53 | + if writer: | ||
| 54 | + for i, data in enumerate(augmented_image_list): | ||
| 55 | + tag = 'img/' + str(i) | ||
| 56 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
| 57 | + | ||
| 58 | + writer.close() | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +if __name__ == '__main__': | ||
| 62 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_saveimg.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | +import random | ||
| 7 | + | ||
| 8 | +import torch | ||
| 9 | +import torch.nn as nn | ||
| 10 | +from torchvision.utils import save_image | ||
| 11 | +from torch.utils.tensorboard import SummaryWriter | ||
| 12 | + | ||
| 13 | +from utils import * | ||
| 14 | + | ||
| 15 | +# command | ||
| 16 | +# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/' | ||
| 17 | + | ||
| 18 | +def eval(model_path): | ||
| 19 | + print('\n[+] Parse arguments') | ||
| 20 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 21 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 22 | + args, kwargs = parse_args(kwargs) | ||
| 23 | + pprint(args) | ||
| 24 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 25 | + | ||
| 26 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 27 | + | ||
| 28 | + writer = SummaryWriter(log_dir=model_path) | ||
| 29 | + | ||
| 30 | + | ||
| 31 | + print('\n[+] Load transform') | ||
| 32 | + # list to tensor | ||
| 33 | + with open(cp_path, 'rb') as f: | ||
| 34 | + aug_transform_list = pickle.load(f) | ||
| 35 | + | ||
| 36 | + transform = transforms.RandomChoice(aug_transform_list) | ||
| 37 | + | ||
| 38 | + | ||
| 39 | + print('\n[+] Load dataset') | ||
| 40 | + | ||
| 41 | + dataset = get_dataset(args, transform, 'train') | ||
| 42 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 43 | + | ||
| 44 | + | ||
| 45 | + print('\n[+] Save 1 random policy') | ||
| 46 | + os.makedirs(os.path.join(model_path, 'augmented_imgs')) | ||
| 47 | + save_dir = os.path.join(model_path, 'augmented_imgs') | ||
| 48 | + | ||
| 49 | + for i, (image, target) in enumerate(loader): | ||
| 50 | + image = image.view(240, 240) | ||
| 51 | + # save img | ||
| 52 | + save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png')) | ||
| 53 | + | ||
| 54 | + if(i % 100 == 0): | ||
| 55 | + print("\n saved images: ", i) | ||
| 56 | + | ||
| 57 | + print('\n[+] Finished to save') | ||
| 58 | + | ||
| 59 | +if __name__ == '__main__': | ||
| 60 | + fire.Fire(eval) | ||
| 61 | + | ||
| 62 | + | ||
| 63 | + |
code/classifier/classify_normal_lesion.ipynb
0 → 100644
| 1 | +{ | ||
| 2 | + "nbformat": 4, | ||
| 3 | + "nbformat_minor": 0, | ||
| 4 | + "metadata": { | ||
| 5 | + "colab": { | ||
| 6 | + "name": "classify normal/lesion.ipynb", | ||
| 7 | + "provenance": [], | ||
| 8 | + "collapsed_sections": [] | ||
| 9 | + }, | ||
| 10 | + "kernelspec": { | ||
| 11 | + "name": "python3", | ||
| 12 | + "display_name": "Python 3" | ||
| 13 | + }, | ||
| 14 | + "accelerator": "GPU" | ||
| 15 | + }, | ||
| 16 | + "cells": [ | ||
| 17 | + { | ||
| 18 | + "cell_type": "code", | ||
| 19 | + "metadata": { | ||
| 20 | + "id": "AjoTMXMCrFYX", | ||
| 21 | + "colab_type": "code", | ||
| 22 | + "outputId": "2548434e-72b7-4946-9748-070103017379", | ||
| 23 | + "colab": { | ||
| 24 | + "base_uri": "https://localhost:8080/", | ||
| 25 | + "height": 129 | ||
| 26 | + } | ||
| 27 | + }, | ||
| 28 | + "source": [ | ||
| 29 | + "from google.colab import drive\n", | ||
| 30 | + "drive.mount('/content/drive')" | ||
| 31 | + ], | ||
| 32 | + "execution_count": 1, | ||
| 33 | + "outputs": [ | ||
| 34 | + { | ||
| 35 | + "output_type": "stream", | ||
| 36 | + "text": [ | ||
| 37 | + "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", | ||
| 38 | + "\n", | ||
| 39 | + "Enter your authorization code:\n", | ||
| 40 | + "··········\n", | ||
| 41 | + "Mounted at /content/drive\n" | ||
| 42 | + ], | ||
| 43 | + "name": "stdout" | ||
| 44 | + } | ||
| 45 | + ] | ||
| 46 | + }, | ||
| 47 | + { | ||
| 48 | + "cell_type": "code", | ||
| 49 | + "metadata": { | ||
| 50 | + "id": "lXK8NZfIyzeB", | ||
| 51 | + "colab_type": "code", | ||
| 52 | + "outputId": "17508683-94fb-45fa-b9df-f7ddb9401b68", | ||
| 53 | + "colab": { | ||
| 54 | + "base_uri": "https://localhost:8080/", | ||
| 55 | + "height": 146 | ||
| 56 | + } | ||
| 57 | + }, | ||
| 58 | + "source": [ | ||
| 59 | + "!git clone http://khuhub.khu.ac.kr/2020-1-capstone-design2/2016104167.git" | ||
| 60 | + ], | ||
| 61 | + "execution_count": 2, | ||
| 62 | + "outputs": [ | ||
| 63 | + { | ||
| 64 | + "output_type": "stream", | ||
| 65 | + "text": [ | ||
| 66 | + "Cloning into '2016104167'...\n", | ||
| 67 | + "remote: Counting objects: 11451, done.\u001b[K\n", | ||
| 68 | + "remote: Compressing objects: 100% (39/39), done.\u001b[K\n", | ||
| 69 | + "remote: Total 11451 (delta 15), reused 0 (delta 0)\u001b[K\n", | ||
| 70 | + "Receiving objects: 100% (11451/11451), 292.82 MiB | 384.00 KiB/s, done.\n", | ||
| 71 | + "Resolving deltas: 100% (1109/1109), done.\n", | ||
| 72 | + "Checking out files: 100% (15684/15684), done.\n" | ||
| 73 | + ], | ||
| 74 | + "name": "stdout" | ||
| 75 | + } | ||
| 76 | + ] | ||
| 77 | + }, | ||
| 78 | + { | ||
| 79 | + "cell_type": "code", | ||
| 80 | + "metadata": { | ||
| 81 | + "id": "TmGc36H2y5sI", | ||
| 82 | + "colab_type": "code", | ||
| 83 | + "outputId": "49ce70f0-10bb-48d8-bd89-de98a28c7893", | ||
| 84 | + "colab": { | ||
| 85 | + "base_uri": "https://localhost:8080/", | ||
| 86 | + "height": 35 | ||
| 87 | + } | ||
| 88 | + }, | ||
| 89 | + "source": [ | ||
| 90 | + "%cd '2016104167/code/classifier/'" | ||
| 91 | + ], | ||
| 92 | + "execution_count": 3, | ||
| 93 | + "outputs": [ | ||
| 94 | + { | ||
| 95 | + "output_type": "stream", | ||
| 96 | + "text": [ | ||
| 97 | + "/content/2016104167/code/classifier\n" | ||
| 98 | + ], | ||
| 99 | + "name": "stdout" | ||
| 100 | + } | ||
| 101 | + ] | ||
| 102 | + }, | ||
| 103 | + { | ||
| 104 | + "cell_type": "code", | ||
| 105 | + "metadata": { | ||
| 106 | + "id": "oJ08JUJCzEEE", | ||
| 107 | + "colab_type": "code", | ||
| 108 | + "outputId": "f9454032-73ad-444e-b708-e1228f6b3a0b", | ||
| 109 | + "colab": { | ||
| 110 | + "base_uri": "https://localhost:8080/", | ||
| 111 | + "height": 1000 | ||
| 112 | + } | ||
| 113 | + }, | ||
| 114 | + "source": [ | ||
| 115 | + "!python -m pip install -r \"requirements.txt\"" | ||
| 116 | + ], | ||
| 117 | + "execution_count": 4, | ||
| 118 | + "outputs": [ | ||
| 119 | + { | ||
| 120 | + "output_type": "stream", | ||
| 121 | + "text": [ | ||
| 122 | + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 1)) (0.16.0)\n", | ||
| 123 | + "Collecting tb-nightly\n", | ||
| 124 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/4e/46/4b95936aed44f2154d936160de6c58e3bd4cf8501152db21945617c84694/tb_nightly-2.3.0a20200425-py3-none-any.whl (2.9MB)\n", | ||
| 125 | + "\u001b[K |████████████████████████████████| 2.9MB 39.9MB/s \n", | ||
| 126 | + "\u001b[?25hRequirement already satisfied: hyperopt in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 3)) (0.1.2)\n", | ||
| 127 | + "Collecting pillow==6.2.1\n", | ||
| 128 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/10/5c/0e94e689de2476c4c5e644a3bd223a1c1b9e2bdb7c510191750be74fa786/Pillow-6.2.1-cp36-cp36m-manylinux1_x86_64.whl (2.1MB)\n", | ||
| 129 | + "\u001b[K |████████████████████████████████| 2.1MB 48.1MB/s \n", | ||
| 130 | + "\u001b[?25hRequirement already satisfied: natsort in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 5)) (5.5.0)\n", | ||
| 131 | + "Collecting fire\n", | ||
| 132 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)\n", | ||
| 133 | + "\u001b[K |████████████████████████████████| 81kB 12.1MB/s \n", | ||
| 134 | + "\u001b[?25hCollecting torchvision==0.2.2\n", | ||
| 135 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ce/a1/66d72a2fe580a9f0fcbaaa5b976911fbbde9dce9b330ba12791997b856e9/torchvision-0.2.2-py2.py3-none-any.whl (64kB)\n", | ||
| 136 | + "\u001b[K |████████████████████████████████| 71kB 11.8MB/s \n", | ||
| 137 | + "\u001b[?25hCollecting torch==1.1.0\n", | ||
| 138 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/60/f685fb2cfb3088736bafbc9bdbb455327bdc8906b606da9c9a81bae1c81e/torch-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (676.9MB)\n", | ||
| 139 | + "\u001b[K |████████████████████████████████| 676.9MB 18kB/s \n", | ||
| 140 | + "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 9)) (1.0.3)\n", | ||
| 141 | + "Requirement already satisfied: sklearn in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 10)) (0.0)\n", | ||
| 142 | + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (3.2.1)\n", | ||
| 143 | + "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (3.10.0)\n", | ||
| 144 | + "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.7.2)\n", | ||
| 145 | + "Requirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.18.3)\n", | ||
| 146 | + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.6.0.post3)\n", | ||
| 147 | + "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (0.34.2)\n", | ||
| 148 | + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (0.9.0)\n", | ||
| 149 | + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (46.1.3)\n", | ||
| 150 | + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (2.21.0)\n", | ||
| 151 | + "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.12.0)\n", | ||
| 152 | + "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.0.1)\n", | ||
| 153 | + "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.28.1)\n", | ||
| 154 | + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (0.4.1)\n", | ||
| 155 | + "Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (2.4)\n", | ||
| 156 | + "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (4.38.0)\n", | ||
| 157 | + "Requirement already satisfied: pymongo in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (3.10.1)\n", | ||
| 158 | + "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (1.4.1)\n", | ||
| 159 | + "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->-r requirements.txt (line 6)) (1.1.0)\n", | ||
| 160 | + "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->-r requirements.txt (line 9)) (2.8.1)\n", | ||
| 161 | + "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->-r requirements.txt (line 9)) (2018.9)\n", | ||
| 162 | + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sklearn->-r requirements.txt (line 10)) (0.22.2.post1)\n", | ||
| 163 | + "Requirement already satisfied: rsa<4.1,>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (4.0)\n", | ||
| 164 | + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (0.2.8)\n", | ||
| 165 | + "Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (3.1.1)\n", | ||
| 166 | + "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (3.0.4)\n", | ||
| 167 | + "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (2.8)\n", | ||
| 168 | + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (2020.4.5.1)\n", | ||
| 169 | + "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (1.24.3)\n", | ||
| 170 | + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly->-r requirements.txt (line 2)) (1.3.0)\n", | ||
| 171 | + "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->hyperopt->-r requirements.txt (line 3)) (4.4.2)\n", | ||
| 172 | + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn->-r requirements.txt (line 10)) (0.14.1)\n", | ||
| 173 | + "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<4.1,>=3.1.4->google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (0.4.8)\n", | ||
| 174 | + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly->-r requirements.txt (line 2)) (3.1.0)\n", | ||
| 175 | + "Building wheels for collected packages: fire\n", | ||
| 176 | + " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | ||
| 177 | + " Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=090cc0b99c44969c5594966fd1af925b4ff73b02719d44b05fb4aacd07b9bfb3\n", | ||
| 178 | + " Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac\n", | ||
| 179 | + "Successfully built fire\n", | ||
| 180 | + "\u001b[31mERROR: torchvision 0.2.2 has requirement tqdm==4.19.9, but you'll have tqdm 4.38.0 which is incompatible.\u001b[0m\n", | ||
| 181 | + "\u001b[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.\u001b[0m\n", | ||
| 182 | + "Installing collected packages: tb-nightly, pillow, fire, torch, torchvision\n", | ||
| 183 | + " Found existing installation: Pillow 7.0.0\n", | ||
| 184 | + " Uninstalling Pillow-7.0.0:\n", | ||
| 185 | + " Successfully uninstalled Pillow-7.0.0\n", | ||
| 186 | + " Found existing installation: torch 1.4.0\n", | ||
| 187 | + " Uninstalling torch-1.4.0:\n", | ||
| 188 | + " Successfully uninstalled torch-1.4.0\n", | ||
| 189 | + " Found existing installation: torchvision 0.5.0\n", | ||
| 190 | + " Uninstalling torchvision-0.5.0:\n", | ||
| 191 | + " Successfully uninstalled torchvision-0.5.0\n", | ||
| 192 | + "Successfully installed fire-0.3.1 pillow-6.2.1 tb-nightly-2.3.0a20200425 torch-1.1.0 torchvision-0.2.2\n" | ||
| 193 | + ], | ||
| 194 | + "name": "stdout" | ||
| 195 | + } | ||
| 196 | + ] | ||
| 197 | + }, | ||
| 198 | + { | ||
| 199 | + "cell_type": "code", | ||
| 200 | + "metadata": { | ||
| 201 | + "id": "jdayOoSYHJDf", | ||
| 202 | + "colab_type": "code", | ||
| 203 | + "outputId": "f663313d-1166-4f6c-f617-d3d331ce7793", | ||
| 204 | + "colab": { | ||
| 205 | + "base_uri": "https://localhost:8080/", | ||
| 206 | + "height": 1000 | ||
| 207 | + } | ||
| 208 | + }, | ||
| 209 | + "source": [ | ||
| 210 | + "!python train.py --use_cuda=True --network=resnet50 --optimizer=adam " | ||
| 211 | + ], | ||
| 212 | + "execution_count": 9, | ||
| 213 | + "outputs": [ | ||
| 214 | + { | ||
| 215 | + "output_type": "stream", | ||
| 216 | + "text": [ | ||
| 217 | + "\n", | ||
| 218 | + "[+] Parse arguments\n", | ||
| 219 | + "Args(augment_path=None, batch_size=8, dataset='BraTS', learning_rate=1e-06, max_step=2000, network='resnet50', num_workers=4, optimizer='adam', print_step=50, scheduler='exp', seed=None, start_step=0, use_cuda=True, val_step=50)\n", | ||
| 220 | + "\n", | ||
| 221 | + "[+] Create log dir\n", | ||
| 222 | + "2020-04-26 04:35:05.510266: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", | ||
| 223 | + "\n", | ||
| 224 | + "[+] Create network\n", | ||
| 225 | + "\n", | ||
| 226 | + "[+] Load dataset\n", | ||
| 227 | + "\n", | ||
| 228 | + "[+] Start training\n", | ||
| 229 | + "\n", | ||
| 230 | + "[+] Use 1 GPUs\n", | ||
| 231 | + "\n", | ||
| 232 | + "[+] Using GPU: Tesla P4 \n", | ||
| 233 | + "\n", | ||
| 234 | + "[+] Training step: 0/2000\tTraining epoch: 0/580\tElapsed time: 0.01min\tLearning rate: 9.999283e-07\n", | ||
| 235 | + " Acc@1 : 0.000%\n", | ||
| 236 | + " Loss : 7.1226911544799805\n", | ||
| 237 | + " FW Time : 125.206ms\n", | ||
| 238 | + " BW Time : 14.758ms\n", | ||
| 239 | + "\n", | ||
| 240 | + "[+] (Valid results) Valid step: 49/2000\n", | ||
| 241 | + " Acc@1 : 0.000%\n", | ||
| 242 | + " Loss : 7.333343982696533\n", | ||
| 243 | + "\n", | ||
| 244 | + "[+] Model saved\n", | ||
| 245 | + "\n", | ||
| 246 | + "[+] Training step: 50/2000\tTraining epoch: 0/580\tElapsed time: 0.17min\tLearning rate: 9.963498469652168e-07\n", | ||
| 247 | + " Acc@1 : 0.000%\n", | ||
| 248 | + " Loss : 6.839444160461426\n", | ||
| 249 | + " FW Time : 15.798ms\n", | ||
| 250 | + " BW Time : 10.145ms\n", | ||
| 251 | + "\n", | ||
| 252 | + "[+] (Valid results) Valid step: 99/2000\n", | ||
| 253 | + " Acc@1 : 0.000%\n", | ||
| 254 | + " Loss : 7.0915045738220215\n", | ||
| 255 | + "\n", | ||
| 256 | + "[+] Model saved\n", | ||
| 257 | + "\n", | ||
| 258 | + "[+] Training step: 100/2000\tTraining epoch: 0/580\tElapsed time: 0.33min\tLearning rate: 9.92784200174764e-07\n", | ||
| 259 | + " Acc@1 : 0.000%\n", | ||
| 260 | + " Loss : 6.130648136138916\n", | ||
| 261 | + " FW Time : 15.954ms\n", | ||
| 262 | + " BW Time : 13.310ms\n", | ||
| 263 | + "\n", | ||
| 264 | + "[+] (Valid results) Valid step: 149/2000\n", | ||
| 265 | + " Acc@1 : 12.000%\n", | ||
| 266 | + " Loss : 6.664271354675293\n", | ||
| 267 | + "\n", | ||
| 268 | + "[+] Model saved\n", | ||
| 269 | + "\n", | ||
| 270 | + "[+] Training step: 150/2000\tTraining epoch: 0/580\tElapsed time: 0.49min\tLearning rate: 9.89231313798811e-07\n", | ||
| 271 | + " Acc@1 : 25.000%\n", | ||
| 272 | + " Loss : 5.820688247680664\n", | ||
| 273 | + " FW Time : 15.717ms\n", | ||
| 274 | + " BW Time : 12.132ms\n", | ||
| 275 | + "\n", | ||
| 276 | + "[+] (Valid results) Valid step: 199/2000\n", | ||
| 277 | + " Acc@1 : 38.400%\n", | ||
| 278 | + " Loss : 6.647609233856201\n", | ||
| 279 | + "\n", | ||
| 280 | + "[+] Model saved\n", | ||
| 281 | + "\n", | ||
| 282 | + "[+] Training step: 200/2000\tTraining epoch: 0/580\tElapsed time: 0.65min\tLearning rate: 9.85691142171539e-07\n", | ||
| 283 | + " Acc@1 : 62.500%\n", | ||
| 284 | + " Loss : 5.464842319488525\n", | ||
| 285 | + " FW Time : 18.873ms\n", | ||
| 286 | + " BW Time : 14.734ms\n", | ||
| 287 | + "\n", | ||
| 288 | + "[+] (Valid results) Valid step: 249/2000\n", | ||
| 289 | + " Acc@1 : 51.600%\n", | ||
| 290 | + " Loss : 6.3311767578125\n", | ||
| 291 | + "\n", | ||
| 292 | + "[+] Model saved\n", | ||
| 293 | + "\n", | ||
| 294 | + "[+] Training step: 250/2000\tTraining epoch: 0/580\tElapsed time: 0.80min\tLearning rate: 9.821636397905564e-07\n", | ||
| 295 | + " Acc@1 : 62.500%\n", | ||
| 296 | + " Loss : 5.102325439453125\n", | ||
| 297 | + " FW Time : 20.299ms\n", | ||
| 298 | + " BW Time : 13.266ms\n", | ||
| 299 | + "\n", | ||
| 300 | + "[+] (Valid results) Valid step: 299/2000\n", | ||
| 301 | + " Acc@1 : 65.000%\n", | ||
| 302 | + " Loss : 5.964324951171875\n", | ||
| 303 | + "\n", | ||
| 304 | + "[+] Model saved\n", | ||
| 305 | + "\n", | ||
| 306 | + "[+] Training step: 300/2000\tTraining epoch: 0/580\tElapsed time: 0.97min\tLearning rate: 9.786487613163075e-07\n", | ||
| 307 | + " Acc@1 : 75.000%\n", | ||
| 308 | + " Loss : 4.567136287689209\n", | ||
| 309 | + " FW Time : 22.482ms\n", | ||
| 310 | + " BW Time : 11.849ms\n", | ||
| 311 | + "\n", | ||
| 312 | + "[+] (Valid results) Valid step: 349/2000\n", | ||
| 313 | + " Acc@1 : 67.000%\n", | ||
| 314 | + " Loss : 5.935492038726807\n", | ||
| 315 | + "\n", | ||
| 316 | + "[+] Model saved\n", | ||
| 317 | + "\n", | ||
| 318 | + "[+] Training step: 350/2000\tTraining epoch: 0/580\tElapsed time: 1.13min\tLearning rate: 9.751464615714972e-07\n", | ||
| 319 | + " Acc@1 : 75.000%\n", | ||
| 320 | + " Loss : 4.528119087219238\n", | ||
| 321 | + " FW Time : 15.663ms\n", | ||
| 322 | + " BW Time : 12.312ms\n", | ||
| 323 | + "\n", | ||
| 324 | + "[+] (Valid results) Valid step: 399/2000\n", | ||
| 325 | + " Acc@1 : 66.200%\n", | ||
| 326 | + " Loss : 6.177009105682373\n", | ||
| 327 | + "\n", | ||
| 328 | + "[+] Training step: 400/2000\tTraining epoch: 0/580\tElapsed time: 1.28min\tLearning rate: 9.716566955405052e-07\n", | ||
| 329 | + " Acc@1 : 50.000%\n", | ||
| 330 | + " Loss : 5.060157775878906\n", | ||
| 331 | + " FW Time : 18.538ms\n", | ||
| 332 | + " BW Time : 15.068ms\n", | ||
| 333 | + "\n", | ||
| 334 | + "[+] (Valid results) Valid step: 449/2000\n", | ||
| 335 | + " Acc@1 : 70.400%\n", | ||
| 336 | + " Loss : 5.513169288635254\n", | ||
| 337 | + "\n", | ||
| 338 | + "[+] Model saved\n", | ||
| 339 | + "\n", | ||
| 340 | + "[+] Training step: 450/2000\tTraining epoch: 0/580\tElapsed time: 1.44min\tLearning rate: 9.681794183688074e-07\n", | ||
| 341 | + " Acc@1 : 62.500%\n", | ||
| 342 | + " Loss : 4.341537952423096\n", | ||
| 343 | + " FW Time : 20.281ms\n", | ||
| 344 | + " BW Time : 11.464ms\n", | ||
| 345 | + "\n", | ||
| 346 | + "[+] (Valid results) Valid step: 499/2000\n", | ||
| 347 | + " Acc@1 : 71.200%\n", | ||
| 348 | + " Loss : 5.211921691894531\n", | ||
| 349 | + "\n", | ||
| 350 | + "[+] Model saved\n", | ||
| 351 | + "\n", | ||
| 352 | + "[+] Training step: 500/2000\tTraining epoch: 0/580\tElapsed time: 1.60min\tLearning rate: 9.647145853624042e-07\n", | ||
| 353 | + " Acc@1 : 62.500%\n", | ||
| 354 | + " Loss : 3.643502712249756\n", | ||
| 355 | + " FW Time : 35.242ms\n", | ||
| 356 | + " BW Time : 13.355ms\n", | ||
| 357 | + "\n", | ||
| 358 | + "[+] (Valid results) Valid step: 549/2000\n", | ||
| 359 | + " Acc@1 : 73.000%\n", | ||
| 360 | + " Loss : 5.065425872802734\n", | ||
| 361 | + "\n", | ||
| 362 | + "[+] Model saved\n", | ||
| 363 | + "\n", | ||
| 364 | + "[+] Training step: 550/2000\tTraining epoch: 0/580\tElapsed time: 1.76min\tLearning rate: 9.61262151987242e-07\n", | ||
| 365 | + " Acc@1 : 37.500%\n", | ||
| 366 | + " Loss : 4.352468013763428\n", | ||
| 367 | + " FW Time : 15.172ms\n", | ||
| 368 | + " BW Time : 14.894ms\n", | ||
| 369 | + "\n", | ||
| 370 | + "[+] (Valid results) Valid step: 599/2000\n", | ||
| 371 | + " Acc@1 : 72.800%\n", | ||
| 372 | + " Loss : 5.094676971435547\n", | ||
| 373 | + "\n", | ||
| 374 | + "[+] Training step: 600/2000\tTraining epoch: 0/580\tElapsed time: 1.92min\tLearning rate: 9.578220738686398e-07\n", | ||
| 375 | + " Acc@1 : 62.500%\n", | ||
| 376 | + " Loss : 2.67336106300354\n", | ||
| 377 | + " FW Time : 16.094ms\n", | ||
| 378 | + " BW Time : 11.950ms\n", | ||
| 379 | + "\n", | ||
| 380 | + "[+] (Valid results) Valid step: 649/2000\n", | ||
| 381 | + " Acc@1 : 74.200%\n", | ||
| 382 | + " Loss : 4.8995866775512695\n", | ||
| 383 | + "\n", | ||
| 384 | + "[+] Model saved\n", | ||
| 385 | + "\n", | ||
| 386 | + "[+] Training step: 650/2000\tTraining epoch: 0/580\tElapsed time: 2.08min\tLearning rate: 9.543943067907226e-07\n", | ||
| 387 | + " Acc@1 : 87.500%\n", | ||
| 388 | + " Loss : 2.3277320861816406\n", | ||
| 389 | + " FW Time : 22.967ms\n", | ||
| 390 | + " BW Time : 16.513ms\n", | ||
| 391 | + "\n", | ||
| 392 | + "[+] (Valid results) Valid step: 699/2000\n", | ||
| 393 | + " Acc@1 : 78.000%\n", | ||
| 394 | + " Loss : 4.331630706787109\n", | ||
| 395 | + "\n", | ||
| 396 | + "[+] Model saved\n", | ||
| 397 | + "\n", | ||
| 398 | + "[+] Training step: 700/2000\tTraining epoch: 0/580\tElapsed time: 2.24min\tLearning rate: 9.509788066958503e-07\n", | ||
| 399 | + " Acc@1 : 50.000%\n", | ||
| 400 | + " Loss : 4.2444071769714355\n", | ||
| 401 | + " FW Time : 26.178ms\n", | ||
| 402 | + " BW Time : 20.318ms\n", | ||
| 403 | + "\n", | ||
| 404 | + "[+] (Valid results) Valid step: 749/2000\n", | ||
| 405 | + " Acc@1 : 79.800%\n", | ||
| 406 | + " Loss : 4.202180862426758\n", | ||
| 407 | + "\n", | ||
| 408 | + "[+] Model saved\n", | ||
| 409 | + "\n", | ||
| 410 | + "[+] Training step: 750/2000\tTraining epoch: 0/580\tElapsed time: 2.40min\tLearning rate: 9.475755296840536e-07\n", | ||
| 411 | + " Acc@1 : 75.000%\n", | ||
| 412 | + " Loss : 2.0681729316711426\n", | ||
| 413 | + " FW Time : 14.539ms\n", | ||
| 414 | + " BW Time : 9.137ms\n", | ||
| 415 | + "\n", | ||
| 416 | + "[+] (Valid results) Valid step: 799/2000\n", | ||
| 417 | + " Acc@1 : 81.800%\n", | ||
| 418 | + " Loss : 3.595407009124756\n", | ||
| 419 | + "\n", | ||
| 420 | + "[+] Model saved\n", | ||
| 421 | + "\n", | ||
| 422 | + "[+] Training step: 800/2000\tTraining epoch: 0/580\tElapsed time: 2.56min\tLearning rate: 9.441844320124666e-07\n", | ||
| 423 | + " Acc@1 : 87.500%\n", | ||
| 424 | + " Loss : 1.7526265382766724\n", | ||
| 425 | + " FW Time : 16.346ms\n", | ||
| 426 | + " BW Time : 13.979ms\n", | ||
| 427 | + "\n", | ||
| 428 | + "[+] (Valid results) Valid step: 849/2000\n", | ||
| 429 | + " Acc@1 : 82.800%\n", | ||
| 430 | + " Loss : 3.435865879058838\n", | ||
| 431 | + "\n", | ||
| 432 | + "[+] Model saved\n", | ||
| 433 | + "\n", | ||
| 434 | + "[+] Training step: 850/2000\tTraining epoch: 0/580\tElapsed time: 2.72min\tLearning rate: 9.408054700947673e-07\n", | ||
| 435 | + " Acc@1 : 100.000%\n", | ||
| 436 | + " Loss : 1.327768325805664\n", | ||
| 437 | + " FW Time : 20.280ms\n", | ||
| 438 | + " BW Time : 11.583ms\n", | ||
| 439 | + "\n", | ||
| 440 | + "[+] (Valid results) Valid step: 899/2000\n", | ||
| 441 | + " Acc@1 : 86.800%\n", | ||
| 442 | + " Loss : 3.0217819213867188\n", | ||
| 443 | + "\n", | ||
| 444 | + "[+] Model saved\n", | ||
| 445 | + "\n", | ||
| 446 | + "[+] Training step: 900/2000\tTraining epoch: 0/580\tElapsed time: 2.88min\tLearning rate: 9.37438600500616e-07\n", | ||
| 447 | + " Acc@1 : 87.500%\n", | ||
| 448 | + " Loss : 1.7023265361785889\n", | ||
| 449 | + " FW Time : 16.003ms\n", | ||
| 450 | + " BW Time : 11.639ms\n", | ||
| 451 | + "\n", | ||
| 452 | + "[+] (Valid results) Valid step: 949/2000\n", | ||
| 453 | + " Acc@1 : 89.600%\n", | ||
| 454 | + " Loss : 2.5056638717651367\n", | ||
| 455 | + "\n", | ||
| 456 | + "[+] Model saved\n", | ||
| 457 | + "\n", | ||
| 458 | + "[+] Training step: 950/2000\tTraining epoch: 0/580\tElapsed time: 3.04min\tLearning rate: 9.340837799550989e-07\n", | ||
| 459 | + " Acc@1 : 87.500%\n", | ||
| 460 | + " Loss : 2.0191428661346436\n", | ||
| 461 | + " FW Time : 26.238ms\n", | ||
| 462 | + " BW Time : 11.809ms\n", | ||
| 463 | + "\n", | ||
| 464 | + "[+] (Valid results) Valid step: 999/2000\n", | ||
| 465 | + " Acc@1 : 92.600%\n", | ||
| 466 | + " Loss : 1.8181736469268799\n", | ||
| 467 | + "\n", | ||
| 468 | + "[+] Model saved\n", | ||
| 469 | + "\n", | ||
| 470 | + "[+] Training step: 1000/2000\tTraining epoch: 0/580\tElapsed time: 3.20min\tLearning rate: 9.307409653381686e-07\n", | ||
| 471 | + " Acc@1 : 87.500%\n", | ||
| 472 | + " Loss : 0.9778107404708862\n", | ||
| 473 | + " FW Time : 18.434ms\n", | ||
| 474 | + " BW Time : 17.896ms\n", | ||
| 475 | + "\n", | ||
| 476 | + "[+] (Valid results) Valid step: 1049/2000\n", | ||
| 477 | + " Acc@1 : 94.400%\n", | ||
| 478 | + " Loss : 1.7727973461151123\n", | ||
| 479 | + "\n", | ||
| 480 | + "[+] Model saved\n", | ||
| 481 | + "\n", | ||
| 482 | + "[+] Training step: 1050/2000\tTraining epoch: 0/580\tElapsed time: 3.36min\tLearning rate: 9.274101136840937e-07\n", | ||
| 483 | + " Acc@1 : 75.000%\n", | ||
| 484 | + " Loss : 1.6472196578979492\n", | ||
| 485 | + " FW Time : 22.264ms\n", | ||
| 486 | + " BW Time : 11.575ms\n", | ||
| 487 | + "\n", | ||
| 488 | + "[+] (Valid results) Valid step: 1099/2000\n", | ||
| 489 | + " Acc@1 : 96.400%\n", | ||
| 490 | + " Loss : 1.3225497007369995\n", | ||
| 491 | + "\n", | ||
| 492 | + "[+] Model saved\n", | ||
| 493 | + "\n", | ||
| 494 | + "[+] Training step: 1100/2000\tTraining epoch: 0/580\tElapsed time: 3.52min\tLearning rate: 9.240911821809037e-07\n", | ||
| 495 | + " Acc@1 : 87.500%\n", | ||
| 496 | + " Loss : 1.520838737487793\n", | ||
| 497 | + " FW Time : 17.655ms\n", | ||
| 498 | + " BW Time : 11.358ms\n", | ||
| 499 | + "\n", | ||
| 500 | + "[+] (Valid results) Valid step: 1149/2000\n", | ||
| 501 | + " Acc@1 : 97.200%\n", | ||
| 502 | + " Loss : 1.0447773933410645\n", | ||
| 503 | + "\n", | ||
| 504 | + "[+] Model saved\n", | ||
| 505 | + "\n", | ||
| 506 | + "[+] Training step: 1150/2000\tTraining epoch: 0/580\tElapsed time: 3.69min\tLearning rate: 9.207841281698394e-07\n", | ||
| 507 | + " Acc@1 : 100.000%\n", | ||
| 508 | + " Loss : 0.45810914039611816\n", | ||
| 509 | + " FW Time : 21.189ms\n", | ||
| 510 | + " BW Time : 21.055ms\n", | ||
| 511 | + "\n", | ||
| 512 | + "[+] (Valid results) Valid step: 1199/2000\n", | ||
| 513 | + " Acc@1 : 97.200%\n", | ||
| 514 | + " Loss : 0.6157697439193726\n", | ||
| 515 | + "\n", | ||
| 516 | + "[+] Model saved\n", | ||
| 517 | + "\n", | ||
| 518 | + "[+] Training step: 1200/2000\tTraining epoch: 0/580\tElapsed time: 3.85min\tLearning rate: 9.174889091448058e-07\n", | ||
| 519 | + " Acc@1 : 100.000%\n", | ||
| 520 | + " Loss : 0.2990947961807251\n", | ||
| 521 | + " FW Time : 16.315ms\n", | ||
| 522 | + " BW Time : 13.309ms\n", | ||
| 523 | + "\n", | ||
| 524 | + "[+] (Valid results) Valid step: 1249/2000\n", | ||
| 525 | + " Acc@1 : 98.200%\n", | ||
| 526 | + " Loss : 0.9130725860595703\n", | ||
| 527 | + "\n", | ||
| 528 | + "[+] Model saved\n", | ||
| 529 | + "\n", | ||
| 530 | + "[+] Training step: 1250/2000\tTraining epoch: 0/580\tElapsed time: 4.01min\tLearning rate: 9.142054827518248e-07\n", | ||
| 531 | + " Acc@1 : 100.000%\n", | ||
| 532 | + " Loss : 0.2261216640472412\n", | ||
| 533 | + " FW Time : 30.053ms\n", | ||
| 534 | + " BW Time : 18.239ms\n", | ||
| 535 | + "\n", | ||
| 536 | + "[+] (Valid results) Valid step: 1299/2000\n", | ||
| 537 | + " Acc@1 : 98.000%\n", | ||
| 538 | + " Loss : 0.4017503261566162\n", | ||
| 539 | + "\n", | ||
| 540 | + "[+] Training step: 1300/2000\tTraining epoch: 0/580\tElapsed time: 4.17min\tLearning rate: 9.109338067884897e-07\n", | ||
| 541 | + " Acc@1 : 100.000%\n", | ||
| 542 | + " Loss : 0.3248733878135681\n", | ||
| 543 | + " FW Time : 21.094ms\n", | ||
| 544 | + " BW Time : 13.526ms\n", | ||
| 545 | + "\n", | ||
| 546 | + "[+] (Valid results) Valid step: 1349/2000\n", | ||
| 547 | + " Acc@1 : 98.000%\n", | ||
| 548 | + " Loss : 0.3138556480407715\n", | ||
| 549 | + "\n", | ||
| 550 | + "[+] Training step: 1350/2000\tTraining epoch: 0/580\tElapsed time: 4.32min\tLearning rate: 9.076738392034251e-07\n", | ||
| 551 | + " Acc@1 : 100.000%\n", | ||
| 552 | + " Loss : 0.21796566247940063\n", | ||
| 553 | + " FW Time : 25.850ms\n", | ||
| 554 | + " BW Time : 19.265ms\n", | ||
| 555 | + "\n", | ||
| 556 | + "[+] (Valid results) Valid step: 1399/2000\n", | ||
| 557 | + " Acc@1 : 98.000%\n", | ||
| 558 | + " Loss : 0.16485238075256348\n", | ||
| 559 | + "\n", | ||
| 560 | + "[+] Training step: 1400/2000\tTraining epoch: 0/580\tElapsed time: 4.48min\tLearning rate: 9.044255380957452e-07\n", | ||
| 561 | + " Acc@1 : 100.000%\n", | ||
| 562 | + " Loss : 0.28114575147628784\n", | ||
| 563 | + " FW Time : 22.245ms\n", | ||
| 564 | + " BW Time : 15.607ms\n", | ||
| 565 | + "\n", | ||
| 566 | + "[+] (Valid results) Valid step: 1449/2000\n", | ||
| 567 | + " Acc@1 : 98.200%\n", | ||
| 568 | + " Loss : 0.20785784721374512\n", | ||
| 569 | + "\n", | ||
| 570 | + "[+] Model saved\n", | ||
| 571 | + "\n", | ||
| 572 | + "[+] Training step: 1450/2000\tTraining epoch: 0/580\tElapsed time: 4.64min\tLearning rate: 9.011888617145144e-07\n", | ||
| 573 | + " Acc@1 : 87.500%\n", | ||
| 574 | + " Loss : 1.18643319606781\n", | ||
| 575 | + " FW Time : 19.665ms\n", | ||
| 576 | + " BW Time : 12.618ms\n", | ||
| 577 | + "\n", | ||
| 578 | + "[+] (Valid results) Valid step: 1499/2000\n", | ||
| 579 | + " Acc@1 : 98.400%\n", | ||
| 580 | + " Loss : 0.17331039905548096\n", | ||
| 581 | + "\n", | ||
| 582 | + "[+] Model saved\n", | ||
| 583 | + "\n", | ||
| 584 | + "[+] Training step: 1500/2000\tTraining epoch: 0/580\tElapsed time: 4.80min\tLearning rate: 8.979637684582136e-07\n", | ||
| 585 | + " Acc@1 : 100.000%\n", | ||
| 586 | + " Loss : 0.27240896224975586\n", | ||
| 587 | + " FW Time : 24.560ms\n", | ||
| 588 | + " BW Time : 14.096ms\n", | ||
| 589 | + "\n", | ||
| 590 | + "[+] (Valid results) Valid step: 1549/2000\n", | ||
| 591 | + " Acc@1 : 99.600%\n", | ||
| 592 | + " Loss : 0.17186295986175537\n", | ||
| 593 | + "\n", | ||
| 594 | + "[+] Model saved\n", | ||
| 595 | + "\n", | ||
| 596 | + "[+] Training step: 1550/2000\tTraining epoch: 0/580\tElapsed time: 4.96min\tLearning rate: 8.947502168742003e-07\n", | ||
| 597 | + " Acc@1 : 87.500%\n", | ||
| 598 | + " Loss : 0.6044453382492065\n", | ||
| 599 | + " FW Time : 20.771ms\n", | ||
| 600 | + " BW Time : 16.470ms\n", | ||
| 601 | + "\n", | ||
| 602 | + "[+] (Valid results) Valid step: 1599/2000\n", | ||
| 603 | + " Acc@1 : 99.200%\n", | ||
| 604 | + " Loss : 0.06833314895629883\n", | ||
| 605 | + "\n", | ||
| 606 | + "[+] Training step: 1600/2000\tTraining epoch: 0/580\tElapsed time: 5.12min\tLearning rate: 8.915481656581816e-07\n", | ||
| 607 | + " Acc@1 : 100.000%\n", | ||
| 608 | + " Loss : 0.1431320309638977\n", | ||
| 609 | + " FW Time : 20.891ms\n", | ||
| 610 | + " BW Time : 12.692ms\n", | ||
| 611 | + "\n", | ||
| 612 | + "[+] (Valid results) Valid step: 1649/2000\n", | ||
| 613 | + " Acc@1 : 99.600%\n", | ||
| 614 | + " Loss : 0.09469139575958252\n", | ||
| 615 | + "\n", | ||
| 616 | + "[+] Model saved\n", | ||
| 617 | + "\n", | ||
| 618 | + "[+] Training step: 1650/2000\tTraining epoch: 0/580\tElapsed time: 5.28min\tLearning rate: 8.8835757365368e-07\n", | ||
| 619 | + " Acc@1 : 87.500%\n", | ||
| 620 | + " Loss : 0.8603015542030334\n", | ||
| 621 | + " FW Time : 15.960ms\n", | ||
| 622 | + " BW Time : 16.599ms\n", | ||
| 623 | + "\n", | ||
| 624 | + "[+] (Valid results) Valid step: 1699/2000\n", | ||
| 625 | + " Acc@1 : 99.600%\n", | ||
| 626 | + " Loss : 0.061615824699401855\n", | ||
| 627 | + "\n", | ||
| 628 | + "[+] Model saved\n", | ||
| 629 | + "\n", | ||
| 630 | + "[+] Training step: 1700/2000\tTraining epoch: 0/580\tElapsed time: 5.44min\tLearning rate: 8.851783998515047e-07\n", | ||
| 631 | + " Acc@1 : 62.500%\n", | ||
| 632 | + " Loss : 1.4177227020263672\n", | ||
| 633 | + " FW Time : 14.728ms\n", | ||
| 634 | + " BW Time : 14.786ms\n", | ||
| 635 | + "\n", | ||
| 636 | + "[+] (Valid results) Valid step: 1749/2000\n", | ||
| 637 | + " Acc@1 : 99.600%\n", | ||
| 638 | + " Loss : 0.03962230682373047\n", | ||
| 639 | + "\n", | ||
| 640 | + "[+] Model saved\n", | ||
| 641 | + "\n", | ||
| 642 | + "[+] Training step: 1750/2000\tTraining epoch: 0/580\tElapsed time: 5.61min\tLearning rate: 8.820106033892254e-07\n", | ||
| 643 | + " Acc@1 : 100.000%\n", | ||
| 644 | + " Loss : 0.13233846426010132\n", | ||
| 645 | + " FW Time : 15.973ms\n", | ||
| 646 | + " BW Time : 13.815ms\n", | ||
| 647 | + "\n", | ||
| 648 | + "[+] (Valid results) Valid step: 1799/2000\n", | ||
| 649 | + " Acc@1 : 99.800%\n", | ||
| 650 | + " Loss : 0.044447898864746094\n", | ||
| 651 | + "\n", | ||
| 652 | + "[+] Model saved\n", | ||
| 653 | + "\n", | ||
| 654 | + "[+] Training step: 1800/2000\tTraining epoch: 0/580\tElapsed time: 5.77min\tLearning rate: 8.788541435506462e-07\n", | ||
| 655 | + " Acc@1 : 100.000%\n", | ||
| 656 | + " Loss : 0.08350610733032227\n", | ||
| 657 | + " FW Time : 18.745ms\n", | ||
| 658 | + " BW Time : 13.062ms\n", | ||
| 659 | + "\n", | ||
| 660 | + "[+] (Valid results) Valid step: 1849/2000\n", | ||
| 661 | + " Acc@1 : 100.000%\n", | ||
| 662 | + " Loss : 0.06921112537384033\n", | ||
| 663 | + "\n", | ||
| 664 | + "[+] Model saved\n", | ||
| 665 | + "\n", | ||
| 666 | + "[+] Training step: 1850/2000\tTraining epoch: 0/580\tElapsed time: 5.93min\tLearning rate: 8.757089797652821e-07\n", | ||
| 667 | + " Acc@1 : 100.000%\n", | ||
| 668 | + " Loss : 0.14804929494857788\n", | ||
| 669 | + " FW Time : 41.547ms\n", | ||
| 670 | + " BW Time : 15.775ms\n", | ||
| 671 | + "\n", | ||
| 672 | + "[+] (Valid results) Valid step: 1899/2000\n", | ||
| 673 | + " Acc@1 : 100.000%\n", | ||
| 674 | + " Loss : 0.06944191455841064\n", | ||
| 675 | + "\n", | ||
| 676 | + "[+] Model saved\n", | ||
| 677 | + "\n", | ||
| 678 | + "[+] Training step: 1900/2000\tTraining epoch: 0/580\tElapsed time: 6.09min\tLearning rate: 8.725750716078392e-07\n", | ||
| 679 | + " Acc@1 : 100.000%\n", | ||
| 680 | + " Loss : 0.027304232120513916\n", | ||
| 681 | + " FW Time : 21.632ms\n", | ||
| 682 | + " BW Time : 17.365ms\n", | ||
| 683 | + "\n", | ||
| 684 | + "[+] (Valid results) Valid step: 1949/2000\n", | ||
| 685 | + " Acc@1 : 100.000%\n", | ||
| 686 | + " Loss : 0.05875754356384277\n", | ||
| 687 | + "\n", | ||
| 688 | + "[+] Model saved\n", | ||
| 689 | + "\n", | ||
| 690 | + "[+] Training step: 1950/2000\tTraining epoch: 0/580\tElapsed time: 6.25min\tLearning rate: 8.694523787976934e-07\n", | ||
| 691 | + " Acc@1 : 100.000%\n", | ||
| 692 | + " Loss : 0.13626277446746826\n", | ||
| 693 | + " FW Time : 40.104ms\n", | ||
| 694 | + " BW Time : 14.362ms\n", | ||
| 695 | + "\n", | ||
| 696 | + "[+] (Valid results) Valid step: 1999/2000\n", | ||
| 697 | + " Acc@1 : 100.000%\n", | ||
| 698 | + " Loss : 0.0167849063873291\n", | ||
| 699 | + "\n", | ||
| 700 | + "[+] Model saved\n" | ||
| 701 | + ], | ||
| 702 | + "name": "stdout" | ||
| 703 | + } | ||
| 704 | + ] | ||
| 705 | + }, | ||
| 706 | + { | ||
| 707 | + "cell_type": "code", | ||
| 708 | + "metadata": { | ||
| 709 | + "id": "3opAMwCLZYJC", | ||
| 710 | + "colab_type": "code", | ||
| 711 | + "outputId": "208d8102-ce39-4897-9f89-ed7f8f626c33", | ||
| 712 | + "colab": { | ||
| 713 | + "base_uri": "https://localhost:8080/", | ||
| 714 | + "height": 350 | ||
| 715 | + } | ||
| 716 | + }, | ||
| 717 | + "source": [ | ||
| 718 | + "!python \"eval.py\" --model_path='/content/drive/My Drive/CD2 Project/runs/classify/April_26_13:35:05__resnet50__None/' " | ||
| 719 | + ], | ||
| 720 | + "execution_count": 10, | ||
| 721 | + "outputs": [ | ||
| 722 | + { | ||
| 723 | + "output_type": "stream", | ||
| 724 | + "text": [ | ||
| 725 | + "\n", | ||
| 726 | + "[+] Parse arguments\n", | ||
| 727 | + "Args(augment_path=None, batch_size=8, dataset='BraTS', learning_rate=1e-06, max_step=2000, network='resnet50', num_workers=4, optimizer='adam', print_step=50, scheduler='exp', seed=None, start_step=0, use_cuda=True, val_step=50)\n", | ||
| 728 | + "\n", | ||
| 729 | + "[+] Create network\n", | ||
| 730 | + "\n", | ||
| 731 | + "[+] Load model\n", | ||
| 732 | + "\n", | ||
| 733 | + "[+] Load dataset\n", | ||
| 734 | + "\n", | ||
| 735 | + "[+] Start testing\n", | ||
| 736 | + "2020-04-26 05:02:26.970817: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", | ||
| 737 | + "\n", | ||
| 738 | + "[+] Valid results\n", | ||
| 739 | + " Acc@1 : 100.000%\n", | ||
| 740 | + " Loss : 0.020\n", | ||
| 741 | + " Infer Time(per image) : 4.062ms\n" | ||
| 742 | + ], | ||
| 743 | + "name": "stdout" | ||
| 744 | + } | ||
| 745 | + ] | ||
| 746 | + }, | ||
| 747 | + { | ||
| 748 | + "cell_type": "code", | ||
| 749 | + "metadata": { | ||
| 750 | + "id": "F70Y3J9DHJwy", | ||
| 751 | + "colab_type": "code", | ||
| 752 | + "outputId": "502a62a8-de44-42c8-ec77-afc838a068c8", | ||
| 753 | + "colab": { | ||
| 754 | + "base_uri": "https://localhost:8080/", | ||
| 755 | + "height": 1000 | ||
| 756 | + } | ||
| 757 | + }, | ||
| 758 | + "source": [ | ||
| 759 | + "# train cifar10\n", | ||
| 760 | + "# resnet50 from pytorch\n", | ||
| 761 | + "# !python train.py --dataset=cifar10 --use_cuda=True --optimizer=adam --network=resnet50" | ||
| 762 | + ], | ||
| 763 | + "execution_count": 0, | ||
| 764 | + "outputs": [ | ||
| 765 | + { | ||
| 766 | + "output_type": "stream", | ||
| 767 | + "text": [ | ||
| 768 | + "\n", | ||
| 769 | + "[+] Parse arguments\n", | ||
| 770 | + "Args(augment_path=None, batch_size=32, dataset='cifar10', learning_rate=0.001, max_step=500, network='resnet50', num_workers=4, optimizer='adam', print_step=100, scheduler='exp', seed=None, start_step=0, use_cuda=True, val_step=100)\n", | ||
| 771 | + "\n", | ||
| 772 | + "[+] Create log dir\n", | ||
| 773 | + "\n", | ||
| 774 | + "[+] Create network\n", | ||
| 775 | + "\n", | ||
| 776 | + "[+] Load dataset\n", | ||
| 777 | + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/drive/My Drive/CD2 Project/cifar-10-python.tar.gz\n", | ||
| 778 | + "171MB [00:02, 70.7MB/s] \n", | ||
| 779 | + "Files already downloaded and verified\n", | ||
| 780 | + "\n", | ||
| 781 | + "[+] Start training\n", | ||
| 782 | + "\n", | ||
| 783 | + "[+] Use 1 GPUs\n", | ||
| 784 | + "\n", | ||
| 785 | + "[+] Using GPU: Tesla P100-PCIE-16GB \n", | ||
| 786 | + "\n", | ||
| 787 | + "[+] Training step: 0/500\tTraining epoch: 0/1406\tElapsed time: 0.01min\tLearning rate: 0.0009999283\n", | ||
| 788 | + " Acc@1 : 0.000%\n", | ||
| 789 | + " Loss : 6.797711372375488\n", | ||
| 790 | + " FW Time : 63.167ms\n", | ||
| 791 | + " BW Time : 14.364ms\n", | ||
| 792 | + "\n", | ||
| 793 | + "[+] Valid results\n", | ||
| 794 | + " Acc@1 : 22.620%\n", | ||
| 795 | + " Loss : 2.110281229019165\n", | ||
| 796 | + "\n", | ||
| 797 | + "[+] Model saved\n", | ||
| 798 | + "\n", | ||
| 799 | + "[+] Training step: 100/500\tTraining epoch: 0/1406\tElapsed time: 0.20min\tLearning rate: 0.000992784200174764\n", | ||
| 800 | + " Acc@1 : 12.500%\n", | ||
| 801 | + " Loss : 2.331458568572998\n", | ||
| 802 | + " FW Time : 28.192ms\n", | ||
| 803 | + " BW Time : 17.216ms\n", | ||
| 804 | + "\n", | ||
| 805 | + "[+] Valid results\n", | ||
| 806 | + " Acc@1 : 27.320%\n", | ||
| 807 | + " Loss : 1.5727815628051758\n", | ||
| 808 | + "\n", | ||
| 809 | + "[+] Model saved\n", | ||
| 810 | + "\n", | ||
| 811 | + "[+] Training step: 200/500\tTraining epoch: 0/1406\tElapsed time: 0.36min\tLearning rate: 0.0009856911421715388\n", | ||
| 812 | + " Acc@1 : 25.000%\n", | ||
| 813 | + " Loss : 1.8888376951217651\n", | ||
| 814 | + " FW Time : 26.493ms\n", | ||
| 815 | + " BW Time : 17.836ms\n", | ||
| 816 | + "\n", | ||
| 817 | + "[+] Valid results\n", | ||
| 818 | + " Acc@1 : 30.800%\n", | ||
| 819 | + " Loss : 1.4916565418243408\n", | ||
| 820 | + "\n", | ||
| 821 | + "[+] Model saved\n", | ||
| 822 | + "\n", | ||
| 823 | + "[+] Training step: 300/500\tTraining epoch: 0/1406\tElapsed time: 0.53min\tLearning rate: 0.0009786487613163062\n", | ||
| 824 | + " Acc@1 : 25.000%\n", | ||
| 825 | + " Loss : 2.3930575847625732\n", | ||
| 826 | + " FW Time : 24.473ms\n", | ||
| 827 | + " BW Time : 22.187ms\n", | ||
| 828 | + "\n", | ||
| 829 | + "[+] Valid results\n", | ||
| 830 | + " Acc@1 : 24.680%\n", | ||
| 831 | + " Loss : 1.992270827293396\n", | ||
| 832 | + "\n", | ||
| 833 | + "[+] Training step: 400/500\tTraining epoch: 0/1406\tElapsed time: 0.70min\tLearning rate: 0.000971656695540503\n", | ||
| 834 | + " Acc@1 : 25.000%\n", | ||
| 835 | + " Loss : 2.1769983768463135\n", | ||
| 836 | + " FW Time : 28.696ms\n", | ||
| 837 | + " BW Time : 22.384ms\n", | ||
| 838 | + "\n", | ||
| 839 | + "[+] Valid results\n", | ||
| 840 | + " Acc@1 : 32.320%\n", | ||
| 841 | + " Loss : 1.5474934577941895\n", | ||
| 842 | + "\n", | ||
| 843 | + "[+] Model saved\n" | ||
| 844 | + ], | ||
| 845 | + "name": "stdout" | ||
| 846 | + } | ||
| 847 | + ] | ||
| 848 | + }, | ||
| 849 | + { | ||
| 850 | + "cell_type": "code", | ||
| 851 | + "metadata": { | ||
| 852 | + "id": "Mhw6fBwCpHRd", | ||
| 853 | + "colab_type": "code", | ||
| 854 | + "colab": {} | ||
| 855 | + }, | ||
| 856 | + "source": [ | ||
| 857 | + "" | ||
| 858 | + ], | ||
| 859 | + "execution_count": 0, | ||
| 860 | + "outputs": [] | ||
| 861 | + } | ||
| 862 | + ] | ||
| 863 | +} | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
code/classifier/utils/util.py
deleted
100644 → 0
| 1 | -import os | ||
| 2 | -import time | ||
| 3 | -import importlib | ||
| 4 | -import collections | ||
| 5 | -import pickle as cp | ||
| 6 | -import glob | ||
| 7 | -import numpy as np | ||
| 8 | -import pandas as pd | ||
| 9 | - | ||
| 10 | -from natsort import natsorted | ||
| 11 | -from PIL import Image | ||
| 12 | -import torch | ||
| 13 | -import torchvision | ||
| 14 | -import torch.nn.functional as F | ||
| 15 | -import torchvision.models as models | ||
| 16 | -import torchvision.transforms as transforms | ||
| 17 | -from torch.utils.data import Subset | ||
| 18 | -from torch.utils.data import Dataset, DataLoader | ||
| 19 | - | ||
| 20 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
| 21 | -from sklearn.model_selection import train_test_split | ||
| 22 | -from sklearn.model_selection import KFold | ||
| 23 | - | ||
| 24 | -from networks import * | ||
| 25 | - | ||
| 26 | - | ||
| 27 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
| 28 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
| 29 | -# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/' | ||
| 30 | -# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv' | ||
| 31 | - | ||
| 32 | -current_epoch = 0 | ||
| 33 | - | ||
| 34 | - | ||
| 35 | -def split_dataset(args, dataset, k): | ||
| 36 | - # load dataset | ||
| 37 | - X = list(range(len(dataset))) | ||
| 38 | - Y = dataset.targets | ||
| 39 | - | ||
| 40 | - # split to k-fold | ||
| 41 | - assert len(X) == len(Y) | ||
| 42 | - | ||
| 43 | - def _it_to_list(_it): | ||
| 44 | - return list(zip(*list(_it))) | ||
| 45 | - | ||
| 46 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
| 47 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
| 48 | - | ||
| 49 | - return Dm_indexes, Da_indexes | ||
| 50 | - | ||
| 51 | - | ||
| 52 | - | ||
| 53 | -def get_model_name(args): | ||
| 54 | - from datetime import datetime, timedelta, timezone | ||
| 55 | - now = datetime.now(timezone.utc) | ||
| 56 | - tz = timezone(timedelta(hours=9)) | ||
| 57 | - now = now.astimezone(tz) | ||
| 58 | - date_time = now.strftime("%B_%d_%H:%M:%S") | ||
| 59 | - model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
| 60 | - return model_name | ||
| 61 | - | ||
| 62 | - | ||
| 63 | -def dict_to_namedtuple(d): | ||
| 64 | - Args = collections.namedtuple('Args', sorted(d.keys())) | ||
| 65 | - | ||
| 66 | - for k,v in d.items(): | ||
| 67 | - if type(v) is dict: | ||
| 68 | - d[k] = dict_to_namedtuple(v) | ||
| 69 | - | ||
| 70 | - elif type(v) is str: | ||
| 71 | - try: | ||
| 72 | - d[k] = eval(v) | ||
| 73 | - except: | ||
| 74 | - d[k] = v | ||
| 75 | - | ||
| 76 | - args = Args(**d) | ||
| 77 | - return args | ||
| 78 | - | ||
| 79 | - | ||
| 80 | -def parse_args(kwargs): | ||
| 81 | - # combine with default args | ||
| 82 | - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
| 83 | - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
| 84 | - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
| 85 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001 | ||
| 86 | - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
| 87 | - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
| 88 | - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
| 89 | - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
| 90 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | ||
| 91 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | ||
| 92 | - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
| 93 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
| 94 | - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
| 95 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | ||
| 96 | - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
| 97 | - | ||
| 98 | - # to named tuple | ||
| 99 | - args = dict_to_namedtuple(kwargs) | ||
| 100 | - return args, kwargs | ||
| 101 | - | ||
| 102 | - | ||
| 103 | -def select_model(args): | ||
| 104 | - # grayResNet2 | ||
| 105 | - resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
| 106 | - 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
| 107 | - | ||
| 108 | - if args.network in resnet_dict: | ||
| 109 | - backbone = resnet_dict[args.network] | ||
| 110 | - model = basenet.BaseNet(backbone, args) | ||
| 111 | - else: | ||
| 112 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
| 113 | - model = Net(args) | ||
| 114 | - | ||
| 115 | - #print(model) # print model architecture | ||
| 116 | - return model | ||
| 117 | - | ||
| 118 | - | ||
| 119 | -def select_optimizer(args, model): | ||
| 120 | - if args.optimizer == 'sgd': | ||
| 121 | - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
| 122 | - elif args.optimizer == 'rms': | ||
| 123 | - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
| 124 | - elif args.optimizer == 'adam': | ||
| 125 | - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
| 126 | - else: | ||
| 127 | - raise Exception('Unknown Optimizer') | ||
| 128 | - return optimizer | ||
| 129 | - | ||
| 130 | - | ||
| 131 | -def select_scheduler(args, optimizer): | ||
| 132 | - if not args.scheduler or args.scheduler == 'None': | ||
| 133 | - return None | ||
| 134 | - elif args.scheduler =='clr': | ||
| 135 | - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
| 136 | - elif args.scheduler =='exp': | ||
| 137 | - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
| 138 | - else: | ||
| 139 | - raise Exception('Unknown Scheduler') | ||
| 140 | - | ||
| 141 | - | ||
| 142 | -class CustomDataset(Dataset): | ||
| 143 | - def __init__(self, data_path, csv_path): | ||
| 144 | - self.len = len(self.imgs) | ||
| 145 | - self.path = data_path | ||
| 146 | - self.imgs = natsorted(os.listdir(data_path)) | ||
| 147 | - | ||
| 148 | - df = pd.read_csv(csv_path) | ||
| 149 | - targets_list = [] | ||
| 150 | - | ||
| 151 | - for fname in self.imgs: | ||
| 152 | - row = df.loc[df['filename'] == fname] | ||
| 153 | - targets_list.append(row.iloc[0, 1]) | ||
| 154 | - | ||
| 155 | - self.targets = targets_list | ||
| 156 | - | ||
| 157 | - def __len__(self): | ||
| 158 | - return self.len | ||
| 159 | - | ||
| 160 | - def __getitem__(self, idx): | ||
| 161 | - img_loc = os.path.join(self.path, self.imgs[idx]) | ||
| 162 | - targets = self.targets[idx] | ||
| 163 | - image = Image.open(img_loc) | ||
| 164 | - return image, targets | ||
| 165 | - | ||
| 166 | - | ||
| 167 | - | ||
| 168 | -def get_dataset(args, transform, split='train'): | ||
| 169 | - assert split in ['train', 'val', 'test', 'trainval'] | ||
| 170 | - | ||
| 171 | - if split in ['train']: | ||
| 172 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | ||
| 173 | - else: #test | ||
| 174 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | ||
| 175 | - | ||
| 176 | - return dataset | ||
| 177 | - | ||
| 178 | - | ||
| 179 | -def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 180 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
| 181 | - batch_size=args.batch_size, | ||
| 182 | - shuffle=shuffle, | ||
| 183 | - num_workers=args.num_workers, | ||
| 184 | - pin_memory=pin_memory) | ||
| 185 | - return data_loader | ||
| 186 | - | ||
| 187 | - | ||
| 188 | -def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 189 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
| 190 | - batch_size=args.batch_size, | ||
| 191 | - shuffle=shuffle, | ||
| 192 | - num_workers=args.num_workers, | ||
| 193 | - pin_memory=pin_memory) | ||
| 194 | - return data_loader | ||
| 195 | - | ||
| 196 | - | ||
| 197 | -def get_inf_dataloader(args, dataset): | ||
| 198 | - global current_epoch | ||
| 199 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 200 | - | ||
| 201 | - while True: | ||
| 202 | - try: | ||
| 203 | - batch = next(data_loader) | ||
| 204 | - | ||
| 205 | - except StopIteration: | ||
| 206 | - current_epoch += 1 | ||
| 207 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 208 | - batch = next(data_loader) | ||
| 209 | - | ||
| 210 | - yield batch | ||
| 211 | - | ||
| 212 | - | ||
| 213 | - | ||
| 214 | - | ||
| 215 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
| 216 | - model.train() | ||
| 217 | - images, target = batch | ||
| 218 | - | ||
| 219 | - if device: | ||
| 220 | - images = images.to(device) | ||
| 221 | - target = target.to(device) | ||
| 222 | - | ||
| 223 | - elif args.use_cuda: | ||
| 224 | - images = images.cuda(non_blocking=True) | ||
| 225 | - target = target.cuda(non_blocking=True) | ||
| 226 | - | ||
| 227 | - # compute output | ||
| 228 | - start_t = time.time() | ||
| 229 | - output, first = model(images) | ||
| 230 | - forward_t = time.time() - start_t | ||
| 231 | - loss = criterion(output, target) | ||
| 232 | - | ||
| 233 | - # measure accuracy and record loss | ||
| 234 | - acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 235 | - acc1 /= images.size(0) | ||
| 236 | - acc5 /= images.size(0) | ||
| 237 | - | ||
| 238 | - # compute gradient and do SGD step | ||
| 239 | - optimizer.zero_grad() | ||
| 240 | - start_t = time.time() | ||
| 241 | - loss.backward() | ||
| 242 | - backward_t = time.time() - start_t | ||
| 243 | - optimizer.step() | ||
| 244 | - if scheduler: scheduler.step() | ||
| 245 | - | ||
| 246 | - if writer and step % args.print_step == 0: | ||
| 247 | - n_imgs = min(images.size(0), 10) | ||
| 248 | - tag = 'train/' + str(step) | ||
| 249 | - for j in range(n_imgs): | ||
| 250 | - writer.add_image(tag, | ||
| 251 | - concat_image_features(images[j], first[j]), global_step=step) | ||
| 252 | - | ||
| 253 | - return acc1, acc5, loss, forward_t, backward_t | ||
| 254 | - | ||
| 255 | - | ||
| 256 | -#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 257 | -def accuracy(output, target, topk=(1,)): | ||
| 258 | - """Computes the accuracy over the k top predictions for the specified values of k""" | ||
| 259 | - with torch.no_grad(): | ||
| 260 | - maxk = max(topk) | ||
| 261 | - batch_size = target.size(0) | ||
| 262 | - | ||
| 263 | - _, pred = output.topk(maxk, 1, True, True) | ||
| 264 | - pred = pred.t() | ||
| 265 | - correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 266 | - | ||
| 267 | - res = [] | ||
| 268 | - for k in topk: | ||
| 269 | - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
| 270 | - res.append(correct_k) | ||
| 271 | - return res | ||
| 272 | - |
-
Please register or login to post a comment