Showing
5 changed files
with
1071 additions
and
272 deletions
code/FAA2_VM/getAugmented_1.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +from torch.utils.tensorboard import SummaryWriter | ||
10 | + | ||
11 | +from utils import * | ||
12 | + | ||
13 | +# command | ||
14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
15 | + | ||
16 | +def eval(model_path): | ||
17 | + print('\n[+] Parse arguments') | ||
18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
20 | + args, kwargs = parse_args(kwargs) | ||
21 | + pprint(args) | ||
22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
23 | + | ||
24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
25 | + | ||
26 | + writer = SummaryWriter(log_dir=model_path) | ||
27 | + | ||
28 | + | ||
29 | + print('\n[+] Load transform') | ||
30 | + # list | ||
31 | + with open(cp_path, 'rb') as f: | ||
32 | + aug_transform_list = pickle.load(f) | ||
33 | + | ||
34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test')) | ||
35 | + | ||
36 | + | ||
37 | + print('\n[+] Load dataset') | ||
38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
39 | + dataset = get_dataset(args, aug_transform, 'test') | ||
40 | + | ||
41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
42 | + | ||
43 | + for i, (images, target) in enumerate(loader): | ||
44 | + images = images.view(240, 240) | ||
45 | + | ||
46 | + # concat image | ||
47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
48 | + | ||
49 | + if i % 1000 == 0: | ||
50 | + print("\n images size: ", augmented_image_list[i].size()) # [240, 240] | ||
51 | + | ||
52 | + break | ||
53 | + # break | ||
54 | + | ||
55 | + | ||
56 | + # print(augmented_image_list) | ||
57 | + | ||
58 | + | ||
59 | + print('\n[+] Write on tensorboard') | ||
60 | + if writer: | ||
61 | + for i, data in enumerate(augmented_image_list): | ||
62 | + tag = 'img/' + str(i) | ||
63 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
64 | + break | ||
65 | + | ||
66 | + writer.close() | ||
67 | + | ||
68 | + | ||
69 | + # if writer: | ||
70 | + # for j in range(): | ||
71 | + # tag = 'img/' + str(img_count) + '_' + str(j) | ||
72 | + # # writer.add_image(tag, | ||
73 | + # # concat_image_features(images[j], first[j]), global_step=step) | ||
74 | + # # if j > 0: | ||
75 | + # # fore = concat_image_features(fore, images[j]) | ||
76 | + | ||
77 | + # writer.add_image(tag, fore, global_step=0) | ||
78 | + # img_count = img_count + 1 | ||
79 | + | ||
80 | + # writer.close() | ||
81 | + | ||
82 | +if __name__ == '__main__': | ||
83 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_all.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +from torch.utils.tensorboard import SummaryWriter | ||
10 | + | ||
11 | +from utils import * | ||
12 | + | ||
13 | +# command | ||
14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
15 | + | ||
16 | +def eval(model_path): | ||
17 | + print('\n[+] Parse arguments') | ||
18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
20 | + args, kwargs = parse_args(kwargs) | ||
21 | + pprint(args) | ||
22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
23 | + | ||
24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
25 | + | ||
26 | + writer = SummaryWriter(log_dir=model_path) | ||
27 | + | ||
28 | + | ||
29 | + print('\n[+] Load transform') | ||
30 | + # list | ||
31 | + with open(cp_path, 'rb') as f: | ||
32 | + aug_transform_list = pickle.load(f) | ||
33 | + | ||
34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'train')) | ||
35 | + | ||
36 | + | ||
37 | + print('\n[+] Load dataset') | ||
38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
39 | + dataset = get_dataset(args, aug_transform, 'train') | ||
40 | + | ||
41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
42 | + | ||
43 | + for i, (images, target) in enumerate(loader): | ||
44 | + images = images.view(240, 240) | ||
45 | + | ||
46 | + # concat image | ||
47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
48 | + | ||
49 | + | ||
50 | + | ||
51 | + | ||
52 | + print('\n[+] Write on tensorboard') | ||
53 | + if writer: | ||
54 | + for i, data in enumerate(augmented_image_list): | ||
55 | + tag = 'img/' + str(i) | ||
56 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
57 | + | ||
58 | + writer.close() | ||
59 | + | ||
60 | + | ||
61 | +if __name__ == '__main__': | ||
62 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_saveimg.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | +import random | ||
7 | + | ||
8 | +import torch | ||
9 | +import torch.nn as nn | ||
10 | +from torchvision.utils import save_image | ||
11 | +from torch.utils.tensorboard import SummaryWriter | ||
12 | + | ||
13 | +from utils import * | ||
14 | + | ||
15 | +# command | ||
16 | +# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/' | ||
17 | + | ||
18 | +def eval(model_path): | ||
19 | + print('\n[+] Parse arguments') | ||
20 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
21 | + kwargs = json.loads(open(kwargs_path).read()) | ||
22 | + args, kwargs = parse_args(kwargs) | ||
23 | + pprint(args) | ||
24 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
25 | + | ||
26 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
27 | + | ||
28 | + writer = SummaryWriter(log_dir=model_path) | ||
29 | + | ||
30 | + | ||
31 | + print('\n[+] Load transform') | ||
32 | + # list to tensor | ||
33 | + with open(cp_path, 'rb') as f: | ||
34 | + aug_transform_list = pickle.load(f) | ||
35 | + | ||
36 | + transform = transforms.RandomChoice(aug_transform_list) | ||
37 | + | ||
38 | + | ||
39 | + print('\n[+] Load dataset') | ||
40 | + | ||
41 | + dataset = get_dataset(args, transform, 'train') | ||
42 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
43 | + | ||
44 | + | ||
45 | + print('\n[+] Save 1 random policy') | ||
46 | + os.makedirs(os.path.join(model_path, 'augmented_imgs')) | ||
47 | + save_dir = os.path.join(model_path, 'augmented_imgs') | ||
48 | + | ||
49 | + for i, (image, target) in enumerate(loader): | ||
50 | + image = image.view(240, 240) | ||
51 | + # save img | ||
52 | + save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png')) | ||
53 | + | ||
54 | + if(i % 100 == 0): | ||
55 | + print("\n saved images: ", i) | ||
56 | + | ||
57 | + print('\n[+] Finished to save') | ||
58 | + | ||
59 | +if __name__ == '__main__': | ||
60 | + fire.Fire(eval) | ||
61 | + | ||
62 | + | ||
63 | + |
code/classifier/classify_normal_lesion.ipynb
0 → 100644
1 | +{ | ||
2 | + "nbformat": 4, | ||
3 | + "nbformat_minor": 0, | ||
4 | + "metadata": { | ||
5 | + "colab": { | ||
6 | + "name": "classify normal/lesion.ipynb", | ||
7 | + "provenance": [], | ||
8 | + "collapsed_sections": [] | ||
9 | + }, | ||
10 | + "kernelspec": { | ||
11 | + "name": "python3", | ||
12 | + "display_name": "Python 3" | ||
13 | + }, | ||
14 | + "accelerator": "GPU" | ||
15 | + }, | ||
16 | + "cells": [ | ||
17 | + { | ||
18 | + "cell_type": "code", | ||
19 | + "metadata": { | ||
20 | + "id": "AjoTMXMCrFYX", | ||
21 | + "colab_type": "code", | ||
22 | + "outputId": "2548434e-72b7-4946-9748-070103017379", | ||
23 | + "colab": { | ||
24 | + "base_uri": "https://localhost:8080/", | ||
25 | + "height": 129 | ||
26 | + } | ||
27 | + }, | ||
28 | + "source": [ | ||
29 | + "from google.colab import drive\n", | ||
30 | + "drive.mount('/content/drive')" | ||
31 | + ], | ||
32 | + "execution_count": 1, | ||
33 | + "outputs": [ | ||
34 | + { | ||
35 | + "output_type": "stream", | ||
36 | + "text": [ | ||
37 | + "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", | ||
38 | + "\n", | ||
39 | + "Enter your authorization code:\n", | ||
40 | + "··········\n", | ||
41 | + "Mounted at /content/drive\n" | ||
42 | + ], | ||
43 | + "name": "stdout" | ||
44 | + } | ||
45 | + ] | ||
46 | + }, | ||
47 | + { | ||
48 | + "cell_type": "code", | ||
49 | + "metadata": { | ||
50 | + "id": "lXK8NZfIyzeB", | ||
51 | + "colab_type": "code", | ||
52 | + "outputId": "17508683-94fb-45fa-b9df-f7ddb9401b68", | ||
53 | + "colab": { | ||
54 | + "base_uri": "https://localhost:8080/", | ||
55 | + "height": 146 | ||
56 | + } | ||
57 | + }, | ||
58 | + "source": [ | ||
59 | + "!git clone http://khuhub.khu.ac.kr/2020-1-capstone-design2/2016104167.git" | ||
60 | + ], | ||
61 | + "execution_count": 2, | ||
62 | + "outputs": [ | ||
63 | + { | ||
64 | + "output_type": "stream", | ||
65 | + "text": [ | ||
66 | + "Cloning into '2016104167'...\n", | ||
67 | + "remote: Counting objects: 11451, done.\u001b[K\n", | ||
68 | + "remote: Compressing objects: 100% (39/39), done.\u001b[K\n", | ||
69 | + "remote: Total 11451 (delta 15), reused 0 (delta 0)\u001b[K\n", | ||
70 | + "Receiving objects: 100% (11451/11451), 292.82 MiB | 384.00 KiB/s, done.\n", | ||
71 | + "Resolving deltas: 100% (1109/1109), done.\n", | ||
72 | + "Checking out files: 100% (15684/15684), done.\n" | ||
73 | + ], | ||
74 | + "name": "stdout" | ||
75 | + } | ||
76 | + ] | ||
77 | + }, | ||
78 | + { | ||
79 | + "cell_type": "code", | ||
80 | + "metadata": { | ||
81 | + "id": "TmGc36H2y5sI", | ||
82 | + "colab_type": "code", | ||
83 | + "outputId": "49ce70f0-10bb-48d8-bd89-de98a28c7893", | ||
84 | + "colab": { | ||
85 | + "base_uri": "https://localhost:8080/", | ||
86 | + "height": 35 | ||
87 | + } | ||
88 | + }, | ||
89 | + "source": [ | ||
90 | + "%cd '2016104167/code/classifier/'" | ||
91 | + ], | ||
92 | + "execution_count": 3, | ||
93 | + "outputs": [ | ||
94 | + { | ||
95 | + "output_type": "stream", | ||
96 | + "text": [ | ||
97 | + "/content/2016104167/code/classifier\n" | ||
98 | + ], | ||
99 | + "name": "stdout" | ||
100 | + } | ||
101 | + ] | ||
102 | + }, | ||
103 | + { | ||
104 | + "cell_type": "code", | ||
105 | + "metadata": { | ||
106 | + "id": "oJ08JUJCzEEE", | ||
107 | + "colab_type": "code", | ||
108 | + "outputId": "f9454032-73ad-444e-b708-e1228f6b3a0b", | ||
109 | + "colab": { | ||
110 | + "base_uri": "https://localhost:8080/", | ||
111 | + "height": 1000 | ||
112 | + } | ||
113 | + }, | ||
114 | + "source": [ | ||
115 | + "!python -m pip install -r \"requirements.txt\"" | ||
116 | + ], | ||
117 | + "execution_count": 4, | ||
118 | + "outputs": [ | ||
119 | + { | ||
120 | + "output_type": "stream", | ||
121 | + "text": [ | ||
122 | + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 1)) (0.16.0)\n", | ||
123 | + "Collecting tb-nightly\n", | ||
124 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/4e/46/4b95936aed44f2154d936160de6c58e3bd4cf8501152db21945617c84694/tb_nightly-2.3.0a20200425-py3-none-any.whl (2.9MB)\n", | ||
125 | + "\u001b[K |████████████████████████████████| 2.9MB 39.9MB/s \n", | ||
126 | + "\u001b[?25hRequirement already satisfied: hyperopt in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 3)) (0.1.2)\n", | ||
127 | + "Collecting pillow==6.2.1\n", | ||
128 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/10/5c/0e94e689de2476c4c5e644a3bd223a1c1b9e2bdb7c510191750be74fa786/Pillow-6.2.1-cp36-cp36m-manylinux1_x86_64.whl (2.1MB)\n", | ||
129 | + "\u001b[K |████████████████████████████████| 2.1MB 48.1MB/s \n", | ||
130 | + "\u001b[?25hRequirement already satisfied: natsort in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 5)) (5.5.0)\n", | ||
131 | + "Collecting fire\n", | ||
132 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)\n", | ||
133 | + "\u001b[K |████████████████████████████████| 81kB 12.1MB/s \n", | ||
134 | + "\u001b[?25hCollecting torchvision==0.2.2\n", | ||
135 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ce/a1/66d72a2fe580a9f0fcbaaa5b976911fbbde9dce9b330ba12791997b856e9/torchvision-0.2.2-py2.py3-none-any.whl (64kB)\n", | ||
136 | + "\u001b[K |████████████████████████████████| 71kB 11.8MB/s \n", | ||
137 | + "\u001b[?25hCollecting torch==1.1.0\n", | ||
138 | + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/60/f685fb2cfb3088736bafbc9bdbb455327bdc8906b606da9c9a81bae1c81e/torch-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (676.9MB)\n", | ||
139 | + "\u001b[K |████████████████████████████████| 676.9MB 18kB/s \n", | ||
140 | + "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 9)) (1.0.3)\n", | ||
141 | + "Requirement already satisfied: sklearn in /usr/local/lib/python3.6/dist-packages (from -r requirements.txt (line 10)) (0.0)\n", | ||
142 | + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (3.2.1)\n", | ||
143 | + "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (3.10.0)\n", | ||
144 | + "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.7.2)\n", | ||
145 | + "Requirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.18.3)\n", | ||
146 | + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.6.0.post3)\n", | ||
147 | + "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (0.34.2)\n", | ||
148 | + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (0.9.0)\n", | ||
149 | + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (46.1.3)\n", | ||
150 | + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (2.21.0)\n", | ||
151 | + "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.12.0)\n", | ||
152 | + "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.0.1)\n", | ||
153 | + "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (1.28.1)\n", | ||
154 | + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly->-r requirements.txt (line 2)) (0.4.1)\n", | ||
155 | + "Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (2.4)\n", | ||
156 | + "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (4.38.0)\n", | ||
157 | + "Requirement already satisfied: pymongo in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (3.10.1)\n", | ||
158 | + "Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from hyperopt->-r requirements.txt (line 3)) (1.4.1)\n", | ||
159 | + "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->-r requirements.txt (line 6)) (1.1.0)\n", | ||
160 | + "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->-r requirements.txt (line 9)) (2.8.1)\n", | ||
161 | + "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->-r requirements.txt (line 9)) (2018.9)\n", | ||
162 | + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sklearn->-r requirements.txt (line 10)) (0.22.2.post1)\n", | ||
163 | + "Requirement already satisfied: rsa<4.1,>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (4.0)\n", | ||
164 | + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (0.2.8)\n", | ||
165 | + "Requirement already satisfied: cachetools<3.2,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (3.1.1)\n", | ||
166 | + "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (3.0.4)\n", | ||
167 | + "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (2.8)\n", | ||
168 | + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (2020.4.5.1)\n", | ||
169 | + "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tb-nightly->-r requirements.txt (line 2)) (1.24.3)\n", | ||
170 | + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly->-r requirements.txt (line 2)) (1.3.0)\n", | ||
171 | + "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->hyperopt->-r requirements.txt (line 3)) (4.4.2)\n", | ||
172 | + "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn->-r requirements.txt (line 10)) (0.14.1)\n", | ||
173 | + "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<4.1,>=3.1.4->google-auth<2,>=1.6.3->tb-nightly->-r requirements.txt (line 2)) (0.4.8)\n", | ||
174 | + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly->-r requirements.txt (line 2)) (3.1.0)\n", | ||
175 | + "Building wheels for collected packages: fire\n", | ||
176 | + " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | ||
177 | + " Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=090cc0b99c44969c5594966fd1af925b4ff73b02719d44b05fb4aacd07b9bfb3\n", | ||
178 | + " Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac\n", | ||
179 | + "Successfully built fire\n", | ||
180 | + "\u001b[31mERROR: torchvision 0.2.2 has requirement tqdm==4.19.9, but you'll have tqdm 4.38.0 which is incompatible.\u001b[0m\n", | ||
181 | + "\u001b[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.\u001b[0m\n", | ||
182 | + "Installing collected packages: tb-nightly, pillow, fire, torch, torchvision\n", | ||
183 | + " Found existing installation: Pillow 7.0.0\n", | ||
184 | + " Uninstalling Pillow-7.0.0:\n", | ||
185 | + " Successfully uninstalled Pillow-7.0.0\n", | ||
186 | + " Found existing installation: torch 1.4.0\n", | ||
187 | + " Uninstalling torch-1.4.0:\n", | ||
188 | + " Successfully uninstalled torch-1.4.0\n", | ||
189 | + " Found existing installation: torchvision 0.5.0\n", | ||
190 | + " Uninstalling torchvision-0.5.0:\n", | ||
191 | + " Successfully uninstalled torchvision-0.5.0\n", | ||
192 | + "Successfully installed fire-0.3.1 pillow-6.2.1 tb-nightly-2.3.0a20200425 torch-1.1.0 torchvision-0.2.2\n" | ||
193 | + ], | ||
194 | + "name": "stdout" | ||
195 | + } | ||
196 | + ] | ||
197 | + }, | ||
198 | + { | ||
199 | + "cell_type": "code", | ||
200 | + "metadata": { | ||
201 | + "id": "jdayOoSYHJDf", | ||
202 | + "colab_type": "code", | ||
203 | + "outputId": "f663313d-1166-4f6c-f617-d3d331ce7793", | ||
204 | + "colab": { | ||
205 | + "base_uri": "https://localhost:8080/", | ||
206 | + "height": 1000 | ||
207 | + } | ||
208 | + }, | ||
209 | + "source": [ | ||
210 | + "!python train.py --use_cuda=True --network=resnet50 --optimizer=adam " | ||
211 | + ], | ||
212 | + "execution_count": 9, | ||
213 | + "outputs": [ | ||
214 | + { | ||
215 | + "output_type": "stream", | ||
216 | + "text": [ | ||
217 | + "\n", | ||
218 | + "[+] Parse arguments\n", | ||
219 | + "Args(augment_path=None, batch_size=8, dataset='BraTS', learning_rate=1e-06, max_step=2000, network='resnet50', num_workers=4, optimizer='adam', print_step=50, scheduler='exp', seed=None, start_step=0, use_cuda=True, val_step=50)\n", | ||
220 | + "\n", | ||
221 | + "[+] Create log dir\n", | ||
222 | + "2020-04-26 04:35:05.510266: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", | ||
223 | + "\n", | ||
224 | + "[+] Create network\n", | ||
225 | + "\n", | ||
226 | + "[+] Load dataset\n", | ||
227 | + "\n", | ||
228 | + "[+] Start training\n", | ||
229 | + "\n", | ||
230 | + "[+] Use 1 GPUs\n", | ||
231 | + "\n", | ||
232 | + "[+] Using GPU: Tesla P4 \n", | ||
233 | + "\n", | ||
234 | + "[+] Training step: 0/2000\tTraining epoch: 0/580\tElapsed time: 0.01min\tLearning rate: 9.999283e-07\n", | ||
235 | + " Acc@1 : 0.000%\n", | ||
236 | + " Loss : 7.1226911544799805\n", | ||
237 | + " FW Time : 125.206ms\n", | ||
238 | + " BW Time : 14.758ms\n", | ||
239 | + "\n", | ||
240 | + "[+] (Valid results) Valid step: 49/2000\n", | ||
241 | + " Acc@1 : 0.000%\n", | ||
242 | + " Loss : 7.333343982696533\n", | ||
243 | + "\n", | ||
244 | + "[+] Model saved\n", | ||
245 | + "\n", | ||
246 | + "[+] Training step: 50/2000\tTraining epoch: 0/580\tElapsed time: 0.17min\tLearning rate: 9.963498469652168e-07\n", | ||
247 | + " Acc@1 : 0.000%\n", | ||
248 | + " Loss : 6.839444160461426\n", | ||
249 | + " FW Time : 15.798ms\n", | ||
250 | + " BW Time : 10.145ms\n", | ||
251 | + "\n", | ||
252 | + "[+] (Valid results) Valid step: 99/2000\n", | ||
253 | + " Acc@1 : 0.000%\n", | ||
254 | + " Loss : 7.0915045738220215\n", | ||
255 | + "\n", | ||
256 | + "[+] Model saved\n", | ||
257 | + "\n", | ||
258 | + "[+] Training step: 100/2000\tTraining epoch: 0/580\tElapsed time: 0.33min\tLearning rate: 9.92784200174764e-07\n", | ||
259 | + " Acc@1 : 0.000%\n", | ||
260 | + " Loss : 6.130648136138916\n", | ||
261 | + " FW Time : 15.954ms\n", | ||
262 | + " BW Time : 13.310ms\n", | ||
263 | + "\n", | ||
264 | + "[+] (Valid results) Valid step: 149/2000\n", | ||
265 | + " Acc@1 : 12.000%\n", | ||
266 | + " Loss : 6.664271354675293\n", | ||
267 | + "\n", | ||
268 | + "[+] Model saved\n", | ||
269 | + "\n", | ||
270 | + "[+] Training step: 150/2000\tTraining epoch: 0/580\tElapsed time: 0.49min\tLearning rate: 9.89231313798811e-07\n", | ||
271 | + " Acc@1 : 25.000%\n", | ||
272 | + " Loss : 5.820688247680664\n", | ||
273 | + " FW Time : 15.717ms\n", | ||
274 | + " BW Time : 12.132ms\n", | ||
275 | + "\n", | ||
276 | + "[+] (Valid results) Valid step: 199/2000\n", | ||
277 | + " Acc@1 : 38.400%\n", | ||
278 | + " Loss : 6.647609233856201\n", | ||
279 | + "\n", | ||
280 | + "[+] Model saved\n", | ||
281 | + "\n", | ||
282 | + "[+] Training step: 200/2000\tTraining epoch: 0/580\tElapsed time: 0.65min\tLearning rate: 9.85691142171539e-07\n", | ||
283 | + " Acc@1 : 62.500%\n", | ||
284 | + " Loss : 5.464842319488525\n", | ||
285 | + " FW Time : 18.873ms\n", | ||
286 | + " BW Time : 14.734ms\n", | ||
287 | + "\n", | ||
288 | + "[+] (Valid results) Valid step: 249/2000\n", | ||
289 | + " Acc@1 : 51.600%\n", | ||
290 | + " Loss : 6.3311767578125\n", | ||
291 | + "\n", | ||
292 | + "[+] Model saved\n", | ||
293 | + "\n", | ||
294 | + "[+] Training step: 250/2000\tTraining epoch: 0/580\tElapsed time: 0.80min\tLearning rate: 9.821636397905564e-07\n", | ||
295 | + " Acc@1 : 62.500%\n", | ||
296 | + " Loss : 5.102325439453125\n", | ||
297 | + " FW Time : 20.299ms\n", | ||
298 | + " BW Time : 13.266ms\n", | ||
299 | + "\n", | ||
300 | + "[+] (Valid results) Valid step: 299/2000\n", | ||
301 | + " Acc@1 : 65.000%\n", | ||
302 | + " Loss : 5.964324951171875\n", | ||
303 | + "\n", | ||
304 | + "[+] Model saved\n", | ||
305 | + "\n", | ||
306 | + "[+] Training step: 300/2000\tTraining epoch: 0/580\tElapsed time: 0.97min\tLearning rate: 9.786487613163075e-07\n", | ||
307 | + " Acc@1 : 75.000%\n", | ||
308 | + " Loss : 4.567136287689209\n", | ||
309 | + " FW Time : 22.482ms\n", | ||
310 | + " BW Time : 11.849ms\n", | ||
311 | + "\n", | ||
312 | + "[+] (Valid results) Valid step: 349/2000\n", | ||
313 | + " Acc@1 : 67.000%\n", | ||
314 | + " Loss : 5.935492038726807\n", | ||
315 | + "\n", | ||
316 | + "[+] Model saved\n", | ||
317 | + "\n", | ||
318 | + "[+] Training step: 350/2000\tTraining epoch: 0/580\tElapsed time: 1.13min\tLearning rate: 9.751464615714972e-07\n", | ||
319 | + " Acc@1 : 75.000%\n", | ||
320 | + " Loss : 4.528119087219238\n", | ||
321 | + " FW Time : 15.663ms\n", | ||
322 | + " BW Time : 12.312ms\n", | ||
323 | + "\n", | ||
324 | + "[+] (Valid results) Valid step: 399/2000\n", | ||
325 | + " Acc@1 : 66.200%\n", | ||
326 | + " Loss : 6.177009105682373\n", | ||
327 | + "\n", | ||
328 | + "[+] Training step: 400/2000\tTraining epoch: 0/580\tElapsed time: 1.28min\tLearning rate: 9.716566955405052e-07\n", | ||
329 | + " Acc@1 : 50.000%\n", | ||
330 | + " Loss : 5.060157775878906\n", | ||
331 | + " FW Time : 18.538ms\n", | ||
332 | + " BW Time : 15.068ms\n", | ||
333 | + "\n", | ||
334 | + "[+] (Valid results) Valid step: 449/2000\n", | ||
335 | + " Acc@1 : 70.400%\n", | ||
336 | + " Loss : 5.513169288635254\n", | ||
337 | + "\n", | ||
338 | + "[+] Model saved\n", | ||
339 | + "\n", | ||
340 | + "[+] Training step: 450/2000\tTraining epoch: 0/580\tElapsed time: 1.44min\tLearning rate: 9.681794183688074e-07\n", | ||
341 | + " Acc@1 : 62.500%\n", | ||
342 | + " Loss : 4.341537952423096\n", | ||
343 | + " FW Time : 20.281ms\n", | ||
344 | + " BW Time : 11.464ms\n", | ||
345 | + "\n", | ||
346 | + "[+] (Valid results) Valid step: 499/2000\n", | ||
347 | + " Acc@1 : 71.200%\n", | ||
348 | + " Loss : 5.211921691894531\n", | ||
349 | + "\n", | ||
350 | + "[+] Model saved\n", | ||
351 | + "\n", | ||
352 | + "[+] Training step: 500/2000\tTraining epoch: 0/580\tElapsed time: 1.60min\tLearning rate: 9.647145853624042e-07\n", | ||
353 | + " Acc@1 : 62.500%\n", | ||
354 | + " Loss : 3.643502712249756\n", | ||
355 | + " FW Time : 35.242ms\n", | ||
356 | + " BW Time : 13.355ms\n", | ||
357 | + "\n", | ||
358 | + "[+] (Valid results) Valid step: 549/2000\n", | ||
359 | + " Acc@1 : 73.000%\n", | ||
360 | + " Loss : 5.065425872802734\n", | ||
361 | + "\n", | ||
362 | + "[+] Model saved\n", | ||
363 | + "\n", | ||
364 | + "[+] Training step: 550/2000\tTraining epoch: 0/580\tElapsed time: 1.76min\tLearning rate: 9.61262151987242e-07\n", | ||
365 | + " Acc@1 : 37.500%\n", | ||
366 | + " Loss : 4.352468013763428\n", | ||
367 | + " FW Time : 15.172ms\n", | ||
368 | + " BW Time : 14.894ms\n", | ||
369 | + "\n", | ||
370 | + "[+] (Valid results) Valid step: 599/2000\n", | ||
371 | + " Acc@1 : 72.800%\n", | ||
372 | + " Loss : 5.094676971435547\n", | ||
373 | + "\n", | ||
374 | + "[+] Training step: 600/2000\tTraining epoch: 0/580\tElapsed time: 1.92min\tLearning rate: 9.578220738686398e-07\n", | ||
375 | + " Acc@1 : 62.500%\n", | ||
376 | + " Loss : 2.67336106300354\n", | ||
377 | + " FW Time : 16.094ms\n", | ||
378 | + " BW Time : 11.950ms\n", | ||
379 | + "\n", | ||
380 | + "[+] (Valid results) Valid step: 649/2000\n", | ||
381 | + " Acc@1 : 74.200%\n", | ||
382 | + " Loss : 4.8995866775512695\n", | ||
383 | + "\n", | ||
384 | + "[+] Model saved\n", | ||
385 | + "\n", | ||
386 | + "[+] Training step: 650/2000\tTraining epoch: 0/580\tElapsed time: 2.08min\tLearning rate: 9.543943067907226e-07\n", | ||
387 | + " Acc@1 : 87.500%\n", | ||
388 | + " Loss : 2.3277320861816406\n", | ||
389 | + " FW Time : 22.967ms\n", | ||
390 | + " BW Time : 16.513ms\n", | ||
391 | + "\n", | ||
392 | + "[+] (Valid results) Valid step: 699/2000\n", | ||
393 | + " Acc@1 : 78.000%\n", | ||
394 | + " Loss : 4.331630706787109\n", | ||
395 | + "\n", | ||
396 | + "[+] Model saved\n", | ||
397 | + "\n", | ||
398 | + "[+] Training step: 700/2000\tTraining epoch: 0/580\tElapsed time: 2.24min\tLearning rate: 9.509788066958503e-07\n", | ||
399 | + " Acc@1 : 50.000%\n", | ||
400 | + " Loss : 4.2444071769714355\n", | ||
401 | + " FW Time : 26.178ms\n", | ||
402 | + " BW Time : 20.318ms\n", | ||
403 | + "\n", | ||
404 | + "[+] (Valid results) Valid step: 749/2000\n", | ||
405 | + " Acc@1 : 79.800%\n", | ||
406 | + " Loss : 4.202180862426758\n", | ||
407 | + "\n", | ||
408 | + "[+] Model saved\n", | ||
409 | + "\n", | ||
410 | + "[+] Training step: 750/2000\tTraining epoch: 0/580\tElapsed time: 2.40min\tLearning rate: 9.475755296840536e-07\n", | ||
411 | + " Acc@1 : 75.000%\n", | ||
412 | + " Loss : 2.0681729316711426\n", | ||
413 | + " FW Time : 14.539ms\n", | ||
414 | + " BW Time : 9.137ms\n", | ||
415 | + "\n", | ||
416 | + "[+] (Valid results) Valid step: 799/2000\n", | ||
417 | + " Acc@1 : 81.800%\n", | ||
418 | + " Loss : 3.595407009124756\n", | ||
419 | + "\n", | ||
420 | + "[+] Model saved\n", | ||
421 | + "\n", | ||
422 | + "[+] Training step: 800/2000\tTraining epoch: 0/580\tElapsed time: 2.56min\tLearning rate: 9.441844320124666e-07\n", | ||
423 | + " Acc@1 : 87.500%\n", | ||
424 | + " Loss : 1.7526265382766724\n", | ||
425 | + " FW Time : 16.346ms\n", | ||
426 | + " BW Time : 13.979ms\n", | ||
427 | + "\n", | ||
428 | + "[+] (Valid results) Valid step: 849/2000\n", | ||
429 | + " Acc@1 : 82.800%\n", | ||
430 | + " Loss : 3.435865879058838\n", | ||
431 | + "\n", | ||
432 | + "[+] Model saved\n", | ||
433 | + "\n", | ||
434 | + "[+] Training step: 850/2000\tTraining epoch: 0/580\tElapsed time: 2.72min\tLearning rate: 9.408054700947673e-07\n", | ||
435 | + " Acc@1 : 100.000%\n", | ||
436 | + " Loss : 1.327768325805664\n", | ||
437 | + " FW Time : 20.280ms\n", | ||
438 | + " BW Time : 11.583ms\n", | ||
439 | + "\n", | ||
440 | + "[+] (Valid results) Valid step: 899/2000\n", | ||
441 | + " Acc@1 : 86.800%\n", | ||
442 | + " Loss : 3.0217819213867188\n", | ||
443 | + "\n", | ||
444 | + "[+] Model saved\n", | ||
445 | + "\n", | ||
446 | + "[+] Training step: 900/2000\tTraining epoch: 0/580\tElapsed time: 2.88min\tLearning rate: 9.37438600500616e-07\n", | ||
447 | + " Acc@1 : 87.500%\n", | ||
448 | + " Loss : 1.7023265361785889\n", | ||
449 | + " FW Time : 16.003ms\n", | ||
450 | + " BW Time : 11.639ms\n", | ||
451 | + "\n", | ||
452 | + "[+] (Valid results) Valid step: 949/2000\n", | ||
453 | + " Acc@1 : 89.600%\n", | ||
454 | + " Loss : 2.5056638717651367\n", | ||
455 | + "\n", | ||
456 | + "[+] Model saved\n", | ||
457 | + "\n", | ||
458 | + "[+] Training step: 950/2000\tTraining epoch: 0/580\tElapsed time: 3.04min\tLearning rate: 9.340837799550989e-07\n", | ||
459 | + " Acc@1 : 87.500%\n", | ||
460 | + " Loss : 2.0191428661346436\n", | ||
461 | + " FW Time : 26.238ms\n", | ||
462 | + " BW Time : 11.809ms\n", | ||
463 | + "\n", | ||
464 | + "[+] (Valid results) Valid step: 999/2000\n", | ||
465 | + " Acc@1 : 92.600%\n", | ||
466 | + " Loss : 1.8181736469268799\n", | ||
467 | + "\n", | ||
468 | + "[+] Model saved\n", | ||
469 | + "\n", | ||
470 | + "[+] Training step: 1000/2000\tTraining epoch: 0/580\tElapsed time: 3.20min\tLearning rate: 9.307409653381686e-07\n", | ||
471 | + " Acc@1 : 87.500%\n", | ||
472 | + " Loss : 0.9778107404708862\n", | ||
473 | + " FW Time : 18.434ms\n", | ||
474 | + " BW Time : 17.896ms\n", | ||
475 | + "\n", | ||
476 | + "[+] (Valid results) Valid step: 1049/2000\n", | ||
477 | + " Acc@1 : 94.400%\n", | ||
478 | + " Loss : 1.7727973461151123\n", | ||
479 | + "\n", | ||
480 | + "[+] Model saved\n", | ||
481 | + "\n", | ||
482 | + "[+] Training step: 1050/2000\tTraining epoch: 0/580\tElapsed time: 3.36min\tLearning rate: 9.274101136840937e-07\n", | ||
483 | + " Acc@1 : 75.000%\n", | ||
484 | + " Loss : 1.6472196578979492\n", | ||
485 | + " FW Time : 22.264ms\n", | ||
486 | + " BW Time : 11.575ms\n", | ||
487 | + "\n", | ||
488 | + "[+] (Valid results) Valid step: 1099/2000\n", | ||
489 | + " Acc@1 : 96.400%\n", | ||
490 | + " Loss : 1.3225497007369995\n", | ||
491 | + "\n", | ||
492 | + "[+] Model saved\n", | ||
493 | + "\n", | ||
494 | + "[+] Training step: 1100/2000\tTraining epoch: 0/580\tElapsed time: 3.52min\tLearning rate: 9.240911821809037e-07\n", | ||
495 | + " Acc@1 : 87.500%\n", | ||
496 | + " Loss : 1.520838737487793\n", | ||
497 | + " FW Time : 17.655ms\n", | ||
498 | + " BW Time : 11.358ms\n", | ||
499 | + "\n", | ||
500 | + "[+] (Valid results) Valid step: 1149/2000\n", | ||
501 | + " Acc@1 : 97.200%\n", | ||
502 | + " Loss : 1.0447773933410645\n", | ||
503 | + "\n", | ||
504 | + "[+] Model saved\n", | ||
505 | + "\n", | ||
506 | + "[+] Training step: 1150/2000\tTraining epoch: 0/580\tElapsed time: 3.69min\tLearning rate: 9.207841281698394e-07\n", | ||
507 | + " Acc@1 : 100.000%\n", | ||
508 | + " Loss : 0.45810914039611816\n", | ||
509 | + " FW Time : 21.189ms\n", | ||
510 | + " BW Time : 21.055ms\n", | ||
511 | + "\n", | ||
512 | + "[+] (Valid results) Valid step: 1199/2000\n", | ||
513 | + " Acc@1 : 97.200%\n", | ||
514 | + " Loss : 0.6157697439193726\n", | ||
515 | + "\n", | ||
516 | + "[+] Model saved\n", | ||
517 | + "\n", | ||
518 | + "[+] Training step: 1200/2000\tTraining epoch: 0/580\tElapsed time: 3.85min\tLearning rate: 9.174889091448058e-07\n", | ||
519 | + " Acc@1 : 100.000%\n", | ||
520 | + " Loss : 0.2990947961807251\n", | ||
521 | + " FW Time : 16.315ms\n", | ||
522 | + " BW Time : 13.309ms\n", | ||
523 | + "\n", | ||
524 | + "[+] (Valid results) Valid step: 1249/2000\n", | ||
525 | + " Acc@1 : 98.200%\n", | ||
526 | + " Loss : 0.9130725860595703\n", | ||
527 | + "\n", | ||
528 | + "[+] Model saved\n", | ||
529 | + "\n", | ||
530 | + "[+] Training step: 1250/2000\tTraining epoch: 0/580\tElapsed time: 4.01min\tLearning rate: 9.142054827518248e-07\n", | ||
531 | + " Acc@1 : 100.000%\n", | ||
532 | + " Loss : 0.2261216640472412\n", | ||
533 | + " FW Time : 30.053ms\n", | ||
534 | + " BW Time : 18.239ms\n", | ||
535 | + "\n", | ||
536 | + "[+] (Valid results) Valid step: 1299/2000\n", | ||
537 | + " Acc@1 : 98.000%\n", | ||
538 | + " Loss : 0.4017503261566162\n", | ||
539 | + "\n", | ||
540 | + "[+] Training step: 1300/2000\tTraining epoch: 0/580\tElapsed time: 4.17min\tLearning rate: 9.109338067884897e-07\n", | ||
541 | + " Acc@1 : 100.000%\n", | ||
542 | + " Loss : 0.3248733878135681\n", | ||
543 | + " FW Time : 21.094ms\n", | ||
544 | + " BW Time : 13.526ms\n", | ||
545 | + "\n", | ||
546 | + "[+] (Valid results) Valid step: 1349/2000\n", | ||
547 | + " Acc@1 : 98.000%\n", | ||
548 | + " Loss : 0.3138556480407715\n", | ||
549 | + "\n", | ||
550 | + "[+] Training step: 1350/2000\tTraining epoch: 0/580\tElapsed time: 4.32min\tLearning rate: 9.076738392034251e-07\n", | ||
551 | + " Acc@1 : 100.000%\n", | ||
552 | + " Loss : 0.21796566247940063\n", | ||
553 | + " FW Time : 25.850ms\n", | ||
554 | + " BW Time : 19.265ms\n", | ||
555 | + "\n", | ||
556 | + "[+] (Valid results) Valid step: 1399/2000\n", | ||
557 | + " Acc@1 : 98.000%\n", | ||
558 | + " Loss : 0.16485238075256348\n", | ||
559 | + "\n", | ||
560 | + "[+] Training step: 1400/2000\tTraining epoch: 0/580\tElapsed time: 4.48min\tLearning rate: 9.044255380957452e-07\n", | ||
561 | + " Acc@1 : 100.000%\n", | ||
562 | + " Loss : 0.28114575147628784\n", | ||
563 | + " FW Time : 22.245ms\n", | ||
564 | + " BW Time : 15.607ms\n", | ||
565 | + "\n", | ||
566 | + "[+] (Valid results) Valid step: 1449/2000\n", | ||
567 | + " Acc@1 : 98.200%\n", | ||
568 | + " Loss : 0.20785784721374512\n", | ||
569 | + "\n", | ||
570 | + "[+] Model saved\n", | ||
571 | + "\n", | ||
572 | + "[+] Training step: 1450/2000\tTraining epoch: 0/580\tElapsed time: 4.64min\tLearning rate: 9.011888617145144e-07\n", | ||
573 | + " Acc@1 : 87.500%\n", | ||
574 | + " Loss : 1.18643319606781\n", | ||
575 | + " FW Time : 19.665ms\n", | ||
576 | + " BW Time : 12.618ms\n", | ||
577 | + "\n", | ||
578 | + "[+] (Valid results) Valid step: 1499/2000\n", | ||
579 | + " Acc@1 : 98.400%\n", | ||
580 | + " Loss : 0.17331039905548096\n", | ||
581 | + "\n", | ||
582 | + "[+] Model saved\n", | ||
583 | + "\n", | ||
584 | + "[+] Training step: 1500/2000\tTraining epoch: 0/580\tElapsed time: 4.80min\tLearning rate: 8.979637684582136e-07\n", | ||
585 | + " Acc@1 : 100.000%\n", | ||
586 | + " Loss : 0.27240896224975586\n", | ||
587 | + " FW Time : 24.560ms\n", | ||
588 | + " BW Time : 14.096ms\n", | ||
589 | + "\n", | ||
590 | + "[+] (Valid results) Valid step: 1549/2000\n", | ||
591 | + " Acc@1 : 99.600%\n", | ||
592 | + " Loss : 0.17186295986175537\n", | ||
593 | + "\n", | ||
594 | + "[+] Model saved\n", | ||
595 | + "\n", | ||
596 | + "[+] Training step: 1550/2000\tTraining epoch: 0/580\tElapsed time: 4.96min\tLearning rate: 8.947502168742003e-07\n", | ||
597 | + " Acc@1 : 87.500%\n", | ||
598 | + " Loss : 0.6044453382492065\n", | ||
599 | + " FW Time : 20.771ms\n", | ||
600 | + " BW Time : 16.470ms\n", | ||
601 | + "\n", | ||
602 | + "[+] (Valid results) Valid step: 1599/2000\n", | ||
603 | + " Acc@1 : 99.200%\n", | ||
604 | + " Loss : 0.06833314895629883\n", | ||
605 | + "\n", | ||
606 | + "[+] Training step: 1600/2000\tTraining epoch: 0/580\tElapsed time: 5.12min\tLearning rate: 8.915481656581816e-07\n", | ||
607 | + " Acc@1 : 100.000%\n", | ||
608 | + " Loss : 0.1431320309638977\n", | ||
609 | + " FW Time : 20.891ms\n", | ||
610 | + " BW Time : 12.692ms\n", | ||
611 | + "\n", | ||
612 | + "[+] (Valid results) Valid step: 1649/2000\n", | ||
613 | + " Acc@1 : 99.600%\n", | ||
614 | + " Loss : 0.09469139575958252\n", | ||
615 | + "\n", | ||
616 | + "[+] Model saved\n", | ||
617 | + "\n", | ||
618 | + "[+] Training step: 1650/2000\tTraining epoch: 0/580\tElapsed time: 5.28min\tLearning rate: 8.8835757365368e-07\n", | ||
619 | + " Acc@1 : 87.500%\n", | ||
620 | + " Loss : 0.8603015542030334\n", | ||
621 | + " FW Time : 15.960ms\n", | ||
622 | + " BW Time : 16.599ms\n", | ||
623 | + "\n", | ||
624 | + "[+] (Valid results) Valid step: 1699/2000\n", | ||
625 | + " Acc@1 : 99.600%\n", | ||
626 | + " Loss : 0.061615824699401855\n", | ||
627 | + "\n", | ||
628 | + "[+] Model saved\n", | ||
629 | + "\n", | ||
630 | + "[+] Training step: 1700/2000\tTraining epoch: 0/580\tElapsed time: 5.44min\tLearning rate: 8.851783998515047e-07\n", | ||
631 | + " Acc@1 : 62.500%\n", | ||
632 | + " Loss : 1.4177227020263672\n", | ||
633 | + " FW Time : 14.728ms\n", | ||
634 | + " BW Time : 14.786ms\n", | ||
635 | + "\n", | ||
636 | + "[+] (Valid results) Valid step: 1749/2000\n", | ||
637 | + " Acc@1 : 99.600%\n", | ||
638 | + " Loss : 0.03962230682373047\n", | ||
639 | + "\n", | ||
640 | + "[+] Model saved\n", | ||
641 | + "\n", | ||
642 | + "[+] Training step: 1750/2000\tTraining epoch: 0/580\tElapsed time: 5.61min\tLearning rate: 8.820106033892254e-07\n", | ||
643 | + " Acc@1 : 100.000%\n", | ||
644 | + " Loss : 0.13233846426010132\n", | ||
645 | + " FW Time : 15.973ms\n", | ||
646 | + " BW Time : 13.815ms\n", | ||
647 | + "\n", | ||
648 | + "[+] (Valid results) Valid step: 1799/2000\n", | ||
649 | + " Acc@1 : 99.800%\n", | ||
650 | + " Loss : 0.044447898864746094\n", | ||
651 | + "\n", | ||
652 | + "[+] Model saved\n", | ||
653 | + "\n", | ||
654 | + "[+] Training step: 1800/2000\tTraining epoch: 0/580\tElapsed time: 5.77min\tLearning rate: 8.788541435506462e-07\n", | ||
655 | + " Acc@1 : 100.000%\n", | ||
656 | + " Loss : 0.08350610733032227\n", | ||
657 | + " FW Time : 18.745ms\n", | ||
658 | + " BW Time : 13.062ms\n", | ||
659 | + "\n", | ||
660 | + "[+] (Valid results) Valid step: 1849/2000\n", | ||
661 | + " Acc@1 : 100.000%\n", | ||
662 | + " Loss : 0.06921112537384033\n", | ||
663 | + "\n", | ||
664 | + "[+] Model saved\n", | ||
665 | + "\n", | ||
666 | + "[+] Training step: 1850/2000\tTraining epoch: 0/580\tElapsed time: 5.93min\tLearning rate: 8.757089797652821e-07\n", | ||
667 | + " Acc@1 : 100.000%\n", | ||
668 | + " Loss : 0.14804929494857788\n", | ||
669 | + " FW Time : 41.547ms\n", | ||
670 | + " BW Time : 15.775ms\n", | ||
671 | + "\n", | ||
672 | + "[+] (Valid results) Valid step: 1899/2000\n", | ||
673 | + " Acc@1 : 100.000%\n", | ||
674 | + " Loss : 0.06944191455841064\n", | ||
675 | + "\n", | ||
676 | + "[+] Model saved\n", | ||
677 | + "\n", | ||
678 | + "[+] Training step: 1900/2000\tTraining epoch: 0/580\tElapsed time: 6.09min\tLearning rate: 8.725750716078392e-07\n", | ||
679 | + " Acc@1 : 100.000%\n", | ||
680 | + " Loss : 0.027304232120513916\n", | ||
681 | + " FW Time : 21.632ms\n", | ||
682 | + " BW Time : 17.365ms\n", | ||
683 | + "\n", | ||
684 | + "[+] (Valid results) Valid step: 1949/2000\n", | ||
685 | + " Acc@1 : 100.000%\n", | ||
686 | + " Loss : 0.05875754356384277\n", | ||
687 | + "\n", | ||
688 | + "[+] Model saved\n", | ||
689 | + "\n", | ||
690 | + "[+] Training step: 1950/2000\tTraining epoch: 0/580\tElapsed time: 6.25min\tLearning rate: 8.694523787976934e-07\n", | ||
691 | + " Acc@1 : 100.000%\n", | ||
692 | + " Loss : 0.13626277446746826\n", | ||
693 | + " FW Time : 40.104ms\n", | ||
694 | + " BW Time : 14.362ms\n", | ||
695 | + "\n", | ||
696 | + "[+] (Valid results) Valid step: 1999/2000\n", | ||
697 | + " Acc@1 : 100.000%\n", | ||
698 | + " Loss : 0.0167849063873291\n", | ||
699 | + "\n", | ||
700 | + "[+] Model saved\n" | ||
701 | + ], | ||
702 | + "name": "stdout" | ||
703 | + } | ||
704 | + ] | ||
705 | + }, | ||
706 | + { | ||
707 | + "cell_type": "code", | ||
708 | + "metadata": { | ||
709 | + "id": "3opAMwCLZYJC", | ||
710 | + "colab_type": "code", | ||
711 | + "outputId": "208d8102-ce39-4897-9f89-ed7f8f626c33", | ||
712 | + "colab": { | ||
713 | + "base_uri": "https://localhost:8080/", | ||
714 | + "height": 350 | ||
715 | + } | ||
716 | + }, | ||
717 | + "source": [ | ||
718 | + "!python \"eval.py\" --model_path='/content/drive/My Drive/CD2 Project/runs/classify/April_26_13:35:05__resnet50__None/' " | ||
719 | + ], | ||
720 | + "execution_count": 10, | ||
721 | + "outputs": [ | ||
722 | + { | ||
723 | + "output_type": "stream", | ||
724 | + "text": [ | ||
725 | + "\n", | ||
726 | + "[+] Parse arguments\n", | ||
727 | + "Args(augment_path=None, batch_size=8, dataset='BraTS', learning_rate=1e-06, max_step=2000, network='resnet50', num_workers=4, optimizer='adam', print_step=50, scheduler='exp', seed=None, start_step=0, use_cuda=True, val_step=50)\n", | ||
728 | + "\n", | ||
729 | + "[+] Create network\n", | ||
730 | + "\n", | ||
731 | + "[+] Load model\n", | ||
732 | + "\n", | ||
733 | + "[+] Load dataset\n", | ||
734 | + "\n", | ||
735 | + "[+] Start testing\n", | ||
736 | + "2020-04-26 05:02:26.970817: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1\n", | ||
737 | + "\n", | ||
738 | + "[+] Valid results\n", | ||
739 | + " Acc@1 : 100.000%\n", | ||
740 | + " Loss : 0.020\n", | ||
741 | + " Infer Time(per image) : 4.062ms\n" | ||
742 | + ], | ||
743 | + "name": "stdout" | ||
744 | + } | ||
745 | + ] | ||
746 | + }, | ||
747 | + { | ||
748 | + "cell_type": "code", | ||
749 | + "metadata": { | ||
750 | + "id": "F70Y3J9DHJwy", | ||
751 | + "colab_type": "code", | ||
752 | + "outputId": "502a62a8-de44-42c8-ec77-afc838a068c8", | ||
753 | + "colab": { | ||
754 | + "base_uri": "https://localhost:8080/", | ||
755 | + "height": 1000 | ||
756 | + } | ||
757 | + }, | ||
758 | + "source": [ | ||
759 | + "# train cifar10\n", | ||
760 | + "# resnet50 from pytorch\n", | ||
761 | + "# !python train.py --dataset=cifar10 --use_cuda=True --optimizer=adam --network=resnet50" | ||
762 | + ], | ||
763 | + "execution_count": 0, | ||
764 | + "outputs": [ | ||
765 | + { | ||
766 | + "output_type": "stream", | ||
767 | + "text": [ | ||
768 | + "\n", | ||
769 | + "[+] Parse arguments\n", | ||
770 | + "Args(augment_path=None, batch_size=32, dataset='cifar10', learning_rate=0.001, max_step=500, network='resnet50', num_workers=4, optimizer='adam', print_step=100, scheduler='exp', seed=None, start_step=0, use_cuda=True, val_step=100)\n", | ||
771 | + "\n", | ||
772 | + "[+] Create log dir\n", | ||
773 | + "\n", | ||
774 | + "[+] Create network\n", | ||
775 | + "\n", | ||
776 | + "[+] Load dataset\n", | ||
777 | + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/drive/My Drive/CD2 Project/cifar-10-python.tar.gz\n", | ||
778 | + "171MB [00:02, 70.7MB/s] \n", | ||
779 | + "Files already downloaded and verified\n", | ||
780 | + "\n", | ||
781 | + "[+] Start training\n", | ||
782 | + "\n", | ||
783 | + "[+] Use 1 GPUs\n", | ||
784 | + "\n", | ||
785 | + "[+] Using GPU: Tesla P100-PCIE-16GB \n", | ||
786 | + "\n", | ||
787 | + "[+] Training step: 0/500\tTraining epoch: 0/1406\tElapsed time: 0.01min\tLearning rate: 0.0009999283\n", | ||
788 | + " Acc@1 : 0.000%\n", | ||
789 | + " Loss : 6.797711372375488\n", | ||
790 | + " FW Time : 63.167ms\n", | ||
791 | + " BW Time : 14.364ms\n", | ||
792 | + "\n", | ||
793 | + "[+] Valid results\n", | ||
794 | + " Acc@1 : 22.620%\n", | ||
795 | + " Loss : 2.110281229019165\n", | ||
796 | + "\n", | ||
797 | + "[+] Model saved\n", | ||
798 | + "\n", | ||
799 | + "[+] Training step: 100/500\tTraining epoch: 0/1406\tElapsed time: 0.20min\tLearning rate: 0.000992784200174764\n", | ||
800 | + " Acc@1 : 12.500%\n", | ||
801 | + " Loss : 2.331458568572998\n", | ||
802 | + " FW Time : 28.192ms\n", | ||
803 | + " BW Time : 17.216ms\n", | ||
804 | + "\n", | ||
805 | + "[+] Valid results\n", | ||
806 | + " Acc@1 : 27.320%\n", | ||
807 | + " Loss : 1.5727815628051758\n", | ||
808 | + "\n", | ||
809 | + "[+] Model saved\n", | ||
810 | + "\n", | ||
811 | + "[+] Training step: 200/500\tTraining epoch: 0/1406\tElapsed time: 0.36min\tLearning rate: 0.0009856911421715388\n", | ||
812 | + " Acc@1 : 25.000%\n", | ||
813 | + " Loss : 1.8888376951217651\n", | ||
814 | + " FW Time : 26.493ms\n", | ||
815 | + " BW Time : 17.836ms\n", | ||
816 | + "\n", | ||
817 | + "[+] Valid results\n", | ||
818 | + " Acc@1 : 30.800%\n", | ||
819 | + " Loss : 1.4916565418243408\n", | ||
820 | + "\n", | ||
821 | + "[+] Model saved\n", | ||
822 | + "\n", | ||
823 | + "[+] Training step: 300/500\tTraining epoch: 0/1406\tElapsed time: 0.53min\tLearning rate: 0.0009786487613163062\n", | ||
824 | + " Acc@1 : 25.000%\n", | ||
825 | + " Loss : 2.3930575847625732\n", | ||
826 | + " FW Time : 24.473ms\n", | ||
827 | + " BW Time : 22.187ms\n", | ||
828 | + "\n", | ||
829 | + "[+] Valid results\n", | ||
830 | + " Acc@1 : 24.680%\n", | ||
831 | + " Loss : 1.992270827293396\n", | ||
832 | + "\n", | ||
833 | + "[+] Training step: 400/500\tTraining epoch: 0/1406\tElapsed time: 0.70min\tLearning rate: 0.000971656695540503\n", | ||
834 | + " Acc@1 : 25.000%\n", | ||
835 | + " Loss : 2.1769983768463135\n", | ||
836 | + " FW Time : 28.696ms\n", | ||
837 | + " BW Time : 22.384ms\n", | ||
838 | + "\n", | ||
839 | + "[+] Valid results\n", | ||
840 | + " Acc@1 : 32.320%\n", | ||
841 | + " Loss : 1.5474934577941895\n", | ||
842 | + "\n", | ||
843 | + "[+] Model saved\n" | ||
844 | + ], | ||
845 | + "name": "stdout" | ||
846 | + } | ||
847 | + ] | ||
848 | + }, | ||
849 | + { | ||
850 | + "cell_type": "code", | ||
851 | + "metadata": { | ||
852 | + "id": "Mhw6fBwCpHRd", | ||
853 | + "colab_type": "code", | ||
854 | + "colab": {} | ||
855 | + }, | ||
856 | + "source": [ | ||
857 | + "" | ||
858 | + ], | ||
859 | + "execution_count": 0, | ||
860 | + "outputs": [] | ||
861 | + } | ||
862 | + ] | ||
863 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/classifier/utils/util.py
deleted
100644 → 0
1 | -import os | ||
2 | -import time | ||
3 | -import importlib | ||
4 | -import collections | ||
5 | -import pickle as cp | ||
6 | -import glob | ||
7 | -import numpy as np | ||
8 | -import pandas as pd | ||
9 | - | ||
10 | -from natsort import natsorted | ||
11 | -from PIL import Image | ||
12 | -import torch | ||
13 | -import torchvision | ||
14 | -import torch.nn.functional as F | ||
15 | -import torchvision.models as models | ||
16 | -import torchvision.transforms as transforms | ||
17 | -from torch.utils.data import Subset | ||
18 | -from torch.utils.data import Dataset, DataLoader | ||
19 | - | ||
20 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
21 | -from sklearn.model_selection import train_test_split | ||
22 | -from sklearn.model_selection import KFold | ||
23 | - | ||
24 | -from networks import * | ||
25 | - | ||
26 | - | ||
27 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
28 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
29 | -# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/' | ||
30 | -# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv' | ||
31 | - | ||
32 | -current_epoch = 0 | ||
33 | - | ||
34 | - | ||
35 | -def split_dataset(args, dataset, k): | ||
36 | - # load dataset | ||
37 | - X = list(range(len(dataset))) | ||
38 | - Y = dataset.targets | ||
39 | - | ||
40 | - # split to k-fold | ||
41 | - assert len(X) == len(Y) | ||
42 | - | ||
43 | - def _it_to_list(_it): | ||
44 | - return list(zip(*list(_it))) | ||
45 | - | ||
46 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
47 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
48 | - | ||
49 | - return Dm_indexes, Da_indexes | ||
50 | - | ||
51 | - | ||
52 | - | ||
53 | -def get_model_name(args): | ||
54 | - from datetime import datetime, timedelta, timezone | ||
55 | - now = datetime.now(timezone.utc) | ||
56 | - tz = timezone(timedelta(hours=9)) | ||
57 | - now = now.astimezone(tz) | ||
58 | - date_time = now.strftime("%B_%d_%H:%M:%S") | ||
59 | - model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
60 | - return model_name | ||
61 | - | ||
62 | - | ||
63 | -def dict_to_namedtuple(d): | ||
64 | - Args = collections.namedtuple('Args', sorted(d.keys())) | ||
65 | - | ||
66 | - for k,v in d.items(): | ||
67 | - if type(v) is dict: | ||
68 | - d[k] = dict_to_namedtuple(v) | ||
69 | - | ||
70 | - elif type(v) is str: | ||
71 | - try: | ||
72 | - d[k] = eval(v) | ||
73 | - except: | ||
74 | - d[k] = v | ||
75 | - | ||
76 | - args = Args(**d) | ||
77 | - return args | ||
78 | - | ||
79 | - | ||
80 | -def parse_args(kwargs): | ||
81 | - # combine with default args | ||
82 | - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
83 | - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
84 | - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
85 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001 | ||
86 | - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
87 | - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
88 | - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
89 | - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
90 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | ||
91 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | ||
92 | - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
93 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
94 | - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
95 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | ||
96 | - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
97 | - | ||
98 | - # to named tuple | ||
99 | - args = dict_to_namedtuple(kwargs) | ||
100 | - return args, kwargs | ||
101 | - | ||
102 | - | ||
103 | -def select_model(args): | ||
104 | - # grayResNet2 | ||
105 | - resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
106 | - 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
107 | - | ||
108 | - if args.network in resnet_dict: | ||
109 | - backbone = resnet_dict[args.network] | ||
110 | - model = basenet.BaseNet(backbone, args) | ||
111 | - else: | ||
112 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
113 | - model = Net(args) | ||
114 | - | ||
115 | - #print(model) # print model architecture | ||
116 | - return model | ||
117 | - | ||
118 | - | ||
119 | -def select_optimizer(args, model): | ||
120 | - if args.optimizer == 'sgd': | ||
121 | - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
122 | - elif args.optimizer == 'rms': | ||
123 | - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
124 | - elif args.optimizer == 'adam': | ||
125 | - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
126 | - else: | ||
127 | - raise Exception('Unknown Optimizer') | ||
128 | - return optimizer | ||
129 | - | ||
130 | - | ||
131 | -def select_scheduler(args, optimizer): | ||
132 | - if not args.scheduler or args.scheduler == 'None': | ||
133 | - return None | ||
134 | - elif args.scheduler =='clr': | ||
135 | - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
136 | - elif args.scheduler =='exp': | ||
137 | - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
138 | - else: | ||
139 | - raise Exception('Unknown Scheduler') | ||
140 | - | ||
141 | - | ||
142 | -class CustomDataset(Dataset): | ||
143 | - def __init__(self, data_path, csv_path): | ||
144 | - self.len = len(self.imgs) | ||
145 | - self.path = data_path | ||
146 | - self.imgs = natsorted(os.listdir(data_path)) | ||
147 | - | ||
148 | - df = pd.read_csv(csv_path) | ||
149 | - targets_list = [] | ||
150 | - | ||
151 | - for fname in self.imgs: | ||
152 | - row = df.loc[df['filename'] == fname] | ||
153 | - targets_list.append(row.iloc[0, 1]) | ||
154 | - | ||
155 | - self.targets = targets_list | ||
156 | - | ||
157 | - def __len__(self): | ||
158 | - return self.len | ||
159 | - | ||
160 | - def __getitem__(self, idx): | ||
161 | - img_loc = os.path.join(self.path, self.imgs[idx]) | ||
162 | - targets = self.targets[idx] | ||
163 | - image = Image.open(img_loc) | ||
164 | - return image, targets | ||
165 | - | ||
166 | - | ||
167 | - | ||
168 | -def get_dataset(args, transform, split='train'): | ||
169 | - assert split in ['train', 'val', 'test', 'trainval'] | ||
170 | - | ||
171 | - if split in ['train']: | ||
172 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | ||
173 | - else: #test | ||
174 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | ||
175 | - | ||
176 | - return dataset | ||
177 | - | ||
178 | - | ||
179 | -def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
180 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
181 | - batch_size=args.batch_size, | ||
182 | - shuffle=shuffle, | ||
183 | - num_workers=args.num_workers, | ||
184 | - pin_memory=pin_memory) | ||
185 | - return data_loader | ||
186 | - | ||
187 | - | ||
188 | -def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
189 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
190 | - batch_size=args.batch_size, | ||
191 | - shuffle=shuffle, | ||
192 | - num_workers=args.num_workers, | ||
193 | - pin_memory=pin_memory) | ||
194 | - return data_loader | ||
195 | - | ||
196 | - | ||
197 | -def get_inf_dataloader(args, dataset): | ||
198 | - global current_epoch | ||
199 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
200 | - | ||
201 | - while True: | ||
202 | - try: | ||
203 | - batch = next(data_loader) | ||
204 | - | ||
205 | - except StopIteration: | ||
206 | - current_epoch += 1 | ||
207 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
208 | - batch = next(data_loader) | ||
209 | - | ||
210 | - yield batch | ||
211 | - | ||
212 | - | ||
213 | - | ||
214 | - | ||
215 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
216 | - model.train() | ||
217 | - images, target = batch | ||
218 | - | ||
219 | - if device: | ||
220 | - images = images.to(device) | ||
221 | - target = target.to(device) | ||
222 | - | ||
223 | - elif args.use_cuda: | ||
224 | - images = images.cuda(non_blocking=True) | ||
225 | - target = target.cuda(non_blocking=True) | ||
226 | - | ||
227 | - # compute output | ||
228 | - start_t = time.time() | ||
229 | - output, first = model(images) | ||
230 | - forward_t = time.time() - start_t | ||
231 | - loss = criterion(output, target) | ||
232 | - | ||
233 | - # measure accuracy and record loss | ||
234 | - acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
235 | - acc1 /= images.size(0) | ||
236 | - acc5 /= images.size(0) | ||
237 | - | ||
238 | - # compute gradient and do SGD step | ||
239 | - optimizer.zero_grad() | ||
240 | - start_t = time.time() | ||
241 | - loss.backward() | ||
242 | - backward_t = time.time() - start_t | ||
243 | - optimizer.step() | ||
244 | - if scheduler: scheduler.step() | ||
245 | - | ||
246 | - if writer and step % args.print_step == 0: | ||
247 | - n_imgs = min(images.size(0), 10) | ||
248 | - tag = 'train/' + str(step) | ||
249 | - for j in range(n_imgs): | ||
250 | - writer.add_image(tag, | ||
251 | - concat_image_features(images[j], first[j]), global_step=step) | ||
252 | - | ||
253 | - return acc1, acc5, loss, forward_t, backward_t | ||
254 | - | ||
255 | - | ||
256 | -#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
257 | -def accuracy(output, target, topk=(1,)): | ||
258 | - """Computes the accuracy over the k top predictions for the specified values of k""" | ||
259 | - with torch.no_grad(): | ||
260 | - maxk = max(topk) | ||
261 | - batch_size = target.size(0) | ||
262 | - | ||
263 | - _, pred = output.topk(maxk, 1, True, True) | ||
264 | - pred = pred.t() | ||
265 | - correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
266 | - | ||
267 | - res = [] | ||
268 | - for k in topk: | ||
269 | - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
270 | - res.append(correct_k) | ||
271 | - return res | ||
272 | - |
-
Please register or login to post a comment