조현아

flair2seg & filter normal

Showing 64 changed files with 5314 additions and 0 deletions
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<module type="PYTHON_MODULE" version="4">
3 + <component name="NewModuleRootManager">
4 + <content url="file://$MODULE_DIR$" />
5 + <orderEntry type="inheritedJdk" />
6 + <orderEntry type="sourceFolder" forTests="false" />
7 + </component>
8 + <component name="TestRunnerService">
9 + <option name="PROJECT_TEST_RUNNER" value="Unittests" />
10 + </component>
11 +</module>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="Encoding" addBOMForNewFiles="with NO BOM" />
4 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="JavaScriptSettings">
4 + <option name="languageLevel" value="ES6" />
5 + </component>
6 + <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (code) (3)" project-jdk-type="Python SDK" />
7 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="ProjectModuleManager">
4 + <modules>
5 + <module fileurl="file://$PROJECT_DIR$/.idea/code.iml" filepath="$PROJECT_DIR$/.idea/code.iml" />
6 + </modules>
7 + </component>
8 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="ChangeListManager">
4 + <list default="true" id="38bbd81d-cf79-4a3f-8979-7c3ceb27bc32" name="Default Changelist" comment="" />
5 + <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
6 + <option name="SHOW_DIALOG" value="false" />
7 + <option name="HIGHLIGHT_CONFLICTS" value="true" />
8 + <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
9 + <option name="LAST_RESOLUTION" value="IGNORE" />
10 + </component>
11 + <component name="ProjectFrameBounds">
12 + <option name="x" value="-9" />
13 + <option name="width" value="1825" />
14 + <option name="height" value="1039" />
15 + </component>
16 + <component name="ProjectView">
17 + <navigator proportions="" version="1">
18 + <foldersAlwaysOnTop value="true" />
19 + </navigator>
20 + <panes>
21 + <pane id="Scope" />
22 + <pane id="ProjectPane" />
23 + </panes>
24 + </component>
25 + <component name="PropertiesComponent">
26 + <property name="WebServerToolWindowFactoryState" value="false" />
27 + <property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
28 + <property name="nodejs_npm_path_reset_for_default_project" value="true" />
29 + <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
30 + </component>
31 + <component name="RunDashboard">
32 + <option name="ruleStates">
33 + <list>
34 + <RuleState>
35 + <option name="name" value="ConfigurationTypeDashboardGroupingRule" />
36 + </RuleState>
37 + <RuleState>
38 + <option name="name" value="StatusDashboardGroupingRule" />
39 + </RuleState>
40 + </list>
41 + </option>
42 + </component>
43 + <component name="SvnConfiguration">
44 + <configuration />
45 + </component>
46 + <component name="TaskManager">
47 + <task active="true" id="Default" summary="Default task">
48 + <changelist id="38bbd81d-cf79-4a3f-8979-7c3ceb27bc32" name="Default Changelist" comment="" />
49 + <created>1584871834688</created>
50 + <option name="number" value="Default" />
51 + <option name="presentableId" value="Default" />
52 + <updated>1584871834688</updated>
53 + <workItem from="1584871840396" duration="72000" />
54 + </task>
55 + <servers />
56 + </component>
57 + <component name="TimeTrackingManager">
58 + <option name="totallyTimeSpent" value="72000" />
59 + </component>
60 + <component name="ToolWindowManager">
61 + <frame x="-7" y="0" width="1460" height="831" extended-state="0" />
62 + <layout>
63 + <window_info id="Favorites" side_tool="true" />
64 + <window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.25" />
65 + <window_info id="Structure" order="1" side_tool="true" weight="0.25" />
66 + <window_info anchor="bottom" id="Docker" show_stripe_button="false" />
67 + <window_info anchor="bottom" id="Database Changes" />
68 + <window_info anchor="bottom" id="Version Control" />
69 + <window_info anchor="bottom" id="Python Console" />
70 + <window_info anchor="bottom" id="Terminal" />
71 + <window_info anchor="bottom" id="Event Log" side_tool="true" />
72 + <window_info anchor="bottom" id="Message" order="0" />
73 + <window_info anchor="bottom" id="Find" order="1" />
74 + <window_info anchor="bottom" id="Run" order="2" />
75 + <window_info anchor="bottom" id="Debug" order="3" weight="0.4" />
76 + <window_info anchor="bottom" id="Cvs" order="4" weight="0.25" />
77 + <window_info anchor="bottom" id="Inspection" order="5" weight="0.4" />
78 + <window_info anchor="bottom" id="TODO" order="6" />
79 + <window_info anchor="right" id="SciView" />
80 + <window_info anchor="right" id="Database" />
81 + <window_info anchor="right" id="Commander" internal_type="SLIDING" order="0" type="SLIDING" weight="0.4" />
82 + <window_info anchor="right" id="Ant Build" order="1" weight="0.25" />
83 + <window_info anchor="right" content_ui="combo" id="Hierarchy" order="2" weight="0.25" />
84 + </layout>
85 + </component>
86 + <component name="TypeScriptGeneratedFilesManager">
87 + <option name="version" value="1" />
88 + </component>
89 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +{
2 + "cells": [
3 + {
4 + "cell_type": "code",
5 + "execution_count": 1,
6 + "metadata": {},
7 + "outputs": [],
8 + "source": [
9 + "import pandas as pd"
10 + ]
11 + },
12 + {
13 + "cell_type": "code",
14 + "execution_count": 3,
15 + "metadata": {},
16 + "outputs": [
17 + {
18 + "data": {
19 + "text/html": [
20 + "<div>\n",
21 + "<style scoped>\n",
22 + " .dataframe tbody tr th:only-of-type {\n",
23 + " vertical-align: middle;\n",
24 + " }\n",
25 + "\n",
26 + " .dataframe tbody tr th {\n",
27 + " vertical-align: top;\n",
28 + " }\n",
29 + "\n",
30 + " .dataframe thead th {\n",
31 + " text-align: right;\n",
32 + " }\n",
33 + "</style>\n",
34 + "<table border=\"1\" class=\"dataframe\">\n",
35 + " <thead>\n",
36 + " <tr style=\"text-align: right;\">\n",
37 + " <th></th>\n",
38 + " <th>id</th>\n",
39 + " <th>Label</th>\n",
40 + " <th>Subject</th>\n",
41 + " <th>Date</th>\n",
42 + " <th>Gender</th>\n",
43 + " <th>Age</th>\n",
44 + " <th>mmse</th>\n",
45 + " <th>ageAtEntry</th>\n",
46 + " <th>cdr</th>\n",
47 + " <th>commun</th>\n",
48 + " <th>...</th>\n",
49 + " <th>memory</th>\n",
50 + " <th>orient</th>\n",
51 + " <th>perscare</th>\n",
52 + " <th>apoe</th>\n",
53 + " <th>sumbox</th>\n",
54 + " <th>acsparnt</th>\n",
55 + " <th>height</th>\n",
56 + " <th>weight</th>\n",
57 + " <th>primStudy</th>\n",
58 + " <th>acsStudy</th>\n",
59 + " </tr>\n",
60 + " </thead>\n",
61 + " <tbody>\n",
62 + " <tr>\n",
63 + " <th>0</th>\n",
64 + " <td>/@WEBAPP/images/r.gif</td>\n",
65 + " <td>OAS30001_ClinicalData_d3025</td>\n",
66 + " <td>OAS30001</td>\n",
67 + " <td>NaN</td>\n",
68 + " <td>female</td>\n",
69 + " <td>NaN</td>\n",
70 + " <td>30.0</td>\n",
71 + " <td>65.149895</td>\n",
72 + " <td>0.0</td>\n",
73 + " <td>0.0</td>\n",
74 + " <td>...</td>\n",
75 + " <td>0.0</td>\n",
76 + " <td>0.0</td>\n",
77 + " <td>0.0</td>\n",
78 + " <td>23.0</td>\n",
79 + " <td>0.0</td>\n",
80 + " <td>NaN</td>\n",
81 + " <td>64.0</td>\n",
82 + " <td>180.0</td>\n",
83 + " <td>NaN</td>\n",
84 + " <td>NaN</td>\n",
85 + " </tr>\n",
86 + " <tr>\n",
87 + " <th>1</th>\n",
88 + " <td>/@WEBAPP/images/r.gif</td>\n",
89 + " <td>OAS30001_ClinicalData_d3977</td>\n",
90 + " <td>OAS30001</td>\n",
91 + " <td>NaN</td>\n",
92 + " <td>female</td>\n",
93 + " <td>NaN</td>\n",
94 + " <td>29.0</td>\n",
95 + " <td>65.149895</td>\n",
96 + " <td>0.0</td>\n",
97 + " <td>0.0</td>\n",
98 + " <td>...</td>\n",
99 + " <td>0.0</td>\n",
100 + " <td>0.0</td>\n",
101 + " <td>0.0</td>\n",
102 + " <td>23.0</td>\n",
103 + " <td>0.0</td>\n",
104 + " <td>NaN</td>\n",
105 + " <td>NaN</td>\n",
106 + " <td>NaN</td>\n",
107 + " <td>NaN</td>\n",
108 + " <td>NaN</td>\n",
109 + " </tr>\n",
110 + " <tr>\n",
111 + " <th>2</th>\n",
112 + " <td>/@WEBAPP/images/r.gif</td>\n",
113 + " <td>OAS30001_ClinicalData_d3332</td>\n",
114 + " <td>OAS30001</td>\n",
115 + " <td>NaN</td>\n",
116 + " <td>female</td>\n",
117 + " <td>NaN</td>\n",
118 + " <td>30.0</td>\n",
119 + " <td>65.149895</td>\n",
120 + " <td>0.0</td>\n",
121 + " <td>0.0</td>\n",
122 + " <td>...</td>\n",
123 + " <td>0.0</td>\n",
124 + " <td>0.0</td>\n",
125 + " <td>0.0</td>\n",
126 + " <td>23.0</td>\n",
127 + " <td>0.0</td>\n",
128 + " <td>NaN</td>\n",
129 + " <td>63.0</td>\n",
130 + " <td>185.0</td>\n",
131 + " <td>NaN</td>\n",
132 + " <td>NaN</td>\n",
133 + " </tr>\n",
134 + " <tr>\n",
135 + " <th>3</th>\n",
136 + " <td>/@WEBAPP/images/r.gif</td>\n",
137 + " <td>OAS30001_ClinicalData_d0000</td>\n",
138 + " <td>OAS30001</td>\n",
139 + " <td>NaN</td>\n",
140 + " <td>female</td>\n",
141 + " <td>NaN</td>\n",
142 + " <td>28.0</td>\n",
143 + " <td>65.149895</td>\n",
144 + " <td>0.0</td>\n",
145 + " <td>0.0</td>\n",
146 + " <td>...</td>\n",
147 + " <td>0.0</td>\n",
148 + " <td>0.0</td>\n",
149 + " <td>0.0</td>\n",
150 + " <td>23.0</td>\n",
151 + " <td>0.0</td>\n",
152 + " <td>NaN</td>\n",
153 + " <td>NaN</td>\n",
154 + " <td>NaN</td>\n",
155 + " <td>NaN</td>\n",
156 + " <td>NaN</td>\n",
157 + " </tr>\n",
158 + " <tr>\n",
159 + " <th>4</th>\n",
160 + " <td>/@WEBAPP/images/r.gif</td>\n",
161 + " <td>OAS30001_ClinicalData_d1456</td>\n",
162 + " <td>OAS30001</td>\n",
163 + " <td>NaN</td>\n",
164 + " <td>female</td>\n",
165 + " <td>NaN</td>\n",
166 + " <td>30.0</td>\n",
167 + " <td>65.149895</td>\n",
168 + " <td>0.0</td>\n",
169 + " <td>0.0</td>\n",
170 + " <td>...</td>\n",
171 + " <td>0.0</td>\n",
172 + " <td>0.0</td>\n",
173 + " <td>0.0</td>\n",
174 + " <td>23.0</td>\n",
175 + " <td>0.0</td>\n",
176 + " <td>NaN</td>\n",
177 + " <td>63.0</td>\n",
178 + " <td>173.0</td>\n",
179 + " <td>NaN</td>\n",
180 + " <td>NaN</td>\n",
181 + " </tr>\n",
182 + " </tbody>\n",
183 + "</table>\n",
184 + "<p>5 rows × 27 columns</p>\n",
185 + "</div>"
186 + ],
187 + "text/plain": [
188 + " id Label Subject Date Gender \\\n",
189 + "0 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3025 OAS30001 NaN female \n",
190 + "1 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3977 OAS30001 NaN female \n",
191 + "2 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3332 OAS30001 NaN female \n",
192 + "3 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d0000 OAS30001 NaN female \n",
193 + "4 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d1456 OAS30001 NaN female \n",
194 + "\n",
195 + " Age mmse ageAtEntry cdr commun ... memory orient perscare apoe \\\n",
196 + "0 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
197 + "1 NaN 29.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
198 + "2 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
199 + "3 NaN 28.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
200 + "4 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
201 + "\n",
202 + " sumbox acsparnt height weight primStudy acsStudy \n",
203 + "0 0.0 NaN 64.0 180.0 NaN NaN \n",
204 + "1 0.0 NaN NaN NaN NaN NaN \n",
205 + "2 0.0 NaN 63.0 185.0 NaN NaN \n",
206 + "3 0.0 NaN NaN NaN NaN NaN \n",
207 + "4 0.0 NaN 63.0 173.0 NaN NaN \n",
208 + "\n",
209 + "[5 rows x 27 columns]"
210 + ]
211 + },
212 + "execution_count": 3,
213 + "metadata": {},
214 + "output_type": "execute_result"
215 + }
216 + ],
217 + "source": [
218 + "all_data = pd.read_csv(\"..\\data\\ADRC clinical data_all.csv\")\n",
219 + "\n",
220 + "all_data.head()"
221 + ]
222 + },
223 + {
224 + "cell_type": "code",
225 + "execution_count": 18,
226 + "metadata": {},
227 + "outputs": [
228 + {
229 + "name": "stdout",
230 + "output_type": "stream",
231 + "text": [
232 + "Subject\n",
233 + "OAS30001 0.0\n",
234 + "OAS30002 0.0\n",
235 + "OAS30003 0.0\n",
236 + "OAS30004 0.0\n",
237 + "OAS30005 0.0\n",
238 + " ... \n",
239 + "OAS31168 0.0\n",
240 + "OAS31169 3.0\n",
241 + "OAS31170 2.0\n",
242 + "OAS31171 2.0\n",
243 + "OAS31172 0.0\n",
244 + "Name: cdr, Length: 1098, dtype: float64\n"
245 + ]
246 + }
247 + ],
248 + "source": [
249 + "ad = all_data.groupby(['Subject'])['cdr'].max()\n",
250 + "print(ad)"
251 + ]
252 + },
253 + {
254 + "cell_type": "code",
255 + "execution_count": 21,
256 + "metadata": {},
257 + "outputs": [
258 + {
259 + "data": {
260 + "text/plain": [
261 + "'OAS30001'"
262 + ]
263 + },
264 + "execution_count": 21,
265 + "metadata": {},
266 + "output_type": "execute_result"
267 + }
268 + ],
269 + "source": [
270 + "ad.index[0]"
271 + ]
272 + },
273 + {
274 + "cell_type": "code",
275 + "execution_count": 22,
276 + "metadata": {},
277 + "outputs": [
278 + {
279 + "ename": "SyntaxError",
280 + "evalue": "unexpected EOF while parsing (<ipython-input-22-b8a078b72aca>, line 5)",
281 + "output_type": "error",
282 + "traceback": [
283 + "\u001b[1;36m File \u001b[1;32m\"<ipython-input-22-b8a078b72aca>\"\u001b[1;36m, line \u001b[1;32m5\u001b[0m\n\u001b[1;33m #print(filtered)\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m unexpected EOF while parsing\n"
284 + ]
285 + }
286 + ],
287 + "source": [
288 + "filtered = []\n",
289 + "for i, val in enumerate(ad):\n",
290 + " if ad[i] == 0:\n",
291 + " filtered.append(ad.index[i])\n",
292 + "#print(filtered)"
293 + ]
294 + },
295 + {
296 + "cell_type": "code",
297 + "execution_count": 23,
298 + "metadata": {},
299 + "outputs": [],
300 + "source": [
301 + "df_filtered = pd.DataFrame(filtered)\n",
302 + "df_filtered.to_csv('..\\data\\ADRC clinical data_normal.csv')"
303 + ]
304 + },
305 + {
306 + "cell_type": "code",
307 + "execution_count": null,
308 + "metadata": {},
309 + "outputs": [],
310 + "source": []
311 + }
312 + ],
313 + "metadata": {
314 + "kernelspec": {
315 + "display_name": "ML",
316 + "language": "python",
317 + "name": "ml"
318 + },
319 + "language_info": {
320 + "codemirror_mode": {
321 + "name": "ipython",
322 + "version": 3
323 + },
324 + "file_extension": ".py",
325 + "mimetype": "text/x-python",
326 + "name": "python",
327 + "nbconvert_exporter": "python",
328 + "pygments_lexer": "ipython3",
329 + "version": "3.7.4"
330 + }
331 + },
332 + "nbformat": 4,
333 + "nbformat_minor": 2
334 +}
This diff could not be displayed because it is too large.
1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# C extensions
7 +*.so
8 +
9 +# Distribution / packaging
10 +.Python
11 +build/
12 +develop-eggs/
13 +dist/
14 +downloads/
15 +eggs/
16 +.eggs/
17 +lib/
18 +lib64/
19 +parts/
20 +sdist/
21 +var/
22 +wheels/
23 +*.egg-info/
24 +.installed.cfg
25 +*.egg
26 +MANIFEST
27 +
28 +# PyInstaller
29 +# Usually these files are written by a python script from a template
30 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
31 +*.manifest
32 +*.spec
33 +
34 +# Installer logs
35 +pip-log.txt
36 +pip-delete-this-directory.txt
37 +
38 +# Unit test / coverage reports
39 +htmlcov/
40 +.tox/
41 +.coverage
42 +.coverage.*
43 +.cache
44 +nosetests.xml
45 +coverage.xml
46 +*.cover
47 +.hypothesis/
48 +.pytest_cache/
49 +
50 +# Translations
51 +*.mo
52 +*.pot
53 +
54 +# Django stuff:
55 +*.log
56 +local_settings.py
57 +db.sqlite3
58 +
59 +# Flask stuff:
60 +instance/
61 +.webassets-cache
62 +
63 +# Scrapy stuff:
64 +.scrapy
65 +
66 +# Sphinx documentation
67 +docs/_build/
68 +
69 +# PyBuilder
70 +target/
71 +
72 +# Jupyter Notebook
73 +.ipynb_checkpoints
74 +
75 +# pyenv
76 +.python-version
77 +
78 +# celery beat schedule file
79 +celerybeat-schedule
80 +
81 +# SageMath parsed files
82 +*.sage.py
83 +
84 +# Environments
85 +.env
86 +.venv
87 +env/
88 +venv/
89 +ENV/
90 +env.bak/
91 +venv.bak/
92 +
93 +# Spyder project settings
94 +.spyderproject
95 +.spyproject
96 +
97 +# Rope project settings
98 +.ropeproject
99 +
100 +# mkdocs documentation
101 +/site
102 +
103 +# mypy
104 +.mypy_cache/
This diff could not be displayed because it is too large.
1 +"""
2 +Reference :
3 +- https://github.com/hysts/pytorch_image_classification/blob/master/augmentations/mixup.py
4 +- https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/imagenet_input.py#L120
5 +"""
6 +
7 +import numpy as np
8 +import torch
9 +
10 +from FastAutoAugment.metrics import CrossEntropyLabelSmooth
11 +
12 +
13 +def mixup(data, targets, alpha):
14 + indices = torch.randperm(data.size(0))
15 + shuffled_data = data[indices]
16 + shuffled_targets = targets[indices]
17 +
18 + lam = np.random.beta(alpha, alpha)
19 + lam = max(lam, 1. - lam)
20 + assert 0.0 <= lam <= 1.0, lam
21 + data = data * lam + shuffled_data * (1 - lam)
22 +
23 + return data, targets, shuffled_targets, lam
24 +
25 +
26 +class CrossEntropyMixUpLabelSmooth(torch.nn.Module):
27 + def __init__(self, num_classes, epsilon, reduction='mean'):
28 + super(CrossEntropyMixUpLabelSmooth, self).__init__()
29 + self.ce = CrossEntropyLabelSmooth(num_classes, epsilon, reduction=reduction)
30 +
31 + def forward(self, input, target1, target2, lam): # pylint: disable=redefined-builtin
32 + return lam * self.ce(input, target1) + (1 - lam) * self.ce(input, target2)
1 +# code in this file is adpated from rpmcruz/autoaugment
2 +# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 +import random
4 +
5 +import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 +import numpy as np
7 +import torch
8 +from torchvision.transforms.transforms import Compose
9 +
10 +random_mirror = True
11 +
12 +
13 +def ShearX(img, v): # [-0.3, 0.3]
14 + assert -0.3 <= v <= 0.3
15 + if random_mirror and random.random() > 0.5:
16 + v = -v
17 + return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18 +
19 +
20 +def ShearY(img, v): # [-0.3, 0.3]
21 + assert -0.3 <= v <= 0.3
22 + if random_mirror and random.random() > 0.5:
23 + v = -v
24 + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25 +
26 +
27 +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
28 + assert -0.45 <= v <= 0.45
29 + if random_mirror and random.random() > 0.5:
30 + v = -v
31 + v = v * img.size[0]
32 + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 +
34 +
35 +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
36 + assert -0.45 <= v <= 0.45
37 + if random_mirror and random.random() > 0.5:
38 + v = -v
39 + v = v * img.size[1]
40 + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
41 +
42 +
43 +def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
44 + assert 0 <= v <= 10
45 + if random.random() > 0.5:
46 + v = -v
47 + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
48 +
49 +
50 +def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51 + assert 0 <= v <= 10
52 + if random.random() > 0.5:
53 + v = -v
54 + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55 +
56 +
57 +def Rotate(img, v): # [-30, 30]
58 + assert -30 <= v <= 30
59 + if random_mirror and random.random() > 0.5:
60 + v = -v
61 + return img.rotate(v)
62 +
63 +
64 +def AutoContrast(img, _):
65 + return PIL.ImageOps.autocontrast(img)
66 +
67 +
68 +def Invert(img, _):
69 + return PIL.ImageOps.invert(img)
70 +
71 +
72 +def Equalize(img, _):
73 + return PIL.ImageOps.equalize(img)
74 +
75 +
76 +def Flip(img, _): # not from the paper
77 + return PIL.ImageOps.mirror(img)
78 +
79 +
80 +def Solarize(img, v): # [0, 256]
81 + assert 0 <= v <= 256
82 + return PIL.ImageOps.solarize(img, v)
83 +
84 +
85 +def Posterize(img, v): # [4, 8]
86 + assert 4 <= v <= 8
87 + v = int(v)
88 + return PIL.ImageOps.posterize(img, v)
89 +
90 +
91 +def Posterize2(img, v): # [0, 4]
92 + assert 0 <= v <= 4
93 + v = int(v)
94 + return PIL.ImageOps.posterize(img, v)
95 +
96 +
97 +def Contrast(img, v): # [0.1,1.9]
98 + assert 0.1 <= v <= 1.9
99 + return PIL.ImageEnhance.Contrast(img).enhance(v)
100 +
101 +
102 +def Color(img, v): # [0.1,1.9]
103 + assert 0.1 <= v <= 1.9
104 + return PIL.ImageEnhance.Color(img).enhance(v)
105 +
106 +
107 +def Brightness(img, v): # [0.1,1.9]
108 + assert 0.1 <= v <= 1.9
109 + return PIL.ImageEnhance.Brightness(img).enhance(v)
110 +
111 +
112 +def Sharpness(img, v): # [0.1,1.9]
113 + assert 0.1 <= v <= 1.9
114 + return PIL.ImageEnhance.Sharpness(img).enhance(v)
115 +
116 +
117 +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
118 + assert 0.0 <= v <= 0.2
119 + if v <= 0.:
120 + return img
121 +
122 + v = v * img.size[0]
123 + return CutoutAbs(img, v)
124 +
125 +
126 +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
127 + # assert 0 <= v <= 20
128 + if v < 0:
129 + return img
130 + w, h = img.size
131 + x0 = np.random.uniform(w)
132 + y0 = np.random.uniform(h)
133 +
134 + x0 = int(max(0, x0 - v / 2.))
135 + y0 = int(max(0, y0 - v / 2.))
136 + x1 = min(w, x0 + v)
137 + y1 = min(h, y0 + v)
138 +
139 + xy = (x0, y0, x1, y1)
140 + color = (125, 123, 114)
141 + # color = (0, 0, 0)
142 + img = img.copy()
143 + PIL.ImageDraw.Draw(img).rectangle(xy, color)
144 + return img
145 +
146 +
147 +def SamplePairing(imgs): # [0, 0.4]
148 + def f(img1, v):
149 + i = np.random.choice(len(imgs))
150 + img2 = PIL.Image.fromarray(imgs[i])
151 + return PIL.Image.blend(img1, img2, v)
152 +
153 + return f
154 +
155 +
156 +def augment_list(for_autoaug=True): # 16 oeprations and their ranges
157 + l = [
158 + (ShearX, -0.3, 0.3), # 0
159 + (ShearY, -0.3, 0.3), # 1
160 + (TranslateX, -0.45, 0.45), # 2
161 + (TranslateY, -0.45, 0.45), # 3
162 + (Rotate, -30, 30), # 4
163 + (AutoContrast, 0, 1), # 5
164 + (Invert, 0, 1), # 6
165 + (Equalize, 0, 1), # 7
166 + (Solarize, 0, 256), # 8
167 + (Posterize, 4, 8), # 9
168 + (Contrast, 0.1, 1.9), # 10
169 + (Color, 0.1, 1.9), # 11
170 + (Brightness, 0.1, 1.9), # 12
171 + (Sharpness, 0.1, 1.9), # 13
172 + (Cutout, 0, 0.2), # 14
173 + # (SamplePairing(imgs), 0, 0.4), # 15
174 + ]
175 + if for_autoaug:
176 + l += [
177 + (CutoutAbs, 0, 20), # compatible with auto-augment
178 + (Posterize2, 0, 4), # 9
179 + (TranslateXAbs, 0, 10), # 9
180 + (TranslateYAbs, 0, 10), # 9
181 + ]
182 + return l
183 +
184 +
185 +augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
186 +
187 +
188 +def get_augment(name):
189 + return augment_dict[name]
190 +
191 +
192 +def apply_augment(img, name, level):
193 + augment_fn, low, high = get_augment(name)
194 + return augment_fn(img.copy(), level * (high - low) + low)
195 +
196 +
197 +class Lighting(object):
198 + """Lighting noise(AlexNet - style PCA - based noise)"""
199 +
200 + def __init__(self, alphastd, eigval, eigvec):
201 + self.alphastd = alphastd
202 + self.eigval = torch.Tensor(eigval)
203 + self.eigvec = torch.Tensor(eigvec)
204 +
205 + def __call__(self, img):
206 + if self.alphastd == 0:
207 + return img
208 +
209 + alpha = img.new().resize_(3).normal_(0, self.alphastd)
210 + rgb = self.eigvec.type_as(img).clone() \
211 + .mul(alpha.view(1, 3).expand(3, 3)) \
212 + .mul(self.eigval.view(1, 3).expand(3, 3)) \
213 + .sum(1).squeeze()
214 +
215 + return img.add(rgb.view(3, 1, 1).expand_as(img))
1 +import copy
2 +import logging
3 +import warnings
4 +
5 +formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
6 +warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
7 +warnings.filterwarnings("ignore", "DeprecationWarning: 'saved_variables' is deprecated", UserWarning)
8 +
9 +
10 +def get_logger(name, level=logging.DEBUG):
11 + logger = logging.getLogger(name)
12 + logger.handlers.clear()
13 + logger.setLevel(level)
14 + ch = logging.StreamHandler()
15 + ch.setLevel(level)
16 + ch.setFormatter(formatter)
17 + logger.addHandler(ch)
18 + return logger
19 +
20 +
21 +def add_filehandler(logger, filepath, level=logging.DEBUG):
22 + fh = logging.FileHandler(filepath)
23 + fh.setLevel(level)
24 + fh.setFormatter(formatter)
25 + logger.addHandler(fh)
26 +
27 +
28 +class EMA:
29 + def __init__(self, mu):
30 + self.mu = mu
31 + self.shadow = {}
32 +
33 + def state_dict(self):
34 + return copy.deepcopy(self.shadow)
35 +
36 + def __len__(self):
37 + return len(self.shadow)
38 +
39 + def __call__(self, module, step=None):
40 + if step is None:
41 + mu = self.mu
42 + else:
43 + # see : https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/ExponentialMovingAverage?hl=PL
44 + mu = min(self.mu, (1. + step) / (10 + step))
45 +
46 + for name, x in module.state_dict().items():
47 + if name in self.shadow:
48 + new_average = (1.0 - mu) * x + mu * self.shadow[name]
49 + self.shadow[name] = new_average.clone()
50 + else:
51 + self.shadow[name] = x.clone()
1 +import logging
2 +
3 +import numpy as np
4 +import os
5 +
6 +import math
7 +import random
8 +import torch
9 +import torchvision
10 +from PIL import Image
11 +
12 +from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset
13 +import torch.distributed as dist
14 +from torchvision.transforms import transforms
15 +from sklearn.model_selection import StratifiedShuffleSplit
16 +from theconf import Config as C
17 +
18 +from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet
19 +from FastAutoAugment.augmentations import *
20 +from FastAutoAugment.common import get_logger
21 +from FastAutoAugment.imagenet import ImageNet
22 +from FastAutoAugment.networks.efficientnet_pytorch.model import EfficientNet
23 +
24 +logger = get_logger('Fast AutoAugment')
25 +logger.setLevel(logging.INFO)
26 +_IMAGENET_PCA = {
27 + 'eigval': [0.2175, 0.0188, 0.0045],
28 + 'eigvec': [
29 + [-0.5675, 0.7192, 0.4009],
30 + [-0.5808, -0.0045, -0.8140],
31 + [-0.5836, -0.6948, 0.4203],
32 + ]
33 +}
34 +_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
35 +
36 +
37 +def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1):
38 + if 'cifar' in dataset or 'svhn' in dataset:
39 + transform_train = transforms.Compose([
40 + transforms.RandomCrop(32, padding=4),
41 + transforms.RandomHorizontalFlip(),
42 + transforms.ToTensor(),
43 + transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
44 + ])
45 + transform_test = transforms.Compose([
46 + transforms.ToTensor(),
47 + transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
48 + ])
49 + elif 'imagenet' in dataset:
50 + input_size = 224
51 + sized_size = 256
52 +
53 + if 'efficientnet' in C.get()['model']['type']:
54 + input_size = EfficientNet.get_image_size(C.get()['model']['type'])
55 + sized_size = input_size + 32 # TODO
56 + # sized_size = int(round(input_size / 224. * 256))
57 + # sized_size = input_size
58 + logger.info('size changed to %d/%d.' % (input_size, sized_size))
59 +
60 + transform_train = transforms.Compose([
61 + EfficientNetRandomCrop(input_size),
62 + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
63 + # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC),
64 + transforms.RandomHorizontalFlip(),
65 + transforms.ColorJitter(
66 + brightness=0.4,
67 + contrast=0.4,
68 + saturation=0.4,
69 + ),
70 + transforms.ToTensor(),
71 + Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
72 + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
73 + ])
74 +
75 + transform_test = transforms.Compose([
76 + EfficientNetCenterCrop(input_size),
77 + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC),
78 + transforms.ToTensor(),
79 + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80 + ])
81 +
82 + else:
83 + raise ValueError('dataset=%s' % dataset)
84 +
85 + total_aug = augs = None
86 + if isinstance(C.get()['aug'], list):
87 + logger.debug('augmentation provided.')
88 + transform_train.transforms.insert(0, Augmentation(C.get()['aug']))
89 + else:
90 + logger.debug('augmentation: %s' % C.get()['aug'])
91 + if C.get()['aug'] == 'fa_reduced_cifar10':
92 + transform_train.transforms.insert(0, Augmentation(fa_reduced_cifar10()))
93 +
94 + elif C.get()['aug'] == 'fa_reduced_imagenet':
95 + transform_train.transforms.insert(0, Augmentation(fa_resnet50_rimagenet()))
96 +
97 + elif C.get()['aug'] == 'fa_reduced_svhn':
98 + transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn()))
99 +
100 + elif C.get()['aug'] == 'arsaug':
101 + transform_train.transforms.insert(0, Augmentation(arsaug_policy()))
102 + elif C.get()['aug'] == 'autoaug_cifar10':
103 + transform_train.transforms.insert(0, Augmentation(autoaug_paper_cifar10()))
104 + elif C.get()['aug'] == 'autoaug_extend':
105 + transform_train.transforms.insert(0, Augmentation(autoaug_policy()))
106 + elif C.get()['aug'] in ['default']:
107 + pass
108 + else:
109 + raise ValueError('not found augmentations. %s' % C.get()['aug'])
110 +
111 + if C.get()['cutout'] > 0:
112 + transform_train.transforms.append(CutoutDefault(C.get()['cutout']))
113 +
114 + if dataset == 'cifar10':
115 + total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
116 + testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
117 + elif dataset == 'reduced_cifar10':
118 + total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
119 + sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset
120 + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
121 + train_idx, valid_idx = next(sss)
122 + targets = [total_trainset.targets[idx] for idx in train_idx]
123 + total_trainset = Subset(total_trainset, train_idx)
124 + total_trainset.targets = targets
125 +
126 + testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
127 + elif dataset == 'cifar100':
128 + total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train)
129 + testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test)
130 + elif dataset == 'svhn':
131 + trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train)
132 + extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train)
133 + total_trainset = ConcatDataset([trainset, extraset])
134 + testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
135 + elif dataset == 'reduced_svhn':
136 + total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train)
137 + sss = StratifiedShuffleSplit(n_splits=1, test_size=73257-1000, random_state=0) # 1000 trainset
138 + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
139 + train_idx, valid_idx = next(sss)
140 + targets = [total_trainset.targets[idx] for idx in train_idx]
141 + total_trainset = Subset(total_trainset, train_idx)
142 + total_trainset.targets = targets
143 +
144 + testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
145 + elif dataset == 'imagenet':
146 + total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
147 + testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
148 +
149 + # compatibility
150 + total_trainset.targets = [lb for _, lb in total_trainset.samples]
151 + elif dataset == 'reduced_imagenet':
152 + # randomly chosen indices
153 +# idx120 = sorted(random.sample(list(range(1000)), k=120))
154 + idx120 = [16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959]
155 + total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train)
156 + testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test)
157 +
158 + # compatibility
159 + total_trainset.targets = [lb for _, lb in total_trainset.samples]
160 +
161 + sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0) # 4000 trainset
162 + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
163 + train_idx, valid_idx = next(sss)
164 +
165 + # filter out
166 + train_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, train_idx))
167 + valid_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, valid_idx))
168 + test_idx = list(filter(lambda x: testset.samples[x][1] in idx120, range(len(testset))))
169 +
170 + targets = [idx120.index(total_trainset.targets[idx]) for idx in train_idx]
171 + for idx in range(len(total_trainset.samples)):
172 + if total_trainset.samples[idx][1] not in idx120:
173 + continue
174 + total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index(total_trainset.samples[idx][1]))
175 + total_trainset = Subset(total_trainset, train_idx)
176 + total_trainset.targets = targets
177 +
178 + for idx in range(len(testset.samples)):
179 + if testset.samples[idx][1] not in idx120:
180 + continue
181 + testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1]))
182 + testset = Subset(testset, test_idx)
183 + print('reduced_imagenet train=', len(total_trainset))
184 + else:
185 + raise ValueError('invalid dataset name=%s' % dataset)
186 +
187 + if total_aug is not None and augs is not None:
188 + total_trainset.set_preaug(augs, total_aug)
189 + print('set_preaug-')
190 +
191 + train_sampler = None
192 + if split > 0.0:
193 + sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0)
194 + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
195 + for _ in range(split_idx + 1):
196 + train_idx, valid_idx = next(sss)
197 +
198 + if target_lb >= 0:
199 + train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb]
200 + valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb]
201 +
202 + train_sampler = SubsetRandomSampler(train_idx)
203 + valid_sampler = SubsetSampler(valid_idx)
204 +
205 + if multinode:
206 + train_sampler = torch.utils.data.distributed.DistributedSampler(Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank())
207 + else:
208 + valid_sampler = SubsetSampler([])
209 +
210 + if multinode:
211 + train_sampler = torch.utils.data.distributed.DistributedSampler(total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
212 + logger.info(f'----- dataset with DistributedSampler {dist.get_rank()}/{dist.get_world_size()}')
213 +
214 + trainloader = torch.utils.data.DataLoader(
215 + total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8, pin_memory=True,
216 + sampler=train_sampler, drop_last=True)
217 + validloader = torch.utils.data.DataLoader(
218 + total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True,
219 + sampler=valid_sampler, drop_last=False)
220 +
221 + testloader = torch.utils.data.DataLoader(
222 + testset, batch_size=batch, shuffle=False, num_workers=8, pin_memory=True,
223 + drop_last=False
224 + )
225 + return train_sampler, trainloader, validloader, testloader
226 +
227 +
228 +class CutoutDefault(object):
229 + """
230 + Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
231 + """
232 + def __init__(self, length):
233 + self.length = length
234 +
235 + def __call__(self, img):
236 + h, w = img.size(1), img.size(2)
237 + mask = np.ones((h, w), np.float32)
238 + y = np.random.randint(h)
239 + x = np.random.randint(w)
240 +
241 + y1 = np.clip(y - self.length // 2, 0, h)
242 + y2 = np.clip(y + self.length // 2, 0, h)
243 + x1 = np.clip(x - self.length // 2, 0, w)
244 + x2 = np.clip(x + self.length // 2, 0, w)
245 +
246 + mask[y1: y2, x1: x2] = 0.
247 + mask = torch.from_numpy(mask)
248 + mask = mask.expand_as(img)
249 + img *= mask
250 + return img
251 +
252 +
253 +class Augmentation(object):
254 + def __init__(self, policies):
255 + self.policies = policies
256 +
257 + def __call__(self, img):
258 + for _ in range(1):
259 + policy = random.choice(self.policies)
260 + for name, pr, level in policy:
261 + if random.random() > pr:
262 + continue
263 + img = apply_augment(img, name, level)
264 + return img
265 +
266 +
267 +class EfficientNetRandomCrop:
268 + def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3), area_range=(0.08, 1.0), max_attempts=10):
269 + assert 0.0 < min_covered
270 + assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1]
271 + assert 0 < area_range[0] <= area_range[1]
272 + assert 1 <= max_attempts
273 +
274 + self.min_covered = min_covered
275 + self.aspect_ratio_range = aspect_ratio_range
276 + self.area_range = area_range
277 + self.max_attempts = max_attempts
278 + self._fallback = EfficientNetCenterCrop(imgsize)
279 +
280 + def __call__(self, img):
281 + # https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111
282 + original_width, original_height = img.size
283 + min_area = self.area_range[0] * (original_width * original_height)
284 + max_area = self.area_range[1] * (original_width * original_height)
285 +
286 + for _ in range(self.max_attempts):
287 + aspect_ratio = random.uniform(*self.aspect_ratio_range)
288 + height = int(round(math.sqrt(min_area / aspect_ratio)))
289 + max_height = int(round(math.sqrt(max_area / aspect_ratio)))
290 +
291 + if max_height * aspect_ratio > original_width:
292 + max_height = (original_width + 0.5 - 1e-7) / aspect_ratio
293 + max_height = int(max_height)
294 + if max_height * aspect_ratio > original_width:
295 + max_height -= 1
296 +
297 + if max_height > original_height:
298 + max_height = original_height
299 +
300 + if height >= max_height:
301 + height = max_height
302 +
303 + height = int(round(random.uniform(height, max_height)))
304 + width = int(round(height * aspect_ratio))
305 + area = width * height
306 +
307 + if area < min_area or area > max_area:
308 + continue
309 + if width > original_width or height > original_height:
310 + continue
311 + if area < self.min_covered * (original_width * original_height):
312 + continue
313 + if width == original_width and height == original_height:
314 + return self._fallback(img) # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L102
315 +
316 + x = random.randint(0, original_width - width)
317 + y = random.randint(0, original_height - height)
318 + return img.crop((x, y, x + width, y + height))
319 +
320 + return self._fallback(img)
321 +
322 +
323 +class EfficientNetCenterCrop:
324 + def __init__(self, imgsize):
325 + self.imgsize = imgsize
326 +
327 + def __call__(self, img):
328 + """Crop the given PIL Image and resize it to desired size.
329 +
330 + Args:
331 + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
332 + output_size (sequence or int): (height, width) of the crop box. If int,
333 + it is used for both directions
334 + Returns:
335 + PIL Image: Cropped image.
336 + """
337 + image_width, image_height = img.size
338 + image_short = min(image_width, image_height)
339 +
340 + crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short
341 +
342 + crop_height, crop_width = crop_size, crop_size
343 + crop_top = int(round((image_height - crop_height) / 2.))
344 + crop_left = int(round((image_width - crop_width) / 2.))
345 + return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
346 +
347 +
348 +class SubsetSampler(Sampler):
349 + r"""Samples elements from a given list of indices, without replacement.
350 +
351 + Arguments:
352 + indices (sequence): a sequence of indices
353 + """
354 +
355 + def __init__(self, indices):
356 + self.indices = indices
357 +
358 + def __iter__(self):
359 + return (i for i in self.indices)
360 +
361 + def __len__(self):
362 + return len(self.indices)
1 +from __future__ import print_function
2 +import os
3 +import shutil
4 +import torch
5 +
6 +ARCHIVE_DICT = {
7 + 'train': {
8 + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
9 + 'md5': '1d675b47d978889d74fa0da5fadfb00e',
10 + },
11 + 'val': {
12 + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
13 + 'md5': '29b22e2961454d5413ddabcf34fc5622',
14 + },
15 + 'devkit': {
16 + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
17 + 'md5': 'fa75699e90414af021442c21a62c3abf',
18 + }
19 +}
20 +
21 +
22 +import torchvision
23 +from torchvision.datasets.utils import check_integrity, download_url
24 +
25 +
26 +# copy ILSVRC/ImageSets/CLS-LOC/train_cls.txt to ./root/
27 +# to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file
28 +class ImageNet(torchvision.datasets.ImageFolder):
29 + """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
30 +
31 + Args:
32 + root (string): Root directory of the ImageNet Dataset.
33 + split (string, optional): The dataset split, supports ``train``, or ``val``.
34 + download (bool, optional): If true, downloads the dataset from the internet and
35 + puts it in root directory. If dataset is already downloaded, it is not
36 + downloaded again.
37 + transform (callable, optional): A function/transform that takes in an PIL image
38 + and returns a transformed version. E.g, ``transforms.RandomCrop``
39 + target_transform (callable, optional): A function/transform that takes in the
40 + target and transforms it.
41 + loader (callable, optional): A function to load an image given its path.
42 +
43 + Attributes:
44 + classes (list): List of the class names.
45 + class_to_idx (dict): Dict with items (class_name, class_index).
46 + wnids (list): List of the WordNet IDs.
47 + wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
48 + imgs (list): List of (image path, class_index) tuples
49 + targets (list): The class_index value for each image in the dataset
50 + """
51 +
52 + def __init__(self, root, split='train', download=False, **kwargs):
53 + root = self.root = os.path.expanduser(root)
54 + self.split = self._verify_split(split)
55 +
56 + if download:
57 + self.download()
58 + wnid_to_classes = self._load_meta_file()[0]
59 +
60 + # to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file
61 + listfile = os.path.join(root, 'train_cls.txt')
62 + if split == 'train' and os.path.exists(listfile):
63 + torchvision.datasets.VisionDataset.__init__(self, root, **kwargs)
64 + with open(listfile, 'r') as f:
65 + datalist = [
66 + line.strip().split(' ')[0]
67 + for line in f.readlines()
68 + if line.strip()
69 + ]
70 +
71 + classes = list(set([line.split('/')[0] for line in datalist]))
72 + classes.sort()
73 + class_to_idx = {classes[i]: i for i in range(len(classes))}
74 +
75 + samples = [
76 + (os.path.join(self.split_folder, line + '.JPEG'), class_to_idx[line.split('/')[0]])
77 + for line in datalist
78 + ]
79 +
80 + self.loader = torchvision.datasets.folder.default_loader
81 + self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS
82 +
83 + self.classes = classes
84 + self.class_to_idx = class_to_idx
85 + self.samples = samples
86 + self.targets = [s[1] for s in samples]
87 +
88 + self.imgs = self.samples
89 + else:
90 + super(ImageNet, self).__init__(self.split_folder, **kwargs)
91 +
92 + self.root = root
93 +
94 + idcs = [idx for _, idx in self.imgs]
95 + self.wnids = self.classes
96 + self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)}
97 + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
98 + self.class_to_idx = {cls: idx
99 + for clss, idx in zip(self.classes, idcs)
100 + for cls in clss}
101 +
102 + def download(self):
103 + if not check_integrity(self.meta_file):
104 + tmpdir = os.path.join(self.root, 'tmp')
105 +
106 + archive_dict = ARCHIVE_DICT['devkit']
107 + download_and_extract_tar(archive_dict['url'], self.root,
108 + extract_root=tmpdir,
109 + md5=archive_dict['md5'])
110 + devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
111 + meta = parse_devkit(os.path.join(tmpdir, devkit_folder))
112 + self._save_meta_file(*meta)
113 +
114 + shutil.rmtree(tmpdir)
115 +
116 + if not os.path.isdir(self.split_folder):
117 + archive_dict = ARCHIVE_DICT[self.split]
118 + download_and_extract_tar(archive_dict['url'], self.root,
119 + extract_root=self.split_folder,
120 + md5=archive_dict['md5'])
121 +
122 + if self.split == 'train':
123 + prepare_train_folder(self.split_folder)
124 + elif self.split == 'val':
125 + val_wnids = self._load_meta_file()[1]
126 + prepare_val_folder(self.split_folder, val_wnids)
127 + else:
128 + msg = ("You set download=True, but a folder '{}' already exist in "
129 + "the root directory. If you want to re-download or re-extract the "
130 + "archive, delete the folder.")
131 + print(msg.format(self.split))
132 +
133 + @property
134 + def meta_file(self):
135 + return os.path.join(self.root, 'meta.bin')
136 +
137 + def _load_meta_file(self):
138 + if check_integrity(self.meta_file):
139 + return torch.load(self.meta_file)
140 + raise RuntimeError("Meta file not found or corrupted.",
141 + "You can use download=True to create it.")
142 +
143 + def _save_meta_file(self, wnid_to_class, val_wnids):
144 + torch.save((wnid_to_class, val_wnids), self.meta_file)
145 +
146 + def _verify_split(self, split):
147 + if split not in self.valid_splits:
148 + msg = "Unknown split {} .".format(split)
149 + msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
150 + raise ValueError(msg)
151 + return split
152 +
153 + @property
154 + def valid_splits(self):
155 + return 'train', 'val'
156 +
157 + @property
158 + def split_folder(self):
159 + return os.path.join(self.root, self.split)
160 +
161 + def extra_repr(self):
162 + return "Split: {split}".format(**self.__dict__)
163 +
164 +
165 +def extract_tar(src, dest=None, gzip=None, delete=False):
166 + import tarfile
167 +
168 + if dest is None:
169 + dest = os.path.dirname(src)
170 + if gzip is None:
171 + gzip = src.lower().endswith('.gz')
172 +
173 + mode = 'r:gz' if gzip else 'r'
174 + with tarfile.open(src, mode) as tarfh:
175 + tarfh.extractall(path=dest)
176 +
177 + if delete:
178 + os.remove(src)
179 +
180 +
181 +def download_and_extract_tar(url, download_root, extract_root=None, filename=None,
182 + md5=None, **kwargs):
183 + download_root = os.path.expanduser(download_root)
184 + if extract_root is None:
185 + extract_root = download_root
186 + if filename is None:
187 + filename = os.path.basename(url)
188 +
189 + if not check_integrity(os.path.join(download_root, filename), md5):
190 + download_url(url, download_root, filename=filename, md5=md5)
191 +
192 + extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
193 +
194 +
195 +def parse_devkit(root):
196 + idx_to_wnid, wnid_to_classes = parse_meta(root)
197 + val_idcs = parse_val_groundtruth(root)
198 + val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
199 + return wnid_to_classes, val_wnids
200 +
201 +
202 +def parse_meta(devkit_root, path='data', filename='meta.mat'):
203 + import scipy.io as sio
204 +
205 + metafile = os.path.join(devkit_root, path, filename)
206 + meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
207 + nums_children = list(zip(*meta))[4]
208 + meta = [meta[idx] for idx, num_children in enumerate(nums_children)
209 + if num_children == 0]
210 + idcs, wnids, classes = list(zip(*meta))[:3]
211 + classes = [tuple(clss.split(', ')) for clss in classes]
212 + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
213 + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
214 + return idx_to_wnid, wnid_to_classes
215 +
216 +
217 +def parse_val_groundtruth(devkit_root, path='data',
218 + filename='ILSVRC2012_validation_ground_truth.txt'):
219 + with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
220 + val_idcs = txtfh.readlines()
221 + return [int(val_idx) for val_idx in val_idcs]
222 +
223 +
224 +def prepare_train_folder(folder):
225 + for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
226 + extract_tar(archive, os.path.splitext(archive)[0], delete=True)
227 +
228 +
229 +def prepare_val_folder(folder, wnids):
230 + img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
231 +
232 + for wnid in set(wnids):
233 + os.mkdir(os.path.join(folder, wnid))
234 +
235 + for wnid, img_file in zip(wnids, img_files):
236 + shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
237 +
238 +
239 +def _splitexts(root):
240 + exts = []
241 + ext = '.'
242 + while ext:
243 + root, ext = os.path.splitext(root)
244 + exts.append(ext)
245 + return root, ''.join(reversed(exts))
1 +import torch
2 +
3 +from theconf import Config as C
4 +
5 +
6 +def adjust_learning_rate_resnet(optimizer):
7 + """
8 + Sets the learning rate to the initial LR decayed by 10 on every predefined epochs
9 + Ref: AutoAugment
10 + """
11 +
12 + if C.get()['epoch'] == 90:
13 + return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80])
14 + elif C.get()['epoch'] == 270: # autoaugment
15 + return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240])
16 + else:
17 + raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch'])
1 +import copy
2 +
3 +import torch
4 +import numpy as np
5 +from collections import defaultdict
6 +
7 +from torch import nn
8 +
9 +
10 +def accuracy(output, target, topk=(1,)):
11 + """Computes the precision@k for the specified values of k"""
12 + maxk = max(topk)
13 + batch_size = target.size(0)
14 +
15 + _, pred = output.topk(maxk, 1, True, True)
16 + pred = pred.t()
17 + correct = pred.eq(target.view(1, -1).expand_as(pred))
18 +
19 + res = []
20 + for k in topk:
21 + correct_k = correct[:k].view(-1).float().sum(0)
22 + res.append(correct_k.mul_(1. / batch_size))
23 + return res
24 +
25 +
26 +class CrossEntropyLabelSmooth(torch.nn.Module):
27 + def __init__(self, num_classes, epsilon, reduction='mean'):
28 + super(CrossEntropyLabelSmooth, self).__init__()
29 + self.num_classes = num_classes
30 + self.epsilon = epsilon
31 + self.reduction = reduction
32 + self.logsoftmax = torch.nn.LogSoftmax(dim=1)
33 +
34 + def forward(self, input, target): # pylint: disable=redefined-builtin
35 + log_probs = self.logsoftmax(input)
36 + targets = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1)
37 + if self.epsilon > 0.0:
38 + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
39 + targets = targets.detach()
40 + loss = (-targets * log_probs)
41 +
42 + if self.reduction in ['avg', 'mean']:
43 + loss = torch.mean(torch.sum(loss, dim=1))
44 + elif self.reduction == 'sum':
45 + loss = loss.sum()
46 + return loss
47 +
48 +
49 +class Accumulator:
50 + def __init__(self):
51 + self.metrics = defaultdict(lambda: 0.)
52 +
53 + def add(self, key, value):
54 + self.metrics[key] += value
55 +
56 + def add_dict(self, dict):
57 + for key, value in dict.items():
58 + self.add(key, value)
59 +
60 + def __getitem__(self, item):
61 + return self.metrics[item]
62 +
63 + def __setitem__(self, key, value):
64 + self.metrics[key] = value
65 +
66 + def get_dict(self):
67 + return copy.deepcopy(dict(self.metrics))
68 +
69 + def items(self):
70 + return self.metrics.items()
71 +
72 + def __str__(self):
73 + return str(dict(self.metrics))
74 +
75 + def __truediv__(self, other):
76 + newone = Accumulator()
77 + for key, value in self.items():
78 + if isinstance(other, str):
79 + if other != key:
80 + newone[key] = value / self[other]
81 + else:
82 + newone[key] = value
83 + else:
84 + newone[key] = value / other
85 + return newone
86 +
87 +
88 +class SummaryWriterDummy:
89 + def __init__(self, log_dir):
90 + pass
91 +
92 + def add_scalar(self, *args, **kwargs):
93 + pass
1 +import torch
2 +
3 +from torch import nn
4 +from torch.nn import DataParallel
5 +from torch.nn.parallel import DistributedDataParallel
6 +import torch.backends.cudnn as cudnn
7 +# from torchvision import models
8 +import numpy as np
9 +
10 +from FastAutoAugment.networks.resnet import ResNet
11 +from FastAutoAugment.networks.pyramidnet import PyramidNet
12 +from FastAutoAugment.networks.shakeshake.shake_resnet import ShakeResNet
13 +from FastAutoAugment.networks.wideresnet import WideResNet
14 +from FastAutoAugment.networks.shakeshake.shake_resnext import ShakeResNeXt
15 +from FastAutoAugment.networks.efficientnet_pytorch import EfficientNet, RoutingFn
16 +from FastAutoAugment.tf_port.tpu_bn import TpuBatchNormalization
17 +
18 +
19 +def get_model(conf, num_class=10, local_rank=-1):
20 + name = conf['type']
21 +
22 + if name == 'resnet50':
23 + model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True)
24 + elif name == 'resnet200':
25 + model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True)
26 + elif name == 'wresnet40_2':
27 + model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class)
28 + elif name == 'wresnet28_10':
29 + model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class)
30 +
31 + elif name == 'shakeshake26_2x32d':
32 + model = ShakeResNet(26, 32, num_class)
33 + elif name == 'shakeshake26_2x64d':
34 + model = ShakeResNet(26, 64, num_class)
35 + elif name == 'shakeshake26_2x96d':
36 + model = ShakeResNet(26, 96, num_class)
37 + elif name == 'shakeshake26_2x112d':
38 + model = ShakeResNet(26, 112, num_class)
39 +
40 + elif name == 'shakeshake26_2x96d_next':
41 + model = ShakeResNeXt(26, 96, 4, num_class)
42 +
43 + elif name == 'pyramid':
44 + model = PyramidNet('cifar10', depth=conf['depth'], alpha=conf['alpha'], num_classes=num_class, bottleneck=conf['bottleneck'])
45 +
46 + elif 'efficientnet' in name:
47 + model = EfficientNet.from_name(name, condconv_num_expert=conf['condconv_num_expert'], norm_layer=None) # TpuBatchNormalization
48 + if local_rank >= 0:
49 + model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
50 + def kernel_initializer(module):
51 + def get_fan_in_out(module):
52 + num_input_fmaps = module.weight.size(1)
53 + num_output_fmaps = module.weight.size(0)
54 + receptive_field_size = 1
55 + if module.weight.dim() > 2:
56 + receptive_field_size = module.weight[0][0].numel()
57 + fan_in = num_input_fmaps * receptive_field_size
58 + fan_out = num_output_fmaps * receptive_field_size
59 + return fan_in, fan_out
60 +
61 + if isinstance(module, torch.nn.Conv2d):
62 + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L58
63 + fan_in, fan_out = get_fan_in_out(module)
64 + torch.nn.init.normal_(module.weight, mean=0.0, std=np.sqrt(2.0 / fan_out))
65 + if module.bias is not None:
66 + torch.nn.init.constant_(module.bias, val=0.)
67 + elif isinstance(module, RoutingFn):
68 + torch.nn.init.xavier_uniform_(module.weight)
69 + torch.nn.init.constant_(module.bias, val=0.)
70 + elif isinstance(module, torch.nn.Linear):
71 + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L82
72 + fan_in, fan_out = get_fan_in_out(module)
73 + delta = 1.0 / np.sqrt(fan_out)
74 + torch.nn.init.uniform_(module.weight, a=-delta, b=delta)
75 + if module.bias is not None:
76 + torch.nn.init.constant_(module.bias, val=0.)
77 + model.apply(kernel_initializer)
78 + else:
79 + raise NameError('no model named, %s' % name)
80 +
81 + if local_rank >= 0:
82 + device = torch.device('cuda', local_rank)
83 + model = model.to(device)
84 + model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
85 + else:
86 + model = model.cuda()
87 +# model = DataParallel(model)
88 +
89 + cudnn.benchmark = True
90 + return model
91 +
92 +
93 +def num_class(dataset):
94 + return {
95 + 'cifar10': 10,
96 + 'reduced_cifar10': 10,
97 + 'cifar10.1': 10,
98 + 'cifar100': 100,
99 + 'svhn': 10,
100 + 'reduced_svhn': 10,
101 + 'imagenet': 1000,
102 + 'reduced_imagenet': 120,
103 + }[dataset]
1 +__version__ = "0.5.1"
2 +from .model import EfficientNet, RoutingFn
3 +from .utils import (
4 + GlobalParams,
5 + BlockArgs,
6 + BlockDecoder,
7 + efficientnet,
8 + get_model_params,
9 +)
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +from torch._six import container_abcs
5 +
6 +from itertools import repeat
7 +from functools import partial
8 +from typing import Union, List, Tuple, Optional, Callable
9 +import numpy as np
10 +import math
11 +
12 +
13 +def _ntuple(n):
14 + def parse(x):
15 + if isinstance(x, container_abcs.Iterable):
16 + return x
17 + return tuple(repeat(x, n))
18 + return parse
19 +
20 +
21 +_single = _ntuple(1)
22 +_pair = _ntuple(2)
23 +_triple = _ntuple(3)
24 +_quadruple = _ntuple(4)
25 +
26 +
27 +def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
28 + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
29 +
30 +
31 +def _get_padding(kernel_size, stride=1, dilation=1, **_):
32 + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
33 + return padding
34 +
35 +
36 +def _calc_same_pad(i: int, k: int, s: int, d: int):
37 + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
38 +
39 +
40 +def conv2d_same(
41 + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
42 + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
43 + ih, iw = x.size()[-2:]
44 + kh, kw = weight.size()[-2:]
45 + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
46 + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
47 + if pad_h > 0 or pad_w > 0:
48 + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
49 + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
50 +
51 +
52 +def get_padding_value(padding, kernel_size, **kwargs):
53 + dynamic = False
54 + if isinstance(padding, str):
55 + # for any string padding, the padding will be calculated for you, one of three ways
56 + padding = padding.lower()
57 + if padding == 'same':
58 + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
59 + if _is_static_pad(kernel_size, **kwargs):
60 + # static case, no extra overhead
61 + padding = _get_padding(kernel_size, **kwargs)
62 + else:
63 + # dynamic padding
64 + padding = 0
65 + dynamic = True
66 + elif padding == 'valid':
67 + # 'VALID' padding, same as padding=0
68 + padding = 0
69 + else:
70 + # Default to PyTorch style 'same'-ish symmetric padding
71 + padding = _get_padding(kernel_size, **kwargs)
72 + return padding, dynamic
73 +
74 +
75 +def get_condconv_initializer(initializer, num_experts, expert_shape):
76 + def condconv_initializer(weight):
77 + """CondConv initializer function."""
78 + num_params = np.prod(expert_shape)
79 + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params):
80 + raise (ValueError('CondConv variables must have shape [num_experts, num_params]'))
81 + for i in range(num_experts):
82 + initializer(weight[i].view(expert_shape))
83 + return condconv_initializer
84 +
85 +
86 +class CondConv2d(nn.Module):
87 + """ Conditional Convolution
88 + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
89 + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
90 + https://github.com/pytorch/pytorch/issues/17983
91 + """
92 + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
93 +
94 + def __init__(self, in_channels, out_channels, kernel_size=3,
95 + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
96 + super(CondConv2d, self).__init__()
97 + assert num_experts > 1
98 +
99 + if isinstance(stride, container_abcs.Iterable) and len(stride) == 1:
100 + stride = stride[0]
101 + # print('CondConv', num_experts)
102 +
103 + self.in_channels = in_channels
104 + self.out_channels = out_channels
105 + self.kernel_size = _pair(kernel_size)
106 + self.stride = _pair(stride)
107 + padding_val, is_padding_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
108 + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
109 + self.padding = _pair(padding_val)
110 + self.dilation = _pair(dilation)
111 + self.groups = groups
112 + self.num_experts = num_experts
113 +
114 + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
115 + weight_num_param = 1
116 + for wd in self.weight_shape:
117 + weight_num_param *= wd
118 + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
119 +
120 + if bias:
121 + self.bias_shape = (self.out_channels,)
122 + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
123 + else:
124 + self.register_parameter('bias', None)
125 +
126 + self.reset_parameters()
127 +
128 + def reset_parameters(self):
129 + num_input_fmaps = self.weight.size(1)
130 + num_output_fmaps = self.weight.size(0)
131 + receptive_field_size = 1
132 + if self.weight.dim() > 2:
133 + receptive_field_size = self.weight[0][0].numel()
134 + fan_in = num_input_fmaps * receptive_field_size
135 + fan_out = num_output_fmaps * receptive_field_size
136 +
137 + init_weight = get_condconv_initializer(partial(nn.init.normal_, mean=0.0, std=np.sqrt(2.0 / fan_out)), self.num_experts, self.weight_shape)
138 + init_weight(self.weight)
139 + if self.bias is not None:
140 + # fan_in = np.prod(self.weight_shape[1:])
141 + # bound = 1 / math.sqrt(fan_in)
142 + init_bias = get_condconv_initializer(partial(nn.init.constant_, val=0), self.num_experts, self.bias_shape)
143 + init_bias(self.bias)
144 +
145 + def forward(self, x, routing_weights):
146 + x_orig = x
147 + B, C, H, W = x.shape
148 + weight = torch.matmul(routing_weights, self.weight) # (Expert x out x in x 3x3) --> (B x out x in x 3x3)
149 + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
150 + weight = weight.view(new_weight_shape) # (B*out x in x 3 x 3)
151 + bias = None
152 + if self.bias is not None:
153 + bias = torch.matmul(routing_weights, self.bias)
154 + bias = bias.view(B * self.out_channels)
155 + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
156 + x = x.view(1, B * C, H, W)
157 + if self.dynamic_padding:
158 + out = conv2d_same(
159 + x, weight, bias, stride=self.stride, padding=self.padding,
160 + dilation=self.dilation, groups=self.groups * B)
161 + else:
162 + out = F.conv2d(
163 + x, weight, bias, stride=self.stride, padding=self.padding,
164 + dilation=self.dilation, groups=self.groups * B)
165 +
166 + # out : (1 x B*out x ...)
167 + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
168 +
169 + # out2 = self.forward_legacy(x_orig, routing_weights)
170 + # lt = torch.lt(torch.abs(torch.add(out, -out2)), 1e-8)
171 + # assert torch.all(lt), torch.abs(torch.add(out, -out2))[lt]
172 + # print('checked')
173 + return out
174 +
175 + def forward_legacy(self, x, routing_weights):
176 + # Literal port (from TF definition)
177 + B, C, H, W = x.shape
178 + weight = torch.matmul(routing_weights, self.weight) # (Expert x out x in x 3x3) --> (B x out x in x 3x3)
179 + x = torch.split(x, 1, 0)
180 + weight = torch.split(weight, 1, 0)
181 + if self.bias is not None:
182 + bias = torch.matmul(routing_weights, self.bias)
183 + bias = torch.split(bias, 1, 0)
184 + else:
185 + bias = [None] * B
186 + out = []
187 + if self.dynamic_padding:
188 + conv_fn = conv2d_same
189 + else:
190 + conv_fn = F.conv2d
191 + for xi, wi, bi in zip(x, weight, bias):
192 + wi = wi.view(*self.weight_shape)
193 + if bi is not None:
194 + bi = bi.view(*self.bias_shape)
195 + out.append(conv_fn(
196 + xi, wi, bi, stride=self.stride, padding=self.padding,
197 + dilation=self.dilation, groups=self.groups))
198 + out = torch.cat(out, 0)
199 + return out
1 +import torch
2 +from torch import nn
3 +from torch.nn import functional as F
4 +
5 +from functools import partial
6 +from .utils import (
7 + round_filters,
8 + round_repeats,
9 + drop_connect,
10 + get_same_padding_conv2d,
11 + get_model_params,
12 + efficientnet_params,
13 + load_pretrained_weights,
14 + MemoryEfficientSwish,
15 +)
16 +
17 +
18 +class RoutingFn(nn.Linear):
19 + pass
20 +
21 +
22 +class MBConvBlock(nn.Module):
23 + """
24 + Mobile Inverted Residual Bottleneck Block
25 +
26 + Args:
27 + block_args (namedtuple): BlockArgs, see above
28 + global_params (namedtuple): GlobalParam, see above
29 +
30 + Attributes:
31 + has_se (bool): Whether the block contains a Squeeze and Excitation layer.
32 + """
33 +
34 + def __init__(self, block_args, global_params, norm_layer=None):
35 + super().__init__()
36 + self._block_args = block_args
37 + self._bn_mom = 1 - global_params.batch_norm_momentum
38 + self._bn_eps = global_params.batch_norm_epsilon
39 + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
40 + self.id_skip = block_args.id_skip # skip connection and drop connect
41 + if norm_layer is None:
42 + norm_layer = nn.BatchNorm2d
43 +
44 + self.condconv_num_expert = block_args.condconv_num_expert
45 + if self._is_condconv():
46 + self.routing_fn = RoutingFn(self._block_args.input_filters, self.condconv_num_expert)
47 +
48 + # Get static or dynamic convolution depending on image size
49 + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size, condconv_num_expert=block_args.condconv_num_expert)
50 + Conv2dse = get_same_padding_conv2d(image_size=global_params.image_size)
51 +
52 + # Expansion phase
53 + inp = self._block_args.input_filters # number of input channels
54 + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
55 + if self._block_args.expand_ratio != 1:
56 + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
57 + self._bn0 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
58 +
59 + # Depthwise convolution phase
60 + k = self._block_args.kernel_size
61 + s = self._block_args.stride
62 + self._depthwise_conv = Conv2d(
63 + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
64 + kernel_size=k, stride=s, bias=False)
65 + self._bn1 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
66 +
67 + # Squeeze and Excitation layer, if desired
68 + if self.has_se:
69 + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
70 + self._se_reduce = Conv2dse(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
71 + self._se_expand = Conv2dse(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
72 +
73 + # Output phase
74 + final_oup = self._block_args.output_filters
75 + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
76 + self._bn2 = norm_layer(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
77 + self._swish = MemoryEfficientSwish()
78 +
79 + def _is_condconv(self):
80 + return self.condconv_num_expert > 1
81 +
82 + def forward(self, inputs, drop_connect_rate=None):
83 + """
84 + :param inputs: input tensor
85 + :param drop_connect_rate: drop connect rate (float, between 0 and 1)
86 + :return: output of block
87 + """
88 +
89 + if self._is_condconv():
90 + feat = F.adaptive_avg_pool2d(inputs, 1).flatten(1)
91 + routing_w = torch.sigmoid(self.routing_fn(feat))
92 +
93 + if self._block_args.expand_ratio != 1:
94 + _expand_conv = partial(self._expand_conv, routing_weights=routing_w)
95 + _depthwise_conv = partial(self._depthwise_conv, routing_weights=routing_w)
96 + _project_conv = partial(self._project_conv, routing_weights=routing_w)
97 + else:
98 + if self._block_args.expand_ratio != 1:
99 + _expand_conv = self._expand_conv
100 + _depthwise_conv, _project_conv = self._depthwise_conv, self._project_conv
101 +
102 + # Expansion and Depthwise Convolution
103 + x = inputs
104 + if self._block_args.expand_ratio != 1:
105 + x = self._swish(self._bn0(_expand_conv(inputs)))
106 + x = self._swish(self._bn1(_depthwise_conv(x)))
107 +
108 + # Squeeze and Excitation
109 + if self.has_se:
110 + x_squeezed = F.adaptive_avg_pool2d(x, 1)
111 + x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
112 + x = torch.sigmoid(x_squeezed) * x
113 +
114 + x = self._bn2(_project_conv(x))
115 +
116 + # Skip connection and drop connect
117 + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
118 + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
119 + if drop_connect_rate:
120 + x = drop_connect(x, drop_p=drop_connect_rate, training=self.training)
121 + x = x + inputs # skip connection
122 + return x
123 +
124 + def set_swish(self):
125 + """Sets swish function as memory efficient (for training) or standard (for export)"""
126 + self._swish = MemoryEfficientSwish()
127 +
128 +
129 +class EfficientNet(nn.Module):
130 + """
131 + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
132 +
133 + Args:
134 + blocks_args (list): A list of BlockArgs to construct blocks
135 + global_params (namedtuple): A set of GlobalParams shared between blocks
136 +
137 + Example:
138 + model = EfficientNet.from_pretrained('efficientnet-b0')
139 +
140 + """
141 +
142 + def __init__(self, blocks_args=None, global_params=None, norm_layer=None):
143 + super().__init__()
144 + assert isinstance(blocks_args, list), 'blocks_args should be a list'
145 + assert len(blocks_args) > 0, 'block args must be greater than 0'
146 + self._global_params = global_params
147 + self._blocks_args = blocks_args
148 + if norm_layer is None:
149 + norm_layer = nn.BatchNorm2d
150 +
151 + # Get static or dynamic convolution depending on image size
152 + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
153 +
154 + # Batch norm parameters
155 + bn_mom = 1 - self._global_params.batch_norm_momentum
156 + bn_eps = self._global_params.batch_norm_epsilon
157 +
158 + # Stem
159 + in_channels = 3 # rgb
160 + out_channels = round_filters(32, self._global_params) # number of output channels
161 + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
162 + self._bn0 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
163 +
164 + # Build blocks
165 + self._blocks = nn.ModuleList([])
166 + for idx, block_args in enumerate(self._blocks_args):
167 + # Update block input and output filters based on depth multiplier.
168 + block_args = block_args._replace(
169 + input_filters=round_filters(block_args.input_filters, self._global_params),
170 + output_filters=round_filters(block_args.output_filters, self._global_params),
171 + num_repeat=round_repeats(block_args.num_repeat, self._global_params)
172 + )
173 +
174 + # The first block needs to take care of stride and filter size increase.
175 + self._blocks.append(MBConvBlock(block_args, self._global_params, norm_layer=norm_layer))
176 + if block_args.num_repeat > 1:
177 + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
178 + for _ in range(block_args.num_repeat - 1):
179 + self._blocks.append(MBConvBlock(block_args, self._global_params, norm_layer=norm_layer))
180 +
181 + # Head
182 + in_channels = block_args.output_filters # output of final block
183 + out_channels = round_filters(1280, self._global_params)
184 + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
185 + self._bn1 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
186 +
187 + # Final linear layer
188 + self._avg_pooling = nn.AdaptiveAvgPool2d(1)
189 + self._dropout = nn.Dropout(self._global_params.dropout_rate)
190 + self._fc = nn.Linear(out_channels, self._global_params.num_classes)
191 + self._swish = MemoryEfficientSwish()
192 +
193 + def set_swish(self):
194 + """Sets swish function as memory efficient (for training) or standard (for export)"""
195 + self._swish = MemoryEfficientSwish()
196 + for block in self._blocks:
197 + block.set_swish()
198 +
199 + def extract_features(self, inputs):
200 + """ Returns output of the final convolution layer """
201 +
202 + # Stem
203 + x = self._swish(self._bn0(self._conv_stem(inputs)))
204 +
205 + # Blocks
206 + for idx, block in enumerate(self._blocks):
207 + drop_connect_rate = self._global_params.drop_connect_rate
208 + if drop_connect_rate:
209 + drop_connect_rate *= float(idx) / len(self._blocks)
210 + x = block(x, drop_connect_rate=drop_connect_rate)
211 +
212 + # Head
213 + x = self._swish(self._bn1(self._conv_head(x)))
214 +
215 + return x
216 +
217 + def forward(self, inputs):
218 + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
219 + bs = inputs.size(0)
220 + # Convolution layers
221 + x = self.extract_features(inputs)
222 +
223 + # Pooling and final linear layer
224 + x = self._avg_pooling(x)
225 + x = x.view(bs, -1)
226 + x = self._dropout(x)
227 + x = self._fc(x)
228 + return x
229 +
230 + @classmethod
231 + def from_name(cls, model_name, override_params=None, norm_layer=None, condconv_num_expert=1):
232 + cls._check_model_name_is_valid(model_name)
233 + blocks_args, global_params = get_model_params(model_name, override_params, condconv_num_expert=condconv_num_expert)
234 + return cls(blocks_args, global_params, norm_layer=norm_layer)
235 +
236 + @classmethod
237 + def from_pretrained(cls, model_name, num_classes=1000):
238 + model = cls.from_name(model_name, override_params={'num_classes': num_classes})
239 + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
240 +
241 + return model
242 +
243 + @classmethod
244 + def get_image_size(cls, model_name):
245 + cls._check_model_name_is_valid(model_name)
246 + _, _, res, _ = efficientnet_params(model_name)
247 + return res
248 +
249 + @classmethod
250 + def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
251 + """ Validates model name. None that pretrained weights are only available for
252 + the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
253 + num_models = 4 if also_need_pretrained_weights else 8
254 + valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
255 + if model_name not in valid_models:
256 + raise ValueError(f'model_name={model_name} should be one of: ' + ', '.join(valid_models))
1 +"""
2 +This file contains helper functions for building the model and for loading model parameters.
3 +These helper functions are built to mirror those in the official TensorFlow implementation.
4 +"""
5 +
6 +import re
7 +import math
8 +import collections
9 +from functools import partial
10 +import torch
11 +from torch import nn
12 +from torch.nn import functional as F
13 +from torch.utils import model_zoo
14 +
15 +########################################################################
16 +############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
17 +########################################################################
18 +
19 +
20 +# Parameters for the entire model (stem, all blocks, and head)
21 +from FastAutoAugment.networks.efficientnet_pytorch.condconv import CondConv2d
22 +
23 +GlobalParams = collections.namedtuple('GlobalParams', [
24 + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
25 + 'num_classes', 'width_coefficient', 'depth_coefficient',
26 + 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
27 +
28 +# Parameters for an individual model block
29 +BlockArgs = collections.namedtuple('BlockArgs', [
30 + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
31 + 'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'condconv_num_expert'])
32 +
33 +# Change namedtuple defaults
34 +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
35 +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
36 +
37 +
38 +class SwishImplementation(torch.autograd.Function):
39 + @staticmethod
40 + def forward(ctx, i):
41 + result = i * torch.sigmoid(i)
42 + ctx.save_for_backward(i)
43 + return result
44 +
45 + @staticmethod
46 + def backward(ctx, grad_output):
47 + i = ctx.saved_tensors[0]
48 + sigmoid_i = torch.sigmoid(i)
49 + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
50 +
51 +
52 +class MemoryEfficientSwish(nn.Module):
53 + def forward(self, x):
54 + return SwishImplementation.apply(x)
55 +
56 +
57 +def round_filters(filters, global_params):
58 + """ Calculate and round number of filters based on depth multiplier. """
59 + multiplier = global_params.width_coefficient
60 + if not multiplier:
61 + return filters
62 + divisor = global_params.depth_divisor
63 + min_depth = global_params.min_depth
64 + filters *= multiplier
65 + min_depth = min_depth or divisor
66 + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
67 + if new_filters < 0.9 * filters: # prevent rounding by more than 10%
68 + new_filters += divisor
69 + return int(new_filters)
70 +
71 +
72 +def round_repeats(repeats, global_params):
73 + """ Round number of filters based on depth multiplier. """
74 + multiplier = global_params.depth_coefficient
75 + if not multiplier:
76 + return repeats
77 + return int(math.ceil(multiplier * repeats))
78 +
79 +
80 +def drop_connect(inputs, drop_p, training):
81 + """ Drop connect. """
82 + if not training:
83 + return inputs * (1. - drop_p)
84 + batch_size = inputs.shape[0]
85 + random_tensor = torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
86 + binary_tensor = random_tensor > drop_p
87 + output = inputs * binary_tensor.float()
88 + # output = inputs / (1. - drop_p) * binary_tensor.float()
89 + return output
90 +
91 + # if not training: return inputs
92 + # batch_size = inputs.shape[0]
93 + # keep_prob = 1 - drop_p
94 + # random_tensor = keep_prob
95 + # random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
96 + # binary_tensor = torch.floor(random_tensor)
97 + # output = inputs / keep_prob * binary_tensor
98 + # return output
99 +
100 +
101 +def get_same_padding_conv2d(image_size=None, condconv_num_expert=1):
102 + """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
103 + Static padding is necessary for ONNX exporting of models. """
104 + if condconv_num_expert > 1:
105 + return partial(CondConv2d, num_experts=condconv_num_expert)
106 + elif image_size is None:
107 + return Conv2dDynamicSamePadding
108 + else:
109 + return partial(Conv2dStaticSamePadding, image_size=image_size)
110 +
111 +
112 +class Conv2dDynamicSamePadding(nn.Conv2d):
113 + """ 2D Convolutions like TensorFlow, for a dynamic image size """
114 +
115 + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
116 + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
117 + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
118 +
119 + def forward(self, x):
120 + ih, iw = x.size()[-2:]
121 + kh, kw = self.weight.size()[-2:]
122 + sh, sw = self.stride
123 + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
124 + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
125 + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
126 + if pad_h > 0 or pad_w > 0:
127 + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
128 + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
129 +
130 +
131 +class Conv2dStaticSamePadding(nn.Conv2d):
132 + """ 2D Convolutions like TensorFlow, for a fixed image size"""
133 +
134 + def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
135 + super().__init__(in_channels, out_channels, kernel_size, **kwargs)
136 + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
137 +
138 + # Calculate padding based on image size and save it
139 + assert image_size is not None
140 + ih, iw = image_size if type(image_size) == list else [image_size, image_size]
141 + kh, kw = self.weight.size()[-2:]
142 + sh, sw = self.stride
143 + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
144 + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
145 + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
146 + if pad_h > 0 or pad_w > 0:
147 + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
148 + else:
149 + self.static_padding = Identity()
150 +
151 + def forward(self, x):
152 + x = self.static_padding(x)
153 + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
154 + return x
155 +
156 +
157 +class Identity(nn.Module):
158 + def __init__(self, ):
159 + super(Identity, self).__init__()
160 +
161 + def forward(self, input):
162 + return input
163 +
164 +
165 +########################################################################
166 +############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
167 +########################################################################
168 +
169 +
170 +def efficientnet_params(model_name):
171 + """ Map EfficientNet model name to parameter coefficients. """
172 + params_dict = {
173 + # Coefficients: width,depth,res,dropout
174 + 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
175 + 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
176 + 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
177 + 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
178 + 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
179 + 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
180 + 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
181 + 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
182 + }
183 + return params_dict[model_name]
184 +
185 +
186 +class BlockDecoder(object):
187 + """ Block Decoder for readability, straight from the official TensorFlow repository """
188 +
189 + @staticmethod
190 + def _decode_block_string(block_string):
191 + """ Gets a block through a string notation of arguments. """
192 + assert isinstance(block_string, str)
193 +
194 + ops = block_string.split('_')
195 + options = {}
196 + for op in ops:
197 + splits = re.split(r'(\d.*)', op)
198 + if len(splits) >= 2:
199 + key, value = splits[:2]
200 + options[key] = value
201 +
202 + # Check stride
203 + assert (('s' in options and len(options['s']) == 1) or
204 + (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
205 +
206 + return BlockArgs(
207 + kernel_size=int(options['k']),
208 + num_repeat=int(options['r']),
209 + input_filters=int(options['i']),
210 + output_filters=int(options['o']),
211 + expand_ratio=int(options['e']),
212 + id_skip=('noskip' not in block_string),
213 + se_ratio=float(options['se']) if 'se' in options else None,
214 + stride=[int(options['s'][0])],
215 + condconv_num_expert=0
216 + )
217 +
218 + @staticmethod
219 + def _encode_block_string(block):
220 + """Encodes a block to a string."""
221 + args = [
222 + 'r%d' % block.num_repeat,
223 + 'k%d' % block.kernel_size,
224 + 's%d%d' % (block.strides[0], block.strides[1]),
225 + 'e%s' % block.expand_ratio,
226 + 'i%d' % block.input_filters,
227 + 'o%d' % block.output_filters
228 + ]
229 + if 0 < block.se_ratio <= 1:
230 + args.append('se%s' % block.se_ratio)
231 + if block.id_skip is False:
232 + args.append('noskip')
233 + return '_'.join(args)
234 +
235 + @staticmethod
236 + def decode(string_list):
237 + """
238 + Decodes a list of string notations to specify blocks inside the network.
239 +
240 + :param string_list: a list of strings, each string is a notation of block
241 + :return: a list of BlockArgs namedtuples of block args
242 + """
243 + assert isinstance(string_list, list)
244 + blocks_args = []
245 + for block_string in string_list:
246 + blocks_args.append(BlockDecoder._decode_block_string(block_string))
247 + return blocks_args
248 +
249 + @staticmethod
250 + def encode(blocks_args):
251 + """
252 + Encodes a list of BlockArgs to a list of strings.
253 +
254 + :param blocks_args: a list of BlockArgs namedtuples of block args
255 + :return: a list of strings, each string is a notation of block
256 + """
257 + block_strings = []
258 + for block in blocks_args:
259 + block_strings.append(BlockDecoder._encode_block_string(block))
260 + return block_strings
261 +
262 +
263 +def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
264 + drop_connect_rate=0.2, image_size=None, num_classes=1000, condconv_num_expert=1):
265 + """ Creates a efficientnet model. """
266 +
267 + blocks_args = [
268 + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
269 + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
270 + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
271 + 'r1_k3_s11_e6_i192_o320_se0.25',
272 + ]
273 + blocks_args = BlockDecoder.decode(blocks_args)
274 +
275 + blocks_args_new = blocks_args[:-3]
276 + for blocks_arg in blocks_args[-3:]:
277 + blocks_arg = blocks_arg._replace(condconv_num_expert=condconv_num_expert)
278 + blocks_args_new.append(blocks_arg)
279 + blocks_args = blocks_args_new
280 +
281 + global_params = GlobalParams(
282 + batch_norm_momentum=0.99,
283 + batch_norm_epsilon=1e-3,
284 + dropout_rate=dropout_rate,
285 + drop_connect_rate=drop_connect_rate,
286 + # data_format='channels_last', # removed, this is always true in PyTorch
287 + num_classes=num_classes,
288 + width_coefficient=width_coefficient,
289 + depth_coefficient=depth_coefficient,
290 + depth_divisor=8,
291 + min_depth=None,
292 + image_size=image_size,
293 + )
294 +
295 + return blocks_args, global_params
296 +
297 +
298 +def get_model_params(model_name, override_params, condconv_num_expert=1):
299 + """ Get the block args and global params for a given model """
300 + if model_name.startswith('efficientnet'):
301 + w, d, s, p = efficientnet_params(model_name)
302 + # note: all models have drop connect rate = 0.2
303 + blocks_args, global_params = efficientnet(
304 + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s, condconv_num_expert=condconv_num_expert)
305 + else:
306 + raise NotImplementedError('model name is not pre-defined: %s' % model_name)
307 + if override_params:
308 + # ValueError will be raised here if override_params has fields not included in global_params.
309 + global_params = global_params._replace(**override_params)
310 + return blocks_args, global_params
311 +
312 +
313 +url_map = {
314 + 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
315 + 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
316 + 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
317 + 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
318 + 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
319 + 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
320 + 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
321 + 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
322 +}
323 +
324 +
325 +def load_pretrained_weights(model, model_name, load_fc=True):
326 + """ Loads pretrained weights, and downloads if loading for the first time. """
327 + state_dict = model_zoo.load_url(url_map[model_name])
328 + if load_fc:
329 + model.load_state_dict(state_dict)
330 + else:
331 + state_dict.pop('_fc.weight')
332 + state_dict.pop('_fc.bias')
333 + res = model.load_state_dict(state_dict, strict=False)
334 + assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
335 + print('Loaded pretrained weights for {}'.format(model_name))
1 +import torch
2 +import torch.nn as nn
3 +import math
4 +
5 +from FastAutoAugment.networks.shakedrop import ShakeDrop
6 +
7 +
8 +def conv3x3(in_planes, out_planes, stride=1):
9 + """
10 + 3x3 convolution with padding
11 + """
12 + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
13 +
14 +
15 +class BasicBlock(nn.Module):
16 + outchannel_ratio = 1
17 +
18 + def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0):
19 + super(BasicBlock, self).__init__()
20 + self.bn1 = nn.BatchNorm2d(inplanes)
21 + self.conv1 = conv3x3(inplanes, planes, stride)
22 + self.bn2 = nn.BatchNorm2d(planes)
23 + self.conv2 = conv3x3(planes, planes)
24 + self.bn3 = nn.BatchNorm2d(planes)
25 + self.relu = nn.ReLU(inplace=True)
26 + self.downsample = downsample
27 + self.stride = stride
28 + self.shake_drop = ShakeDrop(p_shakedrop)
29 +
30 + def forward(self, x):
31 +
32 + out = self.bn1(x)
33 + out = self.conv1(out)
34 + out = self.bn2(out)
35 + out = self.relu(out)
36 + out = self.conv2(out)
37 + out = self.bn3(out)
38 +
39 + out = self.shake_drop(out)
40 +
41 + if self.downsample is not None:
42 + shortcut = self.downsample(x)
43 + featuremap_size = shortcut.size()[2:4]
44 + else:
45 + shortcut = x
46 + featuremap_size = out.size()[2:4]
47 +
48 + batch_size = out.size()[0]
49 + residual_channel = out.size()[1]
50 + shortcut_channel = shortcut.size()[1]
51 +
52 + if residual_channel != shortcut_channel:
53 + padding = torch.autograd.Variable(
54 + torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0],
55 + featuremap_size[1]).fill_(0))
56 + out += torch.cat((shortcut, padding), 1)
57 + else:
58 + out += shortcut
59 +
60 + return out
61 +
62 +
63 +class Bottleneck(nn.Module):
64 + outchannel_ratio = 4
65 +
66 + def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0):
67 + super(Bottleneck, self).__init__()
68 + self.bn1 = nn.BatchNorm2d(inplanes)
69 + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
70 + self.bn2 = nn.BatchNorm2d(planes)
71 + self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride,
72 + padding=1, bias=False)
73 + self.bn3 = nn.BatchNorm2d((planes * 1))
74 + self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False)
75 + self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio)
76 + self.relu = nn.ReLU(inplace=True)
77 + self.downsample = downsample
78 + self.stride = stride
79 + self.shake_drop = ShakeDrop(p_shakedrop)
80 +
81 + def forward(self, x):
82 +
83 + out = self.bn1(x)
84 + out = self.conv1(out)
85 +
86 + out = self.bn2(out)
87 + out = self.relu(out)
88 + out = self.conv2(out)
89 +
90 + out = self.bn3(out)
91 + out = self.relu(out)
92 + out = self.conv3(out)
93 +
94 + out = self.bn4(out)
95 +
96 + out = self.shake_drop(out)
97 +
98 + if self.downsample is not None:
99 + shortcut = self.downsample(x)
100 + featuremap_size = shortcut.size()[2:4]
101 + else:
102 + shortcut = x
103 + featuremap_size = out.size()[2:4]
104 +
105 + batch_size = out.size()[0]
106 + residual_channel = out.size()[1]
107 + shortcut_channel = shortcut.size()[1]
108 +
109 + if residual_channel != shortcut_channel:
110 + padding = torch.autograd.Variable(
111 + torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0],
112 + featuremap_size[1]).fill_(0))
113 + out += torch.cat((shortcut, padding), 1)
114 + else:
115 + out += shortcut
116 +
117 + return out
118 +
119 +
120 +class PyramidNet(nn.Module):
121 +
122 + def __init__(self, dataset, depth, alpha, num_classes, bottleneck=True):
123 + super(PyramidNet, self).__init__()
124 + self.dataset = dataset
125 + if self.dataset.startswith('cifar'):
126 + self.inplanes = 16
127 + if bottleneck:
128 + n = int((depth - 2) / 9)
129 + block = Bottleneck
130 + else:
131 + n = int((depth - 2) / 6)
132 + block = BasicBlock
133 +
134 + self.addrate = alpha / (3 * n * 1.0)
135 + self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)]
136 +
137 + self.input_featuremap_dim = self.inplanes
138 + self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False)
139 + self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
140 +
141 + self.featuremap_dim = self.input_featuremap_dim
142 + self.layer1 = self.pyramidal_make_layer(block, n)
143 + self.layer2 = self.pyramidal_make_layer(block, n, stride=2)
144 + self.layer3 = self.pyramidal_make_layer(block, n, stride=2)
145 +
146 + self.final_featuremap_dim = self.input_featuremap_dim
147 + self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim)
148 + self.relu_final = nn.ReLU(inplace=True)
149 + self.avgpool = nn.AvgPool2d(8)
150 + self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
151 +
152 + elif dataset == 'imagenet':
153 + blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
154 + layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3],
155 + 200: [3, 24, 36, 3]}
156 +
157 + if layers.get(depth) is None:
158 + if bottleneck == True:
159 + blocks[depth] = Bottleneck
160 + temp_cfg = int((depth - 2) / 12)
161 + else:
162 + blocks[depth] = BasicBlock
163 + temp_cfg = int((depth - 2) / 8)
164 +
165 + layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg]
166 + print('=> the layer configuration for each stage is set to', layers[depth])
167 +
168 + self.inplanes = 64
169 + self.addrate = alpha / (sum(layers[depth]) * 1.0)
170 +
171 + self.input_featuremap_dim = self.inplanes
172 + self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False)
173 + self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
174 + self.relu = nn.ReLU(inplace=True)
175 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
176 +
177 + self.featuremap_dim = self.input_featuremap_dim
178 + self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0])
179 + self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2)
180 + self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2)
181 + self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2)
182 +
183 + self.final_featuremap_dim = self.input_featuremap_dim
184 + self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim)
185 + self.relu_final = nn.ReLU(inplace=True)
186 + self.avgpool = nn.AvgPool2d(7)
187 + self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
188 +
189 + for m in self.modules():
190 + if isinstance(m, nn.Conv2d):
191 + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
192 + m.weight.data.normal_(0, math.sqrt(2. / n))
193 + elif isinstance(m, nn.BatchNorm2d):
194 + m.weight.data.fill_(1)
195 + m.bias.data.zero_()
196 +
197 + assert len(self.ps_shakedrop) == 0, self.ps_shakedrop
198 +
199 + def pyramidal_make_layer(self, block, block_depth, stride=1):
200 + downsample = None
201 + if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio:
202 + downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True)
203 +
204 + layers = []
205 + self.featuremap_dim = self.featuremap_dim + self.addrate
206 + layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample, p_shakedrop=self.ps_shakedrop.pop(0)))
207 + for i in range(1, block_depth):
208 + temp_featuremap_dim = self.featuremap_dim + self.addrate
209 + layers.append(
210 + block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1, p_shakedrop=self.ps_shakedrop.pop(0)))
211 + self.featuremap_dim = temp_featuremap_dim
212 + self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio
213 +
214 + return nn.Sequential(*layers)
215 +
216 + def forward(self, x):
217 + if self.dataset == 'cifar10' or self.dataset == 'cifar100':
218 + x = self.conv1(x)
219 + x = self.bn1(x)
220 +
221 + x = self.layer1(x)
222 + x = self.layer2(x)
223 + x = self.layer3(x)
224 +
225 + x = self.bn_final(x)
226 + x = self.relu_final(x)
227 + x = self.avgpool(x)
228 + x = x.view(x.size(0), -1)
229 + x = self.fc(x)
230 +
231 + elif self.dataset == 'imagenet':
232 + x = self.conv1(x)
233 + x = self.bn1(x)
234 + x = self.relu(x)
235 + x = self.maxpool(x)
236 +
237 + x = self.layer1(x)
238 + x = self.layer2(x)
239 + x = self.layer3(x)
240 + x = self.layer4(x)
241 +
242 + x = self.bn_final(x)
243 + x = self.relu_final(x)
244 + x = self.avgpool(x)
245 + x = x.view(x.size(0), -1)
246 + x = self.fc(x)
247 +
248 + return x
1 +# Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2 +
3 +import torch.nn as nn
4 +import math
5 +
6 +
7 +def conv3x3(in_planes, out_planes, stride=1):
8 + "3x3 convolution with padding"
9 + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 + padding=1, bias=False)
11 +
12 +
13 +class BasicBlock(nn.Module):
14 + expansion = 1
15 +
16 + def __init__(self, inplanes, planes, stride=1, downsample=None):
17 + super(BasicBlock, self).__init__()
18 + self.conv1 = conv3x3(inplanes, planes, stride)
19 + self.bn1 = nn.BatchNorm2d(planes)
20 + self.conv2 = conv3x3(planes, planes)
21 + self.bn2 = nn.BatchNorm2d(planes)
22 + self.relu = nn.ReLU(inplace=True)
23 +
24 + self.downsample = downsample
25 + self.stride = stride
26 +
27 + def forward(self, x):
28 + residual = x
29 +
30 + out = self.conv1(x)
31 + out = self.bn1(out)
32 + out = self.relu(out)
33 +
34 + out = self.conv2(out)
35 + out = self.bn2(out)
36 +
37 + if self.downsample is not None:
38 + residual = self.downsample(x)
39 +
40 + out += residual
41 + out = self.relu(out)
42 +
43 + return out
44 +
45 +
46 +class Bottleneck(nn.Module):
47 + expansion = 4
48 +
49 + def __init__(self, inplanes, planes, stride=1, downsample=None):
50 + super(Bottleneck, self).__init__()
51 +
52 + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
53 + self.bn1 = nn.BatchNorm2d(planes)
54 + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
55 + self.bn2 = nn.BatchNorm2d(planes)
56 + self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
57 + self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
58 + self.relu = nn.ReLU(inplace=True)
59 +
60 + self.downsample = downsample
61 + self.stride = stride
62 +
63 + def forward(self, x):
64 + residual = x
65 +
66 + out = self.conv1(x)
67 + out = self.bn1(out)
68 + out = self.relu(out)
69 +
70 + out = self.conv2(out)
71 + out = self.bn2(out)
72 + out = self.relu(out)
73 +
74 + out = self.conv3(out)
75 + out = self.bn3(out)
76 + if self.downsample is not None:
77 + residual = self.downsample(x)
78 +
79 + out += residual
80 + out = self.relu(out)
81 +
82 + return out
83 +
84 +class ResNet(nn.Module):
85 + def __init__(self, dataset, depth, num_classes, bottleneck=False):
86 + super(ResNet, self).__init__()
87 + self.dataset = dataset
88 + if self.dataset.startswith('cifar'):
89 + self.inplanes = 16
90 + print(bottleneck)
91 + if bottleneck == True:
92 + n = int((depth - 2) / 9)
93 + block = Bottleneck
94 + else:
95 + n = int((depth - 2) / 6)
96 + block = BasicBlock
97 +
98 + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
99 + self.bn1 = nn.BatchNorm2d(self.inplanes)
100 + self.relu = nn.ReLU(inplace=True)
101 + self.layer1 = self._make_layer(block, 16, n)
102 + self.layer2 = self._make_layer(block, 32, n, stride=2)
103 + self.layer3 = self._make_layer(block, 64, n, stride=2)
104 + # self.avgpool = nn.AvgPool2d(8)
105 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
106 + self.fc = nn.Linear(64 * block.expansion, num_classes)
107 +
108 + elif dataset == 'imagenet':
109 + blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
110 + layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
111 + assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)'
112 +
113 + self.inplanes = 64
114 + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
115 + self.bn1 = nn.BatchNorm2d(64)
116 + self.relu = nn.ReLU(inplace=True)
117 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
118 + self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0])
119 + self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2)
120 + self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2)
121 + self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2)
122 + # self.avgpool = nn.AvgPool2d(7)
123 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
124 + self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes)
125 +
126 + for m in self.modules():
127 + if isinstance(m, nn.Conv2d):
128 + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 + m.weight.data.normal_(0, math.sqrt(2. / n))
130 + elif isinstance(m, nn.BatchNorm2d):
131 + m.weight.data.fill_(1)
132 + m.bias.data.zero_()
133 +
134 + def _make_layer(self, block, planes, blocks, stride=1):
135 + downsample = None
136 + if stride != 1 or self.inplanes != planes * block.expansion:
137 + downsample = nn.Sequential(
138 + nn.Conv2d(self.inplanes, planes * block.expansion,
139 + kernel_size=1, stride=stride, bias=False),
140 + nn.BatchNorm2d(planes * block.expansion),
141 + )
142 +
143 + layers = []
144 + layers.append(block(self.inplanes, planes, stride, downsample))
145 + self.inplanes = planes * block.expansion
146 + for i in range(1, blocks):
147 + layers.append(block(self.inplanes, planes))
148 +
149 + return nn.Sequential(*layers)
150 +
151 + def forward(self, x):
152 + if self.dataset == 'cifar10' or self.dataset == 'cifar100':
153 + x = self.conv1(x)
154 + x = self.bn1(x)
155 + x = self.relu(x)
156 +
157 + x = self.layer1(x)
158 + x = self.layer2(x)
159 + x = self.layer3(x)
160 +
161 + x = self.avgpool(x)
162 + x = x.view(x.size(0), -1)
163 + x = self.fc(x)
164 +
165 + elif self.dataset == 'imagenet':
166 + x = self.conv1(x)
167 + x = self.bn1(x)
168 + x = self.relu(x)
169 + x = self.maxpool(x)
170 +
171 + x = self.layer1(x)
172 + x = self.layer2(x)
173 + x = self.layer3(x)
174 + x = self.layer4(x)
175 +
176 + x = self.avgpool(x)
177 + x = x.view(x.size(0), -1)
178 + x = self.fc(x)
179 +
180 + return x
1 +# -*- coding: utf-8 -*-
2 +
3 +import torch
4 +import torch.nn as nn
5 +import torch.nn.functional as F
6 +from torch.autograd import Variable
7 +
8 +
9 +class ShakeDropFunction(torch.autograd.Function):
10 +
11 + @staticmethod
12 + def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]):
13 + if training:
14 + gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop)
15 + ctx.save_for_backward(gate)
16 + if gate.item() == 0:
17 + alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range)
18 + alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x)
19 + return alpha * x
20 + else:
21 + return x
22 + else:
23 + return (1 - p_drop) * x
24 +
25 + @staticmethod
26 + def backward(ctx, grad_output):
27 + gate = ctx.saved_tensors[0]
28 + if gate.item() == 0:
29 + beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1)
30 + beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
31 + beta = Variable(beta)
32 + return beta * grad_output, None, None, None
33 + else:
34 + return grad_output, None, None, None
35 +
36 +
37 +class ShakeDrop(nn.Module):
38 +
39 + def __init__(self, p_drop=0.5, alpha_range=[-1, 1]):
40 + super(ShakeDrop, self).__init__()
41 + self.p_drop = p_drop
42 + self.alpha_range = alpha_range
43 +
44 + def forward(self, x):
45 + return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range)
1 +# -*- coding: utf-8 -*-
2 +
3 +import math
4 +
5 +import torch.nn as nn
6 +import torch.nn.functional as F
7 +
8 +from FastAutoAugment.networks.shakeshake.shakeshake import ShakeShake
9 +from FastAutoAugment.networks.shakeshake.shakeshake import Shortcut
10 +
11 +
12 +class ShakeBlock(nn.Module):
13 +
14 + def __init__(self, in_ch, out_ch, stride=1):
15 + super(ShakeBlock, self).__init__()
16 + self.equal_io = in_ch == out_ch
17 + self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride)
18 +
19 + self.branch1 = self._make_branch(in_ch, out_ch, stride)
20 + self.branch2 = self._make_branch(in_ch, out_ch, stride)
21 +
22 + def forward(self, x):
23 + h1 = self.branch1(x)
24 + h2 = self.branch2(x)
25 + h = ShakeShake.apply(h1, h2, self.training)
26 + h0 = x if self.equal_io else self.shortcut(x)
27 + return h + h0
28 +
29 + def _make_branch(self, in_ch, out_ch, stride=1):
30 + return nn.Sequential(
31 + nn.ReLU(inplace=False),
32 + nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
33 + nn.BatchNorm2d(out_ch),
34 + nn.ReLU(inplace=False),
35 + nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
36 + nn.BatchNorm2d(out_ch))
37 +
38 +
39 +class ShakeResNet(nn.Module):
40 +
41 + def __init__(self, depth, w_base, label):
42 + super(ShakeResNet, self).__init__()
43 + n_units = (depth - 2) / 6
44 +
45 + in_chs = [16, w_base, w_base * 2, w_base * 4]
46 + self.in_chs = in_chs
47 +
48 + self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
49 + self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1])
50 + self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2)
51 + self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2)
52 + self.fc_out = nn.Linear(in_chs[3], label)
53 +
54 + # Initialize paramters
55 + for m in self.modules():
56 + if isinstance(m, nn.Conv2d):
57 + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
58 + m.weight.data.normal_(0, math.sqrt(2. / n))
59 + elif isinstance(m, nn.BatchNorm2d):
60 + m.weight.data.fill_(1)
61 + m.bias.data.zero_()
62 + elif isinstance(m, nn.Linear):
63 + m.bias.data.zero_()
64 +
65 + def forward(self, x):
66 + h = self.c_in(x)
67 + h = self.layer1(h)
68 + h = self.layer2(h)
69 + h = self.layer3(h)
70 + h = F.relu(h)
71 + h = F.avg_pool2d(h, 8)
72 + h = h.view(-1, self.in_chs[3])
73 + h = self.fc_out(h)
74 + return h
75 +
76 + def _make_layer(self, n_units, in_ch, out_ch, stride=1):
77 + layers = []
78 + for i in range(int(n_units)):
79 + layers.append(ShakeBlock(in_ch, out_ch, stride=stride))
80 + in_ch, stride = out_ch, 1
81 + return nn.Sequential(*layers)
1 +# -*- coding: utf-8 -*-
2 +
3 +import math
4 +
5 +import torch.nn as nn
6 +import torch.nn.functional as F
7 +
8 +from FastAutoAugment.networks.shakeshake.shakeshake import ShakeShake
9 +from FastAutoAugment.networks.shakeshake.shakeshake import Shortcut
10 +
11 +
12 +class ShakeBottleNeck(nn.Module):
13 +
14 + def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1):
15 + super(ShakeBottleNeck, self).__init__()
16 + self.equal_io = in_ch == out_ch
17 + self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride)
18 +
19 + self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride)
20 + self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride)
21 +
22 + def forward(self, x):
23 + h1 = self.branch1(x)
24 + h2 = self.branch2(x)
25 + h = ShakeShake.apply(h1, h2, self.training)
26 + h0 = x if self.equal_io else self.shortcut(x)
27 + return h + h0
28 +
29 + def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1):
30 + return nn.Sequential(
31 + nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False),
32 + nn.BatchNorm2d(mid_ch),
33 + nn.ReLU(inplace=False),
34 + nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False),
35 + nn.BatchNorm2d(mid_ch),
36 + nn.ReLU(inplace=False),
37 + nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False),
38 + nn.BatchNorm2d(out_ch))
39 +
40 +
41 +class ShakeResNeXt(nn.Module):
42 +
43 + def __init__(self, depth, w_base, cardinary, label):
44 + super(ShakeResNeXt, self).__init__()
45 + n_units = (depth - 2) // 9
46 + n_chs = [64, 128, 256, 1024]
47 + self.n_chs = n_chs
48 + self.in_ch = n_chs[0]
49 +
50 + self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1)
51 + self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary)
52 + self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2)
53 + self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2)
54 + self.fc_out = nn.Linear(n_chs[3], label)
55 +
56 + # Initialize paramters
57 + for m in self.modules():
58 + if isinstance(m, nn.Conv2d):
59 + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60 + m.weight.data.normal_(0, math.sqrt(2. / n))
61 + elif isinstance(m, nn.BatchNorm2d):
62 + m.weight.data.fill_(1)
63 + m.bias.data.zero_()
64 + elif isinstance(m, nn.Linear):
65 + m.bias.data.zero_()
66 +
67 + def forward(self, x):
68 + h = self.c_in(x)
69 + h = self.layer1(h)
70 + h = self.layer2(h)
71 + h = self.layer3(h)
72 + h = F.relu(h)
73 + h = F.avg_pool2d(h, 8)
74 + h = h.view(-1, self.n_chs[3])
75 + h = self.fc_out(h)
76 + return h
77 +
78 + def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1):
79 + layers = []
80 + mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4
81 + for i in range(n_units):
82 + layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride))
83 + self.in_ch, stride = out_ch, 1
84 + return nn.Sequential(*layers)
1 +# -*- coding: utf-8 -*-
2 +
3 +import torch
4 +import torch.nn as nn
5 +import torch.nn.functional as F
6 +from torch.autograd import Variable
7 +
8 +
9 +class ShakeShake(torch.autograd.Function):
10 +
11 + @staticmethod
12 + def forward(ctx, x1, x2, training=True):
13 + if training:
14 + alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_()
15 + alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1)
16 + else:
17 + alpha = 0.5
18 + return alpha * x1 + (1 - alpha) * x2
19 +
20 + @staticmethod
21 + def backward(ctx, grad_output):
22 + beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_()
23 + beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
24 + beta = Variable(beta)
25 +
26 + return beta * grad_output, (1 - beta) * grad_output, None
27 +
28 +
29 +class Shortcut(nn.Module):
30 +
31 + def __init__(self, in_ch, out_ch, stride):
32 + super(Shortcut, self).__init__()
33 + self.stride = stride
34 + self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False)
35 + self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False)
36 + self.bn = nn.BatchNorm2d(out_ch)
37 +
38 + def forward(self, x):
39 + h = F.relu(x)
40 +
41 + h1 = F.avg_pool2d(h, 1, self.stride)
42 + h1 = self.conv1(h1)
43 +
44 + h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride)
45 + h2 = self.conv2(h2)
46 +
47 + h = torch.cat((h1, h2), 1)
48 + return self.bn(h)
1 +import torch.nn as nn
2 +import torch.nn.init as init
3 +import torch.nn.functional as F
4 +import numpy as np
5 +
6 +
7 +def conv3x3(in_planes, out_planes, stride=1):
8 + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
9 +
10 +
11 +def conv_init(m):
12 + classname = m.__class__.__name__
13 + if classname.find('Conv') != -1:
14 + init.xavier_uniform_(m.weight, gain=np.sqrt(2))
15 + init.constant_(m.bias, 0)
16 + elif classname.find('BatchNorm') != -1:
17 + init.constant_(m.weight, 1)
18 + init.constant_(m.bias, 0)
19 +
20 +
21 +class WideBasic(nn.Module):
22 + def __init__(self, in_planes, planes, dropout_rate, stride=1):
23 + super(WideBasic, self).__init__()
24 + self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.9)
25 + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
26 + self.dropout = nn.Dropout(p=dropout_rate)
27 + self.bn2 = nn.BatchNorm2d(planes, momentum=0.9)
28 + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
29 +
30 + self.shortcut = nn.Sequential()
31 + if stride != 1 or in_planes != planes:
32 + self.shortcut = nn.Sequential(
33 + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
34 + )
35 +
36 + def forward(self, x):
37 + out = self.dropout(self.conv1(F.relu(self.bn1(x))))
38 + out = self.conv2(F.relu(self.bn2(out)))
39 + out += self.shortcut(x)
40 +
41 + return out
42 +
43 +
44 +class WideResNet(nn.Module):
45 + def __init__(self, depth, widen_factor, dropout_rate, num_classes):
46 + super(WideResNet, self).__init__()
47 + self.in_planes = 16
48 +
49 + assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
50 + n = int((depth - 4) / 6)
51 + k = widen_factor
52 +
53 + nStages = [16, 16*k, 32*k, 64*k]
54 +
55 + self.conv1 = conv3x3(3, nStages[0])
56 + self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
57 + self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
58 + self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
59 + self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
60 + self.linear = nn.Linear(nStages[3], num_classes)
61 +
62 + # self.apply(conv_init)
63 +
64 + def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
65 + strides = [stride] + [1]*(num_blocks-1)
66 + layers = []
67 +
68 + for stride in strides:
69 + layers.append(block(self.in_planes, planes, dropout_rate, stride))
70 + self.in_planes = planes
71 +
72 + return nn.Sequential(*layers)
73 +
74 + def forward(self, x):
75 + out = self.conv1(x)
76 + out = self.layer1(out)
77 + out = self.layer2(out)
78 + out = self.layer3(out)
79 + out = F.relu(self.bn1(out))
80 + # out = F.avg_pool2d(out, 8)
81 + out = F.adaptive_avg_pool2d(out, (1, 1))
82 + out = out.view(out.size(0), -1)
83 + out = self.linear(out)
84 +
85 + return out
1 +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +# ==============================================================================
15 +
16 +import os
17 +import psutil
18 +import re
19 +import signal
20 +import subprocess
21 +import sys
22 +import threading
23 +import time
24 +
25 +
26 +GRACEFUL_TERMINATION_TIME_S = 5
27 +
28 +
29 +def terminate_executor_shell_and_children(pid):
30 + print('terminate_executor_shell_and_children+', pid)
31 + # If the shell already ends, no need to terminate its child.
32 + try:
33 + p = psutil.Process(pid)
34 + except psutil.NoSuchProcess:
35 + print('nosuchprocess')
36 + return
37 +
38 + # Terminate children gracefully.
39 + for child in p.children():
40 + try:
41 + child.terminate()
42 + except psutil.NoSuchProcess:
43 + pass
44 +
45 + # Wait for graceful termination.
46 + time.sleep(GRACEFUL_TERMINATION_TIME_S)
47 +
48 + # Send STOP to executor shell to stop progress.
49 + p.send_signal(signal.SIGSTOP)
50 +
51 + # Kill children recursively.
52 + for child in p.children(recursive=True):
53 + try:
54 + child.kill()
55 + except psutil.NoSuchProcess:
56 + pass
57 +
58 + # Kill shell itself.
59 + p.kill()
60 + print('terminate_executor_shell_and_children-', pid)
61 +
62 +
63 +def forward_stream(src_fd, dst_stream, prefix, index):
64 + with os.fdopen(src_fd, 'r') as src:
65 + line_buffer = ''
66 + while True:
67 + text = os.read(src.fileno(), 1000)
68 + if not isinstance(text, str):
69 + text = text.decode('utf-8')
70 + if not text:
71 + break
72 +
73 + for line in re.split('([\r\n])', text):
74 + line_buffer += line
75 + if line == '\r' or line == '\n':
76 + if index is not None:
77 + localtime = time.asctime(time.localtime(time.time()))
78 + line_buffer = '{time}[{rank}]<{prefix}>:{line}'.format(
79 + time=localtime,
80 + rank=str(index),
81 + prefix=prefix,
82 + line=line_buffer
83 + )
84 +
85 + dst_stream.write(line_buffer)
86 + dst_stream.flush()
87 + line_buffer = ''
88 +
89 +
90 +def execute(command, env=None, stdout=None, stderr=None, index=None, event=None):
91 + # Make a pipe for the subprocess stdout/stderr.
92 + (stdout_r, stdout_w) = os.pipe()
93 + (stderr_r, stderr_w) = os.pipe()
94 +
95 + # Make a pipe for notifying the child that parent has died.
96 + (r, w) = os.pipe()
97 +
98 + middleman_pid = os.fork()
99 + if middleman_pid == 0:
100 + # Close unused file descriptors to enforce PIPE behavior.
101 + os.close(w)
102 + os.setsid()
103 +
104 + executor_shell = subprocess.Popen(command, shell=True, env=env,
105 + stdout=stdout_w, stderr=stderr_w)
106 +
107 + sigterm_received = threading.Event()
108 +
109 + def set_sigterm_received(signum, frame):
110 + sigterm_received.set()
111 +
112 + signal.signal(signal.SIGINT, set_sigterm_received)
113 + signal.signal(signal.SIGTERM, set_sigterm_received)
114 +
115 + def kill_executor_children_if_parent_dies():
116 + # This read blocks until the pipe is closed on the other side
117 + # due to the process termination.
118 + os.read(r, 1)
119 + terminate_executor_shell_and_children(executor_shell.pid)
120 +
121 + bg = threading.Thread(target=kill_executor_children_if_parent_dies)
122 + bg.daemon = True
123 + bg.start()
124 +
125 + def kill_executor_children_if_sigterm_received():
126 + sigterm_received.wait()
127 + terminate_executor_shell_and_children(executor_shell.pid)
128 +
129 + bg = threading.Thread(target=kill_executor_children_if_sigterm_received)
130 + bg.daemon = True
131 + bg.start()
132 +
133 + exit_code = executor_shell.wait()
134 + os._exit(exit_code)
135 +
136 + # Close unused file descriptors to enforce PIPE behavior.
137 + os.close(r)
138 + os.close(stdout_w)
139 + os.close(stderr_w)
140 +
141 + # Redirect command stdout & stderr to provided streams or sys.stdout/sys.stderr.
142 + # This is useful for Jupyter Notebook that uses custom sys.stdout/sys.stderr or
143 + # for redirecting to a file on disk.
144 + if stdout is None:
145 + stdout = sys.stdout
146 + if stderr is None:
147 + stderr = sys.stderr
148 + stdout_fwd = threading.Thread(target=forward_stream, args=(stdout_r, stdout, 'stdout', index))
149 + stderr_fwd = threading.Thread(target=forward_stream, args=(stderr_r, stderr, 'stderr', index))
150 + stdout_fwd.start()
151 + stderr_fwd.start()
152 +
153 + def kill_middleman_if_master_thread_terminate():
154 + event.wait()
155 + try:
156 + os.kill(middleman_pid, signal.SIGTERM)
157 + except:
158 + # The process has already been killed elsewhere
159 + pass
160 +
161 + # TODO: Currently this requires explicitly declaration of the event and signal handler to set
162 + # the event (gloo_run.py:_launch_jobs()). Need to figure out a generalized way to hide this behind
163 + # interfaces.
164 + if event is not None:
165 + bg_thread = threading.Thread(target=kill_middleman_if_master_thread_terminate)
166 + bg_thread.daemon = True
167 + bg_thread.start()
168 +
169 + try:
170 + res, status = os.waitpid(middleman_pid, 0)
171 + except:
172 + # interrupted, send middleman TERM signal which will terminate children
173 + os.kill(middleman_pid, signal.SIGTERM)
174 + while True:
175 + try:
176 + _, status = os.waitpid(middleman_pid, 0)
177 + break
178 + except:
179 + # interrupted, wait for middleman to finish
180 + pass
181 +
182 + stdout_fwd.join()
183 + stderr_fwd.join()
184 + exit_code = status >> 8
185 + return exit_code
1 +import copy
2 +import os
3 +import sys
4 +import time
5 +from collections import OrderedDict, defaultdict
6 +
7 +import torch
8 +
9 +import numpy as np
10 +from hyperopt import hp
11 +import ray
12 +import gorilla
13 +from ray.tune.trial import Trial
14 +from ray.tune.trial_runner import TrialRunner
15 +from ray.tune.suggest import HyperOptSearch
16 +from ray.tune import register_trainable, run_experiments
17 +from tqdm import tqdm
18 +
19 +from FastAutoAugment.archive import remove_deplicates, policy_decoder
20 +from FastAutoAugment.augmentations import augment_list
21 +from FastAutoAugment.common import get_logger, add_filehandler
22 +from FastAutoAugment.data import get_dataloaders
23 +from FastAutoAugment.metrics import Accumulator
24 +from FastAutoAugment.networks import get_model, num_class
25 +from FastAutoAugment.train import train_and_eval
26 +from theconf import Config as C, ConfigArgumentParser
27 +
28 +
29 +top1_valid_by_cv = defaultdict(lambda: list)
30 +
31 +
32 +def step_w_log(self):
33 + original = gorilla.get_original_attribute(ray.tune.trial_runner.TrialRunner, 'step')
34 +
35 + # log
36 + cnts = OrderedDict()
37 + for status in [Trial.RUNNING, Trial.TERMINATED, Trial.PENDING, Trial.PAUSED, Trial.ERROR]:
38 + cnt = len(list(filter(lambda x: x.status == status, self._trials)))
39 + cnts[status] = cnt
40 + best_top1_acc = 0.
41 + for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials):
42 + if not trial.last_result:
43 + continue
44 + best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid'])
45 + print('iter', self._iteration, 'top1_acc=%.3f' % best_top1_acc, cnts, end='\r')
46 + return original(self)
47 +
48 +
49 +patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner, 'step', step_w_log, settings=gorilla.Settings(allow_hit=True))
50 +gorilla.apply(patch)
51 +
52 +
53 +logger = get_logger('Fast AutoAugment')
54 +
55 +
56 +def _get_path(dataset, model, tag):
57 + return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models/%s_%s_%s.model' % (dataset, model, tag)) # TODO
58 +
59 +
60 +@ray.remote(num_gpus=4, max_calls=1)
61 +def train_model(config, dataroot, augment, cv_ratio_test, cv_fold, save_path=None, skip_exist=False):
62 + C.get()
63 + C.get().conf = config
64 + C.get()['aug'] = augment
65 +
66 + result = train_and_eval(None, dataroot, cv_ratio_test, cv_fold, save_path=save_path, only_eval=skip_exist)
67 + return C.get()['model']['type'], cv_fold, result
68 +
69 +
70 +def eval_tta(config, augment, reporter):
71 + C.get()
72 + C.get().conf = config
73 + cv_ratio_test, cv_fold, save_path = augment['cv_ratio_test'], augment['cv_fold'], augment['save_path']
74 +
75 + # setup - provided augmentation rules
76 + C.get()['aug'] = policy_decoder(augment, augment['num_policy'], augment['num_op'])
77 +
78 + # eval
79 + model = get_model(C.get()['model'], num_class(C.get()['dataset']))
80 + ckpt = torch.load(save_path)
81 + if 'model' in ckpt:
82 + model.load_state_dict(ckpt['model'])
83 + else:
84 + model.load_state_dict(ckpt)
85 + model.eval()
86 +
87 + loaders = []
88 + for _ in range(augment['num_policy']): # TODO
89 + _, tl, validloader, tl2 = get_dataloaders(C.get()['dataset'], C.get()['batch'], augment['dataroot'], cv_ratio_test, split_idx=cv_fold)
90 + loaders.append(iter(validloader))
91 + del tl, tl2
92 +
93 + start_t = time.time()
94 + metrics = Accumulator()
95 + loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
96 + try:
97 + while True:
98 + losses = []
99 + corrects = []
100 + for loader in loaders:
101 + data, label = next(loader)
102 + data = data.cuda()
103 + label = label.cuda()
104 +
105 + pred = model(data)
106 +
107 + loss = loss_fn(pred, label)
108 + losses.append(loss.detach().cpu().numpy())
109 +
110 + _, pred = pred.topk(1, 1, True, True)
111 + pred = pred.t()
112 + correct = pred.eq(label.view(1, -1).expand_as(pred)).detach().cpu().numpy()
113 + corrects.append(correct)
114 + del loss, correct, pred, data, label
115 +
116 + losses = np.concatenate(losses)
117 + losses_min = np.min(losses, axis=0).squeeze()
118 +
119 + corrects = np.concatenate(corrects)
120 + corrects_max = np.max(corrects, axis=0).squeeze()
121 + metrics.add_dict({
122 + 'minus_loss': -1 * np.sum(losses_min),
123 + 'correct': np.sum(corrects_max),
124 + 'cnt': len(corrects_max)
125 + })
126 + del corrects, corrects_max
127 + except StopIteration:
128 + pass
129 +
130 + del model
131 + metrics = metrics / 'cnt'
132 + gpu_secs = (time.time() - start_t) * torch.cuda.device_count()
133 + reporter(minus_loss=metrics['minus_loss'], top1_valid=metrics['correct'], elapsed_time=gpu_secs, done=True)
134 + return metrics['correct']
135 +
136 +
137 +if __name__ == '__main__':
138 + import json
139 + from pystopwatch2 import PyStopwatch
140 + w = PyStopwatch()
141 +
142 + parser = ConfigArgumentParser(conflict_handler='resolve')
143 + parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder')
144 + parser.add_argument('--until', type=int, default=5)
145 + parser.add_argument('--num-op', type=int, default=2)
146 + parser.add_argument('--num-policy', type=int, default=5)
147 + parser.add_argument('--num-search', type=int, default=200)
148 + parser.add_argument('--cv-ratio', type=float, default=0.4)
149 + parser.add_argument('--decay', type=float, default=-1)
150 + parser.add_argument('--redis', type=str, default='gpu-cloud-vnode30.dakao.io:23655')
151 + parser.add_argument('--per-class', action='store_true')
152 + parser.add_argument('--resume', action='store_true')
153 + parser.add_argument('--smoke-test', action='store_true')
154 + args = parser.parse_args()
155 +
156 + if args.decay > 0:
157 + logger.info('decay=%.4f' % args.decay)
158 + C.get()['optimizer']['decay'] = args.decay
159 +
160 + add_filehandler(logger, os.path.join('models', '%s_%s_cv%.1f.log' % (C.get()['dataset'], C.get()['model']['type'], args.cv_ratio)))
161 + logger.info('configuration...')
162 + logger.info(json.dumps(C.get().conf, sort_keys=True, indent=4))
163 + logger.info('initialize ray...')
164 + ray.init(redis_address=args.redis)
165 +
166 + num_result_per_cv = 10
167 + cv_num = 5
168 + copied_c = copy.deepcopy(C.get().conf)
169 +
170 + logger.info('search augmentation policies, dataset=%s model=%s' % (C.get()['dataset'], C.get()['model']['type']))
171 + logger.info('----- Train without Augmentations cv=%d ratio(test)=%.1f -----' % (cv_num, args.cv_ratio))
172 + w.start(tag='train_no_aug')
173 + paths = [_get_path(C.get()['dataset'], C.get()['model']['type'], 'ratio%.1f_fold%d' % (args.cv_ratio, i)) for i in range(cv_num)]
174 + print(paths)
175 + reqs = [
176 + train_model.remote(copy.deepcopy(copied_c), args.dataroot, C.get()['aug'], args.cv_ratio, i, save_path=paths[i], skip_exist=True)
177 + for i in range(cv_num)]
178 +
179 + tqdm_epoch = tqdm(range(C.get()['epoch']))
180 + is_done = False
181 + for epoch in tqdm_epoch:
182 + while True:
183 + epochs_per_cv = OrderedDict()
184 + for cv_idx in range(cv_num):
185 + try:
186 + latest_ckpt = torch.load(paths[cv_idx])
187 + if 'epoch' not in latest_ckpt:
188 + epochs_per_cv['cv%d' % (cv_idx + 1)] = C.get()['epoch']
189 + continue
190 + epochs_per_cv['cv%d' % (cv_idx+1)] = latest_ckpt['epoch']
191 + except Exception as e:
192 + continue
193 + tqdm_epoch.set_postfix(epochs_per_cv)
194 + if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= C.get()['epoch']:
195 + is_done = True
196 + if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= epoch:
197 + break
198 + time.sleep(10)
199 + if is_done:
200 + break
201 +
202 + logger.info('getting results...')
203 + pretrain_results = ray.get(reqs)
204 + for r_model, r_cv, r_dict in pretrain_results:
205 + logger.info('model=%s cv=%d top1_train=%.4f top1_valid=%.4f' % (r_model, r_cv+1, r_dict['top1_train'], r_dict['top1_valid']))
206 + logger.info('processed in %.4f secs' % w.pause('train_no_aug'))
207 +
208 + if args.until == 1:
209 + sys.exit(0)
210 +
211 + logger.info('----- Search Test-Time Augmentation Policies -----')
212 + w.start(tag='search')
213 +
214 + ops = augment_list(False)
215 + space = {}
216 + for i in range(args.num_policy):
217 + for j in range(args.num_op):
218 + space['policy_%d_%d' % (i, j)] = hp.choice('policy_%d_%d' % (i, j), list(range(0, len(ops))))
219 + space['prob_%d_%d' % (i, j)] = hp.uniform('prob_%d_ %d' % (i, j), 0.0, 1.0)
220 + space['level_%d_%d' % (i, j)] = hp.uniform('level_%d_ %d' % (i, j), 0.0, 1.0)
221 +
222 + final_policy_set = []
223 + total_computation = 0
224 + reward_attr = 'top1_valid' # top1_valid or minus_loss
225 + for _ in range(1): # run multiple times.
226 + for cv_fold in range(cv_num):
227 + name = "search_%s_%s_fold%d_ratio%.1f" % (C.get()['dataset'], C.get()['model']['type'], cv_fold, args.cv_ratio)
228 + print(name)
229 + register_trainable(name, lambda augs, rpt: eval_tta(copy.deepcopy(copied_c), augs, rpt))
230 + algo = HyperOptSearch(space, max_concurrent=4*20, reward_attr=reward_attr)
231 +
232 + exp_config = {
233 + name: {
234 + 'run': name,
235 + 'num_samples': 4 if args.smoke_test else args.num_search,
236 + 'resources_per_trial': {'gpu': 1},
237 + 'stop': {'training_iteration': args.num_policy},
238 + 'config': {
239 + 'dataroot': args.dataroot, 'save_path': paths[cv_fold],
240 + 'cv_ratio_test': args.cv_ratio, 'cv_fold': cv_fold,
241 + 'num_op': args.num_op, 'num_policy': args.num_policy
242 + },
243 + }
244 + }
245 + results = run_experiments(exp_config, search_alg=algo, scheduler=None, verbose=0, queue_trials=True, resume=args.resume, raise_on_failed_trial=False)
246 + print()
247 + results = [x for x in results if x.last_result is not None]
248 + results = sorted(results, key=lambda x: x.last_result[reward_attr], reverse=True)
249 +
250 + # calculate computation usage
251 + for result in results:
252 + total_computation += result.last_result['elapsed_time']
253 +
254 + for result in results[:num_result_per_cv]:
255 + final_policy = policy_decoder(result.config, args.num_policy, args.num_op)
256 + logger.info('loss=%.12f top1_valid=%.4f %s' % (result.last_result['minus_loss'], result.last_result['top1_valid'], final_policy))
257 +
258 + final_policy = remove_deplicates(final_policy)
259 + final_policy_set.extend(final_policy)
260 +
261 + logger.info(json.dumps(final_policy_set))
262 + logger.info('final_policy=%d' % len(final_policy_set))
263 + logger.info('processed in %.4f secs, gpu hours=%.4f' % (w.pause('search'), total_computation / 3600.))
264 + logger.info('----- Train with Augmentations model=%s dataset=%s aug=%s ratio(test)=%.1f -----' % (C.get()['model']['type'], C.get()['dataset'], C.get()['aug'], args.cv_ratio))
265 + w.start(tag='train_aug')
266 +
267 + num_experiments = 5
268 + default_path = [_get_path(C.get()['dataset'], C.get()['model']['type'], 'ratio%.1f_default%d' % (args.cv_ratio, _)) for _ in range(num_experiments)]
269 + augment_path = [_get_path(C.get()['dataset'], C.get()['model']['type'], 'ratio%.1f_augment%d' % (args.cv_ratio, _)) for _ in range(num_experiments)]
270 + reqs = [train_model.remote(copy.deepcopy(copied_c), args.dataroot, C.get()['aug'], 0.0, 0, save_path=default_path[_], skip_exist=True) for _ in range(num_experiments)] + \
271 + [train_model.remote(copy.deepcopy(copied_c), args.dataroot, final_policy_set, 0.0, 0, save_path=augment_path[_]) for _ in range(num_experiments)]
272 +
273 + tqdm_epoch = tqdm(range(C.get()['epoch']))
274 + is_done = False
275 + for epoch in tqdm_epoch:
276 + while True:
277 + epochs = OrderedDict()
278 + for exp_idx in range(num_experiments):
279 + try:
280 + if os.path.exists(default_path[exp_idx]):
281 + latest_ckpt = torch.load(default_path[exp_idx])
282 + epochs['default_exp%d' % (exp_idx + 1)] = latest_ckpt['epoch']
283 + except:
284 + pass
285 + try:
286 + if os.path.exists(augment_path[exp_idx]):
287 + latest_ckpt = torch.load(augment_path[exp_idx])
288 + epochs['augment_exp%d' % (exp_idx + 1)] = latest_ckpt['epoch']
289 + except:
290 + pass
291 +
292 + tqdm_epoch.set_postfix(epochs)
293 + if len(epochs) == num_experiments*2 and min(epochs.values()) >= C.get()['epoch']:
294 + is_done = True
295 + if len(epochs) == num_experiments*2 and min(epochs.values()) >= epoch:
296 + break
297 + time.sleep(10)
298 + if is_done:
299 + break
300 +
301 + logger.info('getting results...')
302 + final_results = ray.get(reqs)
303 +
304 + for train_mode in ['default', 'augment']:
305 + avg = 0.
306 + for _ in range(num_experiments):
307 + r_model, r_cv, r_dict = final_results.pop(0)
308 + logger.info('[%s] top1_train=%.4f top1_test=%.4f' % (train_mode, r_dict['top1_train'], r_dict['top1_test']))
309 + avg += r_dict['top1_test']
310 + avg /= num_experiments
311 + logger.info('[%s] top1_test average=%.4f (#experiments=%d)' % (train_mode, avg, num_experiments))
312 + logger.info('processed in %.4f secs' % w.pause('train_aug'))
313 +
314 + logger.info(w)
1 +import torch
2 +from torch.optim.optimizer import Optimizer
3 +
4 +
5 +class RMSpropTF(Optimizer):
6 + r"""Implements RMSprop algorithm.
7 + Reimplement original formulation to match TF rmsprop
8 + Proposed by G. Hinton in his
9 + `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
10 + The centered version first appears in `Generating Sequences
11 + With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
12 + The implementation here takes the square root of the gradient average before
13 + adding epsilon (note that TensorFlow interchanges these two operations). The effective
14 + learning rate is thus :math:`\alpha/(\sqrt{v + \epsilon})` where :math:`\alpha` from :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
15 + is the scheduled learning rate and :math:`v` is the weighted moving average
16 + of the squared gradient.
17 + Arguments:
18 + params (iterable): iterable of parameters to optimize or dicts defining
19 + parameter groups
20 + lr (float, optional): learning rate (default: 1e-2)
21 + momentum (float, optional): momentum factor (default: 0)
22 + alpha (float, optional): smoothing constant (default: 0.99)
23 + eps (float, optional): term added to the denominator to improve
24 + numerical stability (default: 1e-8)
25 + centered (bool, optional) : if ``True``, compute the centered RMSProp,
26 + the gradient is normalized by an estimation of its variance
27 + weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
28 + """
29 +
30 + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0.0):
31 + if not 0.0 <= lr:
32 + raise ValueError("Invalid learning rate: {}".format(lr))
33 + if not 0.0 <= eps:
34 + raise ValueError("Invalid epsilon value: {}".format(eps))
35 + if not 0.0 < momentum:
36 + raise ValueError("Invalid momentum value: {}".format(momentum))
37 + if not 0.0 <= alpha:
38 + raise ValueError("Invalid alpha value: {}".format(alpha))
39 + assert momentum > 0.0
40 +
41 + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, weight_decay=weight_decay)
42 + super(RMSpropTF, self).__init__(params, defaults)
43 + self.initialized = False
44 +
45 + def __setstate__(self, state):
46 + super(RMSpropTF, self).__setstate__(state)
47 + for group in self.param_groups:
48 + group.setdefault('momentum', 0)
49 +
50 + def load_state_dict(self, state_dict):
51 + super(RMSpropTF, self).load_state_dict(state_dict)
52 + self.initialized = True
53 +
54 + def step(self, closure=None):
55 + """Performs a single optimization step.
56 + We modified pytorch's RMSProp to be same as Tensorflow's
57 + See : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/training_ops.cc#L485
58 +
59 + Arguments:
60 + closure (callable, optional): A closure that reevaluates the model
61 + and returns the loss.
62 + """
63 + loss = None
64 + if closure is not None:
65 + loss = closure()
66 +
67 + for group in self.param_groups:
68 + for p in group['params']:
69 + if p.grad is None:
70 + continue
71 + grad = p.grad.data
72 + if grad.is_sparse:
73 + raise RuntimeError('RMSprop does not support sparse gradients')
74 + state = self.state[p]
75 +
76 + # State initialization
77 + if len(state) == 0:
78 + assert not self.initialized
79 + state['step'] = 0
80 + state['ms'] = torch.ones_like(p.data) #, memory_format=torch.preserve_format)
81 + state['mom'] = torch.zeros_like(p.data) #, memory_format=torch.preserve_format)
82 +
83 + # weight decay -----
84 + if group['weight_decay'] > 0:
85 + grad = grad.add(group['weight_decay'], p.data)
86 +
87 + rho = group['alpha']
88 + ms = state['ms']
89 + mom = state['mom']
90 + state['step'] += 1
91 +
92 + # ms.mul_(rho).addcmul_(1 - rho, grad, grad)
93 + ms.add_(torch.mul(grad, grad).add_(-ms) * (1. - rho))
94 + assert group['momentum'] > 0
95 +
96 + # new rmsprop
97 + mom.mul_(group['momentum']).addcdiv_(group['lr'], grad, (ms + group['eps']).sqrt())
98 +
99 + p.data.add_(-1.0, mom)
100 +
101 + return loss
1 +import torch
2 +from torch.nn import BatchNorm2d
3 +from torch.nn.parameter import Parameter
4 +import torch.distributed as dist
5 +from torch import nn
6 +
7 +
8 +class TpuBatchNormalization(nn.Module):
9 + # Ref : https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py#L113
10 + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
11 + track_running_stats=True):
12 + super(TpuBatchNormalization, self).__init__() # num_features, eps, momentum, affine, track_running_stats)
13 +
14 + self.weight = Parameter(torch.ones(num_features))
15 + self.bias = Parameter(torch.zeros(num_features))
16 +
17 + self.register_buffer('running_mean', torch.zeros(num_features))
18 + self.register_buffer('running_var', torch.ones(num_features))
19 + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
20 +
21 + self.eps = eps
22 + self.momentum = momentum
23 +
24 + def _reduce_avg(self, t):
25 + dist.all_reduce(t, dist.ReduceOp.SUM)
26 + t.mul_(1. / dist.get_world_size())
27 +
28 + def forward(self, input):
29 + if not self.training or not dist.is_initialized():
30 + bn = (input - self.running_mean.view(1, self.running_mean.shape[0], 1, 1)) / \
31 + (torch.sqrt(self.running_var.view(1, self.running_var.shape[0], 1, 1) + self.eps))
32 + # print(self.weight.shape, self.bias.shape)
33 + return bn.mul(self.weight.view(1, self.weight.shape[0], 1, 1)).add(self.bias.view(1, self.bias.shape[0], 1, 1))
34 +
35 + shard_mean, shard_invstd = torch.batch_norm_stats(input, self.eps)
36 + shard_vars = (1. / shard_invstd) ** 2 - self.eps
37 +
38 + shard_square_of_mean = torch.mul(shard_mean, shard_mean)
39 + shard_mean_of_square = shard_vars + shard_square_of_mean
40 +
41 + group_mean = shard_mean.clone().detach()
42 + self._reduce_avg(group_mean)
43 + group_mean_of_square = shard_mean_of_square.clone().detach()
44 + self._reduce_avg(group_mean_of_square)
45 + group_vars = group_mean_of_square - torch.mul(group_mean, group_mean)
46 +
47 + group_mean = group_mean.detach()
48 + group_vars = group_vars.detach()
49 +
50 + # print(self.running_mean.shape, self.running_var.shape)
51 + self.running_mean.mul_(1. - self.momentum).add_(group_mean.mul(self.momentum))
52 + self.running_var.mul_(1. - self.momentum).add_(group_vars.mul(self.momentum))
53 + self.num_batches_tracked.add_(1)
54 +
55 + # print(input.shape, group_mean.view(1, group_mean.shape[0], 1, 1).shape, group_vars.view(1, group_vars.shape[0], 1, 1).shape, self.eps)
56 + bn = (input - group_mean.view(1, group_mean.shape[0], 1, 1)) / (torch.sqrt(group_vars.view(1, group_vars.shape[0], 1, 1) + self.eps))
57 + # print(self.weight.shape, self.bias.shape)
58 + return bn.mul(self.weight.view(1, self.weight.shape[0], 1, 1)).add(self.bias.view(1, self.bias.shape[0], 1, 1))
1 +import sys
2 +
3 +
4 +sys.path.append('/data/private/fast-autoaugment-public') # TODO
5 +
6 +import itertools
7 +import json
8 +import logging
9 +import math
10 +import os
11 +from collections import OrderedDict
12 +
13 +import torch
14 +from torch import nn, optim
15 +from torch.nn.parallel.data_parallel import DataParallel
16 +from torch.nn.parallel import DistributedDataParallel
17 +import torch.distributed as dist
18 +
19 +from tqdm import tqdm
20 +from theconf import Config as C, ConfigArgumentParser
21 +
22 +from FastAutoAugment.common import get_logger, EMA, add_filehandler
23 +from FastAutoAugment.data import get_dataloaders
24 +from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet
25 +from FastAutoAugment.metrics import accuracy, Accumulator, CrossEntropyLabelSmooth
26 +from FastAutoAugment.networks import get_model, num_class
27 +from FastAutoAugment.tf_port.rmsprop import RMSpropTF
28 +from FastAutoAugment.aug_mixup import CrossEntropyMixUpLabelSmooth, mixup
29 +from warmup_scheduler import GradualWarmupScheduler
30 +
31 +logger = get_logger('Fast AutoAugment')
32 +logger.setLevel(logging.INFO)
33 +
34 +
35 +def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None, is_master=True, ema=None, wd=0.0, tqdm_disabled=False):
36 + if verbose:
37 + loader = tqdm(loader, disable=tqdm_disabled)
38 + loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch']))
39 +
40 + params_without_bn = [params for name, params in model.named_parameters() if not ('_bn' in name or '.bn' in name)]
41 +
42 + loss_ema = None
43 + metrics = Accumulator()
44 + cnt = 0
45 + total_steps = len(loader)
46 + steps = 0
47 + for data, label in loader:
48 + steps += 1
49 + data, label = data.cuda(), label.cuda()
50 +
51 + if C.get().conf.get('mixup', 0.0) <= 0.0 or optimizer is None:
52 + preds = model(data)
53 + loss = loss_fn(preds, label)
54 + else: # mixup
55 + data, targets, shuffled_targets, lam = mixup(data, label, C.get()['mixup'])
56 + preds = model(data)
57 + loss = loss_fn(preds, targets, shuffled_targets, lam)
58 + del shuffled_targets, lam
59 +
60 + if optimizer:
61 + loss += wd * (1. / 2.) * sum([torch.sum(p ** 2) for p in params_without_bn])
62 + loss.backward()
63 + if C.get()['optimizer']['clip'] > 0:
64 + nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer']['clip'])
65 + optimizer.step()
66 + optimizer.zero_grad()
67 +
68 + if ema is not None:
69 + ema(model, (epoch - 1) * total_steps + steps)
70 +
71 + top1, top5 = accuracy(preds, label, (1, 5))
72 + metrics.add_dict({
73 + 'loss': loss.item() * len(data),
74 + 'top1': top1.item() * len(data),
75 + 'top5': top5.item() * len(data),
76 + })
77 + cnt += len(data)
78 + if loss_ema:
79 + loss_ema = loss_ema * 0.9 + loss.item() * 0.1
80 + else:
81 + loss_ema = loss.item()
82 + if verbose:
83 + postfix = metrics / cnt
84 + if optimizer:
85 + postfix['lr'] = optimizer.param_groups[0]['lr']
86 + postfix['loss_ema'] = loss_ema
87 + loader.set_postfix(postfix)
88 +
89 + if scheduler is not None:
90 + scheduler.step(epoch - 1 + float(steps) / total_steps)
91 +
92 + del preds, loss, top1, top5, data, label
93 +
94 + if tqdm_disabled and verbose:
95 + if optimizer:
96 + logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr'])
97 + else:
98 + logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt)
99 +
100 + metrics /= cnt
101 + if optimizer:
102 + metrics.metrics['lr'] = optimizer.param_groups[0]['lr']
103 + if verbose:
104 + for key, value in metrics.items():
105 + writer.add_scalar(key, value, epoch)
106 + return metrics
107 +
108 +
109 +def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False, local_rank=-1, evaluation_interval=5):
110 + total_batch = C.get()["batch"]
111 + if local_rank >= 0:
112 + dist.init_process_group(backend='nccl', init_method='env://', world_size=int(os.environ['WORLD_SIZE']))
113 + device = torch.device('cuda', local_rank)
114 + torch.cuda.set_device(device)
115 +
116 + C.get()['lr'] *= dist.get_world_size()
117 + logger.info(f'local batch={C.get()["batch"]} world_size={dist.get_world_size()} ----> total batch={C.get()["batch"] * dist.get_world_size()}')
118 + total_batch = C.get()["batch"] * dist.get_world_size()
119 +
120 + is_master = local_rank < 0 or dist.get_rank() == 0
121 + if is_master:
122 + add_filehandler(logger, args.save + '.log')
123 +
124 + if not reporter:
125 + reporter = lambda **kwargs: 0
126 +
127 + max_epoch = C.get()['epoch']
128 + trainsampler, trainloader, validloader, testloader_ = get_dataloaders(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold, multinode=(local_rank >= 0))
129 +
130 + # create a model & an optimizer
131 + model = get_model(C.get()['model'], num_class(C.get()['dataset']), local_rank=local_rank)
132 + model_ema = get_model(C.get()['model'], num_class(C.get()['dataset']), local_rank=-1)
133 + model_ema.eval()
134 +
135 + criterion_ce = criterion = CrossEntropyLabelSmooth(num_class(C.get()['dataset']), C.get().conf.get('lb_smooth', 0))
136 + if C.get().conf.get('mixup', 0.0) > 0.0:
137 + criterion = CrossEntropyMixUpLabelSmooth(num_class(C.get()['dataset']), C.get().conf.get('lb_smooth', 0))
138 + if C.get()['optimizer']['type'] == 'sgd':
139 + optimizer = optim.SGD(
140 + model.parameters(),
141 + lr=C.get()['lr'],
142 + momentum=C.get()['optimizer'].get('momentum', 0.9),
143 + weight_decay=0.0,
144 + nesterov=C.get()['optimizer'].get('nesterov', True)
145 + )
146 + elif C.get()['optimizer']['type'] == 'rmsprop':
147 + optimizer = RMSpropTF(
148 + model.parameters(),
149 + lr=C.get()['lr'],
150 + weight_decay=0.0,
151 + alpha=0.9, momentum=0.9,
152 + eps=0.001
153 + )
154 + else:
155 + raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type'])
156 +
157 + lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine')
158 + if lr_scheduler_type == 'cosine':
159 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.)
160 + elif lr_scheduler_type == 'resnet':
161 + scheduler = adjust_learning_rate_resnet(optimizer)
162 + elif lr_scheduler_type == 'efficientnet':
163 + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 0.97 ** int((x + C.get()['lr_schedule']['warmup']['epoch']) / 2.4))
164 + else:
165 + raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)
166 +
167 + if C.get()['lr_schedule'].get('warmup', None) and C.get()['lr_schedule']['warmup']['epoch'] > 0:
168 + scheduler = GradualWarmupScheduler(
169 + optimizer,
170 + multiplier=C.get()['lr_schedule']['warmup']['multiplier'],
171 + total_epoch=C.get()['lr_schedule']['warmup']['epoch'],
172 + after_scheduler=scheduler
173 + )
174 +
175 + if not tag or not is_master:
176 + from FastAutoAugment.metrics import SummaryWriterDummy as SummaryWriter
177 + logger.warning('tag not provided, no tensorboard log.')
178 + else:
179 + from tensorboardX import SummaryWriter
180 + writers = [SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']]
181 +
182 + if C.get()['optimizer']['ema'] > 0.0 and is_master:
183 + # https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4?u=ildoonet
184 + ema = EMA(C.get()['optimizer']['ema'])
185 + else:
186 + ema = None
187 +
188 + result = OrderedDict()
189 + epoch_start = 1
190 + if save_path != 'test.pth': # and is_master: --> should load all data(not able to be broadcasted)
191 + if save_path and os.path.exists(save_path):
192 + logger.info('%s file found. loading...' % save_path)
193 + data = torch.load(save_path)
194 + key = 'model' if 'model' in data else 'state_dict'
195 +
196 + if 'epoch' not in data:
197 + model.load_state_dict(data)
198 + else:
199 + logger.info('checkpoint epoch@%d' % data['epoch'])
200 + if not isinstance(model, (DataParallel, DistributedDataParallel)):
201 + model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()})
202 + else:
203 + model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()})
204 + logger.info('optimizer.load_state_dict+')
205 + optimizer.load_state_dict(data['optimizer'])
206 + if data['epoch'] < C.get()['epoch']:
207 + epoch_start = data['epoch']
208 + else:
209 + only_eval = True
210 + if ema is not None:
211 + ema.shadow = data.get('ema', {}) if isinstance(data.get('ema', {}), dict) else data['ema'].state_dict()
212 + del data
213 + else:
214 + logger.info('"%s" file not found. skip to pretrain weights...' % save_path)
215 + if only_eval:
216 + logger.warning('model checkpoint not found. only-evaluation mode is off.')
217 + only_eval = False
218 +
219 + if local_rank >= 0:
220 + for name, x in model.state_dict().items():
221 + dist.broadcast(x, 0)
222 + logger.info(f'multinode init. local_rank={dist.get_rank()} is_master={is_master}')
223 + torch.cuda.synchronize()
224 +
225 + tqdm_disabled = bool(os.environ.get('TASK_NAME', '')) and local_rank != 0 # KakaoBrain Environment
226 +
227 + if only_eval:
228 + logger.info('evaluation only+')
229 + model.eval()
230 + rs = dict()
231 + rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0], is_master=is_master)
232 +
233 + with torch.no_grad():
234 + rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1], is_master=is_master)
235 + rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2], is_master=is_master)
236 + if ema is not None and len(ema) > 0:
237 + model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()})
238 + rs['valid'] = run_epoch(model_ema, validloader, criterion_ce, None, desc_default='valid(EMA)', epoch=0, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled)
239 + rs['test'] = run_epoch(model_ema, testloader_, criterion_ce, None, desc_default='*test(EMA)', epoch=0, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled)
240 + for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
241 + if setname not in rs:
242 + continue
243 + result['%s_%s' % (key, setname)] = rs[setname][key]
244 + result['epoch'] = 0
245 + return result
246 +
247 + # train loop
248 + best_top1 = 0
249 + for epoch in range(epoch_start, max_epoch + 1):
250 + if local_rank >= 0:
251 + trainsampler.set_epoch(epoch)
252 +
253 + model.train()
254 + rs = dict()
255 + rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=(is_master and local_rank <= 0), scheduler=scheduler, ema=ema, wd=C.get()['optimizer']['decay'], tqdm_disabled=tqdm_disabled)
256 + model.eval()
257 +
258 + if math.isnan(rs['train']['loss']):
259 + raise Exception('train loss is NaN.')
260 +
261 + if ema is not None and C.get()['optimizer']['ema_interval'] > 0 and epoch % C.get()['optimizer']['ema_interval'] == 0:
262 + logger.info(f'ema synced+ rank={dist.get_rank()}')
263 + if ema is not None:
264 + model.load_state_dict(ema.state_dict())
265 + for name, x in model.state_dict().items():
266 + # print(name)
267 + dist.broadcast(x, 0)
268 + torch.cuda.synchronize()
269 + logger.info(f'ema synced- rank={dist.get_rank()}')
270 +
271 + if is_master and (epoch % evaluation_interval == 0 or epoch == max_epoch):
272 + with torch.no_grad():
273 + rs['valid'] = run_epoch(model, validloader, criterion_ce, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled)
274 + rs['test'] = run_epoch(model, testloader_, criterion_ce, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled)
275 +
276 + if ema is not None:
277 + model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()})
278 + rs['valid'] = run_epoch(model_ema, validloader, criterion_ce, None, desc_default='valid(EMA)', epoch=epoch, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled)
279 + rs['test'] = run_epoch(model_ema, testloader_, criterion_ce, None, desc_default='*test(EMA)', epoch=epoch, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled)
280 +
281 + logger.info(
282 + f'epoch={epoch} '
283 + f'[train] loss={rs["train"]["loss"]:.4f} top1={rs["train"]["top1"]:.4f} '
284 + f'[valid] loss={rs["valid"]["loss"]:.4f} top1={rs["valid"]["top1"]:.4f} '
285 + f'[test] loss={rs["test"]["loss"]:.4f} top1={rs["test"]["top1"]:.4f} '
286 + )
287 +
288 + if metric == 'last' or rs[metric]['top1'] > best_top1:
289 + if metric != 'last':
290 + best_top1 = rs[metric]['top1']
291 + for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']):
292 + result['%s_%s' % (key, setname)] = rs[setname][key]
293 + result['epoch'] = epoch
294 +
295 + writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch)
296 + writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch)
297 +
298 + reporter(
299 + loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'],
300 + loss_test=rs['test']['loss'], top1_test=rs['test']['top1']
301 + )
302 +
303 + # save checkpoint
304 + if is_master and save_path:
305 + logger.info('save model@%d to %s, err=%.4f' % (epoch, save_path, 1 - best_top1))
306 + torch.save({
307 + 'epoch': epoch,
308 + 'log': {
309 + 'train': rs['train'].get_dict(),
310 + 'valid': rs['valid'].get_dict(),
311 + 'test': rs['test'].get_dict(),
312 + },
313 + 'optimizer': optimizer.state_dict(),
314 + 'model': model.state_dict(),
315 + 'ema': ema.state_dict() if ema is not None else None,
316 + }, save_path)
317 +
318 + del model
319 +
320 + result['top1_test'] = best_top1
321 + return result
322 +
323 +
324 +if __name__ == '__main__':
325 + parser = ConfigArgumentParser(conflict_handler='resolve')
326 + parser.add_argument('--tag', type=str, default='')
327 + parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder')
328 + parser.add_argument('--save', type=str, default='test.pth')
329 + parser.add_argument('--cv-ratio', type=float, default=0.0)
330 + parser.add_argument('--cv', type=int, default=0)
331 + parser.add_argument('--local_rank', type=int, default=-1)
332 + parser.add_argument('--evaluation-interval', type=int, default=5)
333 + parser.add_argument('--only-eval', action='store_true')
334 + args = parser.parse_args()
335 +
336 + assert (args.only_eval and args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.'
337 +
338 + if not args.only_eval:
339 + if args.save:
340 + logger.info('checkpoint will be saved at %s' % args.save)
341 + else:
342 + logger.warning('Provide --save argument to save the checkpoint. Without it, training result will not be saved!')
343 +
344 + import time
345 + t = time.time()
346 + result = train_and_eval(args.tag, args.dataroot, test_ratio=args.cv_ratio, cv_fold=args.cv, save_path=args.save, only_eval=args.only_eval, local_rank=args.local_rank, metric='test', evaluation_interval=args.evaluation_interval)
347 + elapsed = time.time() - t
348 +
349 + logger.info('done.')
350 + logger.info('model: %s' % C.get()['model'])
351 + logger.info('augmentation: %s' % C.get()['aug'])
352 + logger.info('\n' + json.dumps(result, indent=4))
353 + logger.info('elapsed time: %.3f Hours' % (elapsed / 3600.))
354 + logger.info('top1 error in testset: %.4f' % (1. - result['top1_test']))
355 + logger.info(args.save)
1 +import sys
2 +sys.path.append('/data/private/fast-autoaugment-public') # TODO
3 +
4 +import time
5 +import os
6 +import threading
7 +import six
8 +from six.moves import queue
9 +
10 +from FastAutoAugment import safe_shell_exec
11 +
12 +
13 +def _exec_command(command):
14 + host_output = six.StringIO()
15 + try:
16 + exit_code = safe_shell_exec.execute(command,
17 + stdout=host_output,
18 + stderr=host_output)
19 + if exit_code != 0:
20 + print('Launching task function was not successful:\n{host_output}'.format(host_output=host_output.getvalue()))
21 + os._exit(exit_code)
22 + finally:
23 + host_output.close()
24 + return exit_code
25 +
26 +
27 +def execute_function_multithreaded(fn,
28 + args_list,
29 + block_until_all_done=True,
30 + max_concurrent_executions=1000):
31 + """
32 + Executes fn in multiple threads each with one set of the args in the
33 + args_list.
34 + :param fn: function to be executed
35 + :type fn:
36 + :param args_list:
37 + :type args_list: list(list)
38 + :param block_until_all_done: if is True, function will block until all the
39 + threads are done and will return the results of each thread's execution.
40 + :type block_until_all_done: bool
41 + :param max_concurrent_executions:
42 + :type max_concurrent_executions: int
43 + :return:
44 + If block_until_all_done is False, returns None. If block_until_all_done is
45 + True, function returns the dict of results.
46 + {
47 + index: execution result of fn with args_list[index]
48 + }
49 + :rtype: dict
50 + """
51 + result_queue = queue.Queue()
52 + worker_queue = queue.Queue()
53 +
54 + for i, arg in enumerate(args_list):
55 + arg.append(i)
56 + worker_queue.put(arg)
57 +
58 + def fn_execute():
59 + while True:
60 + try:
61 + arg = worker_queue.get(block=False)
62 + except queue.Empty:
63 + return
64 + exec_index = arg[-1]
65 + res = fn(*arg[:-1])
66 + result_queue.put((exec_index, res))
67 +
68 + threads = []
69 + number_of_threads = min(max_concurrent_executions, len(args_list))
70 +
71 + for _ in range(number_of_threads):
72 + thread = threading.Thread(target=fn_execute)
73 + if not block_until_all_done:
74 + thread.daemon = True
75 + thread.start()
76 + threads.append(thread)
77 +
78 + # Returns the results only if block_until_all_done is set.
79 + results = None
80 + if block_until_all_done:
81 + # Because join() cannot be interrupted by signal, a single join()
82 + # needs to be separated into join()s with timeout in a while loop.
83 + have_alive_child = True
84 + while have_alive_child:
85 + have_alive_child = False
86 + for t in threads:
87 + t.join(0.1)
88 + if t.is_alive():
89 + have_alive_child = True
90 +
91 + results = {}
92 + while not result_queue.empty():
93 + item = result_queue.get()
94 + results[item[0]] = item[1]
95 +
96 + if len(results) != len(args_list):
97 + raise RuntimeError(
98 + 'Some threads for func {func} did not complete '
99 + 'successfully.'.format(func=fn.__name__))
100 + return results
101 +
102 +
103 +if __name__ == '__main__':
104 + import argparse
105 +
106 + parser = argparse.ArgumentParser()
107 + parser.add_argument('--host', type=str)
108 + parser.add_argument('--num-gpus', type=int, default=4)
109 + parser.add_argument('--master', type=str, default='task1')
110 + parser.add_argument('--port', type=int, default=1958)
111 + parser.add_argument('-c', '--conf', type=str)
112 + parser.add_argument('--args', type=str, default='')
113 +
114 + args = parser.parse_args()
115 +
116 + try:
117 + hosts = ['task%d' % (x + 1) for x in range(int(args.host))]
118 + except:
119 + hosts = args.host.split(',')
120 +
121 + cwd = os.getcwd()
122 + command_list = []
123 + for node_rank, host in enumerate(hosts):
124 + ssh_cmd = f'ssh -t -t -o StrictHostKeyChecking=no {host} -p 22 ' \
125 + f'\'bash -O huponexit -c "cd {cwd} && ' \
126 + f'python -m torch.distributed.launch --nproc_per_node={args.num_gpus} --nnodes={len(hosts)} ' \
127 + f'--master_addr={args.master} --master_port={args.port} --node_rank={node_rank} ' \
128 + f'FastAutoAugment/train.py -c {args.conf} {args.args}"' \
129 + '\''
130 + print(ssh_cmd)
131 +
132 + command_list.append([ssh_cmd])
133 +
134 + execute_function_multithreaded(_exec_command,
135 + command_list[1:],
136 + block_until_all_done=False)
137 +
138 + print(command_list[0])
139 +
140 + while True:
141 + time.sleep(1)
142 +
143 + # thread = threading.Thread(target=safe_shell_exec.execute, args=(command_list[0][0],))
144 + # thread.start()
145 + # thread.join()
146 +
147 + # while True:
148 + # time.sleep(1)
1 +MIT License
2 +
3 +Copyright (c) 2019 Ildoo Kim
4 +
5 +Permission is hereby granted, free of charge, to any person obtaining a copy
6 +of this software and associated documentation files (the "Software"), to deal
7 +in the Software without restriction, including without limitation the rights
8 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 +copies of the Software, and to permit persons to whom the Software is
10 +furnished to do so, subject to the following conditions:
11 +
12 +The above copyright notice and this permission notice shall be included in all
13 +copies or substantial portions of the Software.
14 +
15 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 +SOFTWARE.
1 +# Fast AutoAugment **(Accepted at NeurIPS 2019)**
2 +
3 +Official [Fast AutoAugment](https://arxiv.org/abs/1905.00397) implementation in PyTorch.
4 +
5 +- Fast AutoAugment learns augmentation policies using a more efficient search strategy based on density matching.
6 +- Fast AutoAugment speeds up the search time by orders of magnitude while maintaining the comparable performances.
7 +
8 +<p align="center">
9 +<img src="etc/search.jpg" height=350>
10 +</p>
11 +
12 +## Results
13 +
14 +### CIFAR-10 / 100
15 +
16 +Search : **3.5 GPU Hours (1428x faster than AutoAugment)**, WResNet-40x2 on Reduced CIFAR-10
17 +
18 +| Model(CIFAR-10) | Baseline | Cutout | AutoAugment | Fast AutoAugment<br/>(transfer/direct) | |
19 +|-------------------------|------------|------------|-------------|------------------|----|
20 +| Wide-ResNet-40-2 | 5.3 | 4.1 | 3.7 | 3.6 / 3.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_wresnet40x2_top1_3.52.pth) |
21 +| Wide-ResNet-28-10 | 3.9 | 3.1 | 2.6 | 2.7 / 2.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_wresnet28x10_top1.pth) |
22 +| Shake-Shake(26 2x32d) | 3.6 | 3.0 | 2.5 | 2.7 / 2.5 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_shake26_2x32d_top1_2.68.pth) |
23 +| Shake-Shake(26 2x96d) | 2.9 | 2.6 | 2.0 | 2.0 / 2.0 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_shake26_2x96d_top1_1.97.pth) |
24 +| Shake-Shake(26 2x112d) | 2.8 | 2.6 | 1.9 | 2.0 / 1.9 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_shake26_2x112d_top1_2.04.pth) |
25 +| PyramidNet+ShakeDrop | 2.7 | 2.3 | 1.5 | 1.8 / 1.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_pyramid272_top1_1.44.pth) |
26 +
27 +| Model(CIFAR-100) | Baseline | Cutout | AutoAugment | Fast AutoAugment<br/>(transfer/direct) | |
28 +|-----------------------|------------|------------|-------------|------------------|----|
29 +| Wide-ResNet-40-2 | 26.0 | 25.2 | 20.7 | 20.7 / 20.6 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_wresnet40x2_top1_20.43.pth) |
30 +| Wide-ResNet-28-10 | 18.8 | 18.4 | 17.1 | 17.3 / 17.3 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_wresnet28x10_top1_17.17.pth) |
31 +| Shake-Shake(26 2x96d) | 17.1 | 16.0 | 14.3 | 14.9 / 14.6 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_shake26_2x96d_top1_15.15.pth) |
32 +| PyramidNet+ShakeDrop | 14.0 | 12.2 | 10.7 | 11.9 / 11.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_pyramid272_top1_11.74.pth) |
33 +
34 +### ImageNet
35 +
36 +Search : **450 GPU Hours (33x faster than AutoAugment)**, ResNet-50 on Reduced ImageNet
37 +
38 +| Model | Baseline | AutoAugment | Fast AutoAugment<br/>(Top1/Top5) | |
39 +|------------|------------|-------------|------------------|----|
40 +| ResNet-50 | 23.7 / 6.9 | 22.4 / 6.2 | **22.4 / 6.3** | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/imagenet_resnet50_top1_22.2.pth) |
41 +| ResNet-200 | 21.5 / 5.8 | 20.0 / 5.0 | **19.4 / 4.7** | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/imagenet_resnet200_top1_19.4.pth) |
42 +
43 +Notes
44 +* We evaluated resnet-50 and resnet-200 with resolution of 224 and 320, respectively. According to the original resnet paper, resnet 200 was tested with the resolution of 320. Also our resnet-200 baseline's performance was similar when we use the resolution.
45 +* But with recent our code clean-up and bugfixes, we've found that the baseline performs similar to the baseline even using 224x224.
46 +* When we use 224x224, resnet-200 performs **20.0 / 5.2**. Download link for the trained model is [here](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/imagenet_resnet200_res224.pth).
47 +
48 +We have conducted additional experiments with EfficientNet.
49 +
50 +| Model | Baseline | AutoAugment | | Our Baseline(Batch) | +Fast AA |
51 +|-------|------------|-------------|---|---------------------|----------|
52 +| B0 | 23.2 | 22.7 | | 22.96 | 22.68 |
53 +
54 +### SVHN Test
55 +
56 +Search : **1.5 GPU Hours**
57 +
58 +| | Baseline | AutoAug / Our | Fast AutoAugment |
59 +|----------------------------------|---------:|--------------:|--------:|
60 +| Wide-Resnet28x10 | 1.5 | 1.1 | 1.1 |
61 +
62 +## Run
63 +
64 +We conducted experiments under
65 +
66 +- python 3.6.9
67 +- pytorch 1.2.0, torchvision 0.4.0, cuda10
68 +
69 +### Search a augmentation policy
70 +
71 +Please read ray's document to construct a proper ray cluster : https://github.com/ray-project/ray, and run search.py with the master's redis address.
72 +
73 +```
74 +$ python search.py -c confs/wresnet40x2_cifar10_b512.yaml --dataroot ... --redis ...
75 +```
76 +
77 +### Train a model with found policies
78 +
79 +You can train network architectures on CIFAR-10 / 100 and ImageNet with our searched policies.
80 +
81 +- fa_reduced_cifar10 : reduced CIFAR-10(4k images), WResNet-40x2
82 +- fa_reduced_imagenet : reduced ImageNet(50k images, 120 classes), ResNet-50
83 +
84 +```
85 +$ export PYTHONPATH=$PYTHONPATH:$PWD
86 +$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
87 +$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100
88 +$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
89 +$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100
90 +...
91 +$ python FastAutoAugment/train.py -c confs/resnet50_b512.yaml --aug fa_reduced_imagenet
92 +$ python FastAutoAugment/train.py -c confs/resnet200_b512.yaml --aug fa_reduced_imagenet
93 +```
94 +
95 +By adding --only-eval and --save arguments, you can test trained models without training.
96 +
97 +If you want to train with multi-gpu/node, use `torch.distributed.launch` such as
98 +
99 +```bash
100 +$ python -m torch.distributed.launch --nproc_per_node={num_gpu_per_node} --nnodes={num_node} --master_addr={master} --master_port={master_port} --node_rank={0,1,2,...,num_node} FastAutoAugment/train.py -c confs/efficientnet_b4.yaml --aug fa_reduced_imagenet
101 +```
102 +
103 +## Citation
104 +
105 +If you use this code in your research, please cite our [paper](https://arxiv.org/abs/1905.00397).
106 +
107 +```
108 +@inproceedings{lim2019fast,
109 + title={Fast AutoAugment},
110 + author={Lim, Sungbin and Kim, Ildoo and Kim, Taesup and Kim, Chiheon and Kim, Sungwoong},
111 + booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
112 + year={2019}
113 +}
114 +```
115 +
116 +## Contact for Issues
117 +- Ildoo Kim, ildoo.kim@kakaobrain.com
118 +
119 +## References & Opensources
120 +
121 +We increase the batch size and adapt the learning rate accordingly to boost the training. Otherwise, we set other hyperparameters equal to AutoAugment if possible. For the unknown hyperparameters, we follow values from the original references or we tune them to match baseline performances.
122 +
123 +- **ResNet** : [paper1](https://arxiv.org/abs/1512.03385), [paper2](https://arxiv.org/abs/1603.05027), [code](https://github.com/osmr/imgclsmob/tree/master/pytorch/pytorchcv/models)
124 +- **PyramidNet** : [paper](https://arxiv.org/abs/1610.02915), [code](https://github.com/dyhan0920/PyramidNet-PyTorch)
125 +- **Wide-ResNet** : [code](https://github.com/meliketoy/wide-resnet.pytorch)
126 +- **Shake-Shake** : [code](https://github.com/owruby/shake-shake_pytorch)
127 +- **ShakeDrop Regularization** : [paper](https://arxiv.org/abs/1802.02375), [code](https://github.com/owruby/shake-drop_pytorch)
128 +- **AutoAugment** : [code](https://github.com/tensorflow/models/tree/master/research/autoaugment)
129 +- **Ray** : [code](https://github.com/ray-project/ray)
130 +- **HyperOpt** : [code](https://github.com/hyperopt/hyperopt)
This diff could not be displayed because it is too large.
1 +model:
2 + type: efficientnet-b0
3 + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv.
4 +dataset: imagenet
5 +aug: fa_reduced_imagenet
6 +cutout: 0
7 +batch: 128 # per gpu
8 +epoch: 350
9 +lr: 0.008 # 0.256 for 4096 batch
10 +lr_schedule:
11 + type: 'efficientnet'
12 + warmup:
13 + multiplier: 1
14 + epoch: 5
15 +optimizer:
16 + type: rmsprop
17 + decay: 0.00001
18 + clip: 0
19 + ema: 0.9999
20 + ema_interval: -1
21 +lb_smooth: 0.1
...\ No newline at end of file ...\ No newline at end of file
1 +model:
2 + type: efficientnet-b0
3 + condconv_num_expert: 8 # if this is greater than 1(eg. 4), it activates condconv.
4 +dataset: imagenet
5 +aug: fa_reduced_imagenet
6 +cutout: 0
7 +batch: 128 # per gpu
8 +epoch: 350
9 +lr: 0.008 # 0.256 for 4096 batch
10 +lr_schedule:
11 + type: 'efficientnet'
12 + warmup:
13 + multiplier: 1
14 + epoch: 5
15 +optimizer:
16 + type: rmsprop
17 + decay: 0.00001
18 + clip: 0
19 + ema: 0.9999
20 + ema_interval: -1
21 +lb_smooth: 0.1
22 +mixup: 0.2
1 +model:
2 + type: efficientnet-b1
3 + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv.
4 +dataset: imagenet
5 +aug: fa_reduced_imagenet
6 +cutout: 0
7 +batch: 128 # per gpu
8 +epoch: 350
9 +lr: 0.008 # 0.256 for 4096 batch
10 +lr_schedule:
11 + type: 'efficientnet'
12 + warmup:
13 + multiplier: 1
14 + epoch: 5
15 +optimizer:
16 + type: rmsprop
17 + decay: 0.00001
18 + clip: 0
19 + ema: 0.9999
20 + ema_interval: -1
21 +lb_smooth: 0.1
1 +model:
2 + type: efficientnet-b2
3 + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv.
4 +dataset: imagenet
5 +aug: fa_reduced_imagenet
6 +cutout: 0
7 +batch: 128 # per gpu
8 +epoch: 350
9 +lr: 0.008 # 0.256 for 4096 batch
10 +lr_schedule:
11 + type: 'efficientnet'
12 + warmup:
13 + multiplier: 1
14 + epoch: 5
15 +optimizer:
16 + type: rmsprop
17 + decay: 0.00001
18 + clip: 0
19 + ema: 0.9999
20 + ema_interval: -1
21 +lb_smooth: 0.1
1 +model:
2 + type: efficientnet-b3
3 + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv.
4 +dataset: imagenet
5 +aug: fa_reduced_imagenet
6 +cutout: 0
7 +batch: 64 # per gpu
8 +epoch: 350
9 +lr: 0.004 # 0.256 for 4096 batch
10 +lr_schedule:
11 + type: 'efficientnet'
12 + warmup:
13 + multiplier: 1
14 + epoch: 5
15 +optimizer:
16 + type: rmsprop
17 + decay: 0.00001
18 + clip: 0
19 + ema: 0.9999
20 + ema_interval: -1
21 +lb_smooth: 0.1
1 +model:
2 + type: efficientnet-b4
3 + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv.
4 +dataset: imagenet
5 +aug: fa_reduced_imagenet
6 +cutout: 0
7 +batch: 32 # per gpu
8 +epoch: 350
9 +lr: 0.002 # 0.256 for 4096 batch
10 +lr_schedule:
11 + type: 'efficientnet'
12 + warmup:
13 + multiplier: 1
14 + epoch: 5
15 +optimizer:
16 + type: rmsprop
17 + decay: 0.00001
18 + clip: 0
19 + ema: 0.9999
20 + ema_interval: -1
21 +lb_smooth: 0.1
1 +model:
2 + type: pyramid
3 + depth: 272
4 + alpha: 200
5 + bottleneck: True
6 +dataset: cifar10
7 +aug: fa_reduced_cifar10
8 +cutout: 16
9 +batch: 64
10 +epoch: 1800
11 +lr: 0.05
12 +lr_schedule:
13 + type: 'cosine'
14 + warmup:
15 + multiplier: 1
16 + epoch: 5
17 +optimizer:
18 + type: sgd
19 + nesterov: True
20 + decay: 0.00005
1 +model:
2 + type: resnet200
3 +dataset: imagenet
4 +aug: fa_reduced_imagenet
5 +cutout: 0
6 +batch: 64
7 +epoch: 270
8 +lr: 0.025
9 +lr_schedule:
10 + type: 'resnet'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.0001
18 + clip: 0
1 +model:
2 + type: resnet50
3 +dataset: imagenet
4 +aug: fa_reduced_imagenet
5 +cutout: 0
6 +batch: 128
7 +epoch: 270
8 +lr: 0.05
9 +lr_schedule:
10 + type: 'resnet'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.0001
18 + clip: 0
19 + ema: 0
1 +model:
2 + type: resnet50
3 +dataset: imagenet
4 +aug: fa_reduced_imagenet
5 +cutout: 0
6 +batch: 128
7 +epoch: 270
8 +lr: 0.05
9 +lr_schedule:
10 + type: 'resnet'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.0001
18 + clip: 0
19 + ema: 0
20 +#lb_smooth: 0.1
21 +mixup: 0.2
1 +model:
2 + type: shakeshake26_2x112d
3 +dataset: cifar10
4 +aug: fa_reduced_cifar10
5 +cutout: 16
6 +batch: 128
7 +epoch: 1800
8 +lr: 0.01
9 +lr_schedule:
10 + type: 'cosine'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.002
1 +model:
2 + type: shakeshake26_2x32d
3 +dataset: cifar10
4 +aug: fa_reduced_cifar10
5 +cutout: 16
6 +batch: 128
7 +epoch: 1800
8 +lr: 0.01
9 +lr_schedule:
10 + type: 'cosine'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.001
1 +model:
2 + type: shakeshake26_2x96d
3 +dataset: cifar10
4 +aug: fa_reduced_cifar10
5 +cutout: 16
6 +batch: 128
7 +epoch: 1800
8 +lr: 0.01
9 +lr_schedule:
10 + type: 'cosine'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.001
1 +model:
2 + type: wresnet28_10
3 +dataset: cifar10
4 +aug: fa_reduced_cifar10
5 +cutout: 16
6 +batch: 128
7 +epoch: 200
8 +lr: 0.1
9 +lr_schedule:
10 + type: 'cosine'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.0005
...\ No newline at end of file ...\ No newline at end of file
1 +model:
2 + type: wresnet28_10
3 +dataset: svhn
4 +aug: fa_reduced_svhn
5 +cutout: 20
6 +batch: 128
7 +epoch: 200
8 +lr: 0.01
9 +lr_schedule:
10 + type: 'cosine'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.0005
...\ No newline at end of file ...\ No newline at end of file
1 +model:
2 + type: wresnet40_2
3 +dataset: cifar10
4 +aug: fa_reduced_cifar10
5 +cutout: 16
6 +batch: 128
7 +epoch: 200
8 +lr: 0.1
9 +lr_schedule:
10 + type: 'cosine'
11 + warmup:
12 + multiplier: 1
13 + epoch: 5
14 +optimizer:
15 + type: sgd
16 + nesterov: True
17 + decay: 0.0002
1 +git+https://github.com/wbaek/theconf
2 +git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@08f7d5e
3 +git+https://github.com/ildoonet/pystopwatch2.git
4 +git+https://github.com/hyperopt/hyperopt.git
5 +git+https://github.com/kakaobrain/torchlars
6 +
7 +pretrainedmodels
8 +tqdm
9 +tensorboardx
10 +sklearn
11 +ray
12 +matplotlib
13 +psutil
14 +requests
...\ No newline at end of file ...\ No newline at end of file
1 +{
2 + "cells": [
3 + {
4 + "cell_type": "code",
5 + "execution_count": 1,
6 + "metadata": {},
7 + "outputs": [],
8 + "source": [
9 + "import pandas as pd"
10 + ]
11 + },
12 + {
13 + "cell_type": "code",
14 + "execution_count": 3,
15 + "metadata": {},
16 + "outputs": [
17 + {
18 + "data": {
19 + "text/html": [
20 + "<div>\n",
21 + "<style scoped>\n",
22 + " .dataframe tbody tr th:only-of-type {\n",
23 + " vertical-align: middle;\n",
24 + " }\n",
25 + "\n",
26 + " .dataframe tbody tr th {\n",
27 + " vertical-align: top;\n",
28 + " }\n",
29 + "\n",
30 + " .dataframe thead th {\n",
31 + " text-align: right;\n",
32 + " }\n",
33 + "</style>\n",
34 + "<table border=\"1\" class=\"dataframe\">\n",
35 + " <thead>\n",
36 + " <tr style=\"text-align: right;\">\n",
37 + " <th></th>\n",
38 + " <th>id</th>\n",
39 + " <th>Label</th>\n",
40 + " <th>Subject</th>\n",
41 + " <th>Date</th>\n",
42 + " <th>Gender</th>\n",
43 + " <th>Age</th>\n",
44 + " <th>mmse</th>\n",
45 + " <th>ageAtEntry</th>\n",
46 + " <th>cdr</th>\n",
47 + " <th>commun</th>\n",
48 + " <th>...</th>\n",
49 + " <th>memory</th>\n",
50 + " <th>orient</th>\n",
51 + " <th>perscare</th>\n",
52 + " <th>apoe</th>\n",
53 + " <th>sumbox</th>\n",
54 + " <th>acsparnt</th>\n",
55 + " <th>height</th>\n",
56 + " <th>weight</th>\n",
57 + " <th>primStudy</th>\n",
58 + " <th>acsStudy</th>\n",
59 + " </tr>\n",
60 + " </thead>\n",
61 + " <tbody>\n",
62 + " <tr>\n",
63 + " <th>0</th>\n",
64 + " <td>/@WEBAPP/images/r.gif</td>\n",
65 + " <td>OAS30001_ClinicalData_d3025</td>\n",
66 + " <td>OAS30001</td>\n",
67 + " <td>NaN</td>\n",
68 + " <td>female</td>\n",
69 + " <td>NaN</td>\n",
70 + " <td>30.0</td>\n",
71 + " <td>65.149895</td>\n",
72 + " <td>0.0</td>\n",
73 + " <td>0.0</td>\n",
74 + " <td>...</td>\n",
75 + " <td>0.0</td>\n",
76 + " <td>0.0</td>\n",
77 + " <td>0.0</td>\n",
78 + " <td>23.0</td>\n",
79 + " <td>0.0</td>\n",
80 + " <td>NaN</td>\n",
81 + " <td>64.0</td>\n",
82 + " <td>180.0</td>\n",
83 + " <td>NaN</td>\n",
84 + " <td>NaN</td>\n",
85 + " </tr>\n",
86 + " <tr>\n",
87 + " <th>1</th>\n",
88 + " <td>/@WEBAPP/images/r.gif</td>\n",
89 + " <td>OAS30001_ClinicalData_d3977</td>\n",
90 + " <td>OAS30001</td>\n",
91 + " <td>NaN</td>\n",
92 + " <td>female</td>\n",
93 + " <td>NaN</td>\n",
94 + " <td>29.0</td>\n",
95 + " <td>65.149895</td>\n",
96 + " <td>0.0</td>\n",
97 + " <td>0.0</td>\n",
98 + " <td>...</td>\n",
99 + " <td>0.0</td>\n",
100 + " <td>0.0</td>\n",
101 + " <td>0.0</td>\n",
102 + " <td>23.0</td>\n",
103 + " <td>0.0</td>\n",
104 + " <td>NaN</td>\n",
105 + " <td>NaN</td>\n",
106 + " <td>NaN</td>\n",
107 + " <td>NaN</td>\n",
108 + " <td>NaN</td>\n",
109 + " </tr>\n",
110 + " <tr>\n",
111 + " <th>2</th>\n",
112 + " <td>/@WEBAPP/images/r.gif</td>\n",
113 + " <td>OAS30001_ClinicalData_d3332</td>\n",
114 + " <td>OAS30001</td>\n",
115 + " <td>NaN</td>\n",
116 + " <td>female</td>\n",
117 + " <td>NaN</td>\n",
118 + " <td>30.0</td>\n",
119 + " <td>65.149895</td>\n",
120 + " <td>0.0</td>\n",
121 + " <td>0.0</td>\n",
122 + " <td>...</td>\n",
123 + " <td>0.0</td>\n",
124 + " <td>0.0</td>\n",
125 + " <td>0.0</td>\n",
126 + " <td>23.0</td>\n",
127 + " <td>0.0</td>\n",
128 + " <td>NaN</td>\n",
129 + " <td>63.0</td>\n",
130 + " <td>185.0</td>\n",
131 + " <td>NaN</td>\n",
132 + " <td>NaN</td>\n",
133 + " </tr>\n",
134 + " <tr>\n",
135 + " <th>3</th>\n",
136 + " <td>/@WEBAPP/images/r.gif</td>\n",
137 + " <td>OAS30001_ClinicalData_d0000</td>\n",
138 + " <td>OAS30001</td>\n",
139 + " <td>NaN</td>\n",
140 + " <td>female</td>\n",
141 + " <td>NaN</td>\n",
142 + " <td>28.0</td>\n",
143 + " <td>65.149895</td>\n",
144 + " <td>0.0</td>\n",
145 + " <td>0.0</td>\n",
146 + " <td>...</td>\n",
147 + " <td>0.0</td>\n",
148 + " <td>0.0</td>\n",
149 + " <td>0.0</td>\n",
150 + " <td>23.0</td>\n",
151 + " <td>0.0</td>\n",
152 + " <td>NaN</td>\n",
153 + " <td>NaN</td>\n",
154 + " <td>NaN</td>\n",
155 + " <td>NaN</td>\n",
156 + " <td>NaN</td>\n",
157 + " </tr>\n",
158 + " <tr>\n",
159 + " <th>4</th>\n",
160 + " <td>/@WEBAPP/images/r.gif</td>\n",
161 + " <td>OAS30001_ClinicalData_d1456</td>\n",
162 + " <td>OAS30001</td>\n",
163 + " <td>NaN</td>\n",
164 + " <td>female</td>\n",
165 + " <td>NaN</td>\n",
166 + " <td>30.0</td>\n",
167 + " <td>65.149895</td>\n",
168 + " <td>0.0</td>\n",
169 + " <td>0.0</td>\n",
170 + " <td>...</td>\n",
171 + " <td>0.0</td>\n",
172 + " <td>0.0</td>\n",
173 + " <td>0.0</td>\n",
174 + " <td>23.0</td>\n",
175 + " <td>0.0</td>\n",
176 + " <td>NaN</td>\n",
177 + " <td>63.0</td>\n",
178 + " <td>173.0</td>\n",
179 + " <td>NaN</td>\n",
180 + " <td>NaN</td>\n",
181 + " </tr>\n",
182 + " </tbody>\n",
183 + "</table>\n",
184 + "<p>5 rows × 27 columns</p>\n",
185 + "</div>"
186 + ],
187 + "text/plain": [
188 + " id Label Subject Date Gender \\\n",
189 + "0 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3025 OAS30001 NaN female \n",
190 + "1 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3977 OAS30001 NaN female \n",
191 + "2 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3332 OAS30001 NaN female \n",
192 + "3 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d0000 OAS30001 NaN female \n",
193 + "4 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d1456 OAS30001 NaN female \n",
194 + "\n",
195 + " Age mmse ageAtEntry cdr commun ... memory orient perscare apoe \\\n",
196 + "0 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
197 + "1 NaN 29.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
198 + "2 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
199 + "3 NaN 28.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
200 + "4 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n",
201 + "\n",
202 + " sumbox acsparnt height weight primStudy acsStudy \n",
203 + "0 0.0 NaN 64.0 180.0 NaN NaN \n",
204 + "1 0.0 NaN NaN NaN NaN NaN \n",
205 + "2 0.0 NaN 63.0 185.0 NaN NaN \n",
206 + "3 0.0 NaN NaN NaN NaN NaN \n",
207 + "4 0.0 NaN 63.0 173.0 NaN NaN \n",
208 + "\n",
209 + "[5 rows x 27 columns]"
210 + ]
211 + },
212 + "execution_count": 3,
213 + "metadata": {},
214 + "output_type": "execute_result"
215 + }
216 + ],
217 + "source": [
218 + "all_data = pd.read_csv(\"..\\data\\ADRC clinical data_all.csv\")\n",
219 + "\n",
220 + "all_data.head()"
221 + ]
222 + },
223 + {
224 + "cell_type": "code",
225 + "execution_count": 18,
226 + "metadata": {},
227 + "outputs": [
228 + {
229 + "name": "stdout",
230 + "output_type": "stream",
231 + "text": [
232 + "Subject\n",
233 + "OAS30001 0.0\n",
234 + "OAS30002 0.0\n",
235 + "OAS30003 0.0\n",
236 + "OAS30004 0.0\n",
237 + "OAS30005 0.0\n",
238 + " ... \n",
239 + "OAS31168 0.0\n",
240 + "OAS31169 3.0\n",
241 + "OAS31170 2.0\n",
242 + "OAS31171 2.0\n",
243 + "OAS31172 0.0\n",
244 + "Name: cdr, Length: 1098, dtype: float64\n"
245 + ]
246 + }
247 + ],
248 + "source": [
249 + "ad = all_data.groupby(['Subject'])['cdr'].max()\n",
250 + "print(ad)"
251 + ]
252 + },
253 + {
254 + "cell_type": "code",
255 + "execution_count": 21,
256 + "metadata": {},
257 + "outputs": [
258 + {
259 + "data": {
260 + "text/plain": [
261 + "'OAS30001'"
262 + ]
263 + },
264 + "execution_count": 21,
265 + "metadata": {},
266 + "output_type": "execute_result"
267 + }
268 + ],
269 + "source": [
270 + "ad.index[0]"
271 + ]
272 + },
273 + {
274 + "cell_type": "code",
275 + "execution_count": 22,
276 + "metadata": {},
277 + "outputs": [
278 + {
279 + "ename": "SyntaxError",
280 + "evalue": "unexpected EOF while parsing (<ipython-input-22-b8a078b72aca>, line 5)",
281 + "output_type": "error",
282 + "traceback": [
283 + "\u001b[1;36m File \u001b[1;32m\"<ipython-input-22-b8a078b72aca>\"\u001b[1;36m, line \u001b[1;32m5\u001b[0m\n\u001b[1;33m #print(filtered)\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m unexpected EOF while parsing\n"
284 + ]
285 + }
286 + ],
287 + "source": [
288 + "filtered = []\n",
289 + "for i, val in enumerate(ad):\n",
290 + " if ad[i] == 0:\n",
291 + " filtered.append(ad.index[i])\n",
292 + "#print(filtered)"
293 + ]
294 + },
295 + {
296 + "cell_type": "code",
297 + "execution_count": 23,
298 + "metadata": {},
299 + "outputs": [],
300 + "source": [
301 + "df_filtered = pd.DataFrame(filtered)\n",
302 + "df_filtered.to_csv('..\\data\\ADRC clinical data_normal.csv')"
303 + ]
304 + },
305 + {
306 + "cell_type": "code",
307 + "execution_count": null,
308 + "metadata": {},
309 + "outputs": [],
310 + "source": []
311 + }
312 + ],
313 + "metadata": {
314 + "kernelspec": {
315 + "display_name": "ML",
316 + "language": "python",
317 + "name": "ml"
318 + },
319 + "language_info": {
320 + "codemirror_mode": {
321 + "name": "ipython",
322 + "version": 3
323 + },
324 + "file_extension": ".py",
325 + "mimetype": "text/x-python",
326 + "name": "python",
327 + "nbconvert_exporter": "python",
328 + "pygments_lexer": "ipython3",
329 + "version": "3.7.4"
330 + }
331 + },
332 + "nbformat": 4,
333 + "nbformat_minor": 2
334 +}
This diff could not be displayed because it is too large.
1 +%load('..\data\MICCAI_BraTS_2019_Data_Training\name_mapping.csv')
2 +
3 +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
4 +outfolder = '..\testfolder\';
5 +id = 'BraTS19_2013_2_1';
6 +
7 +type = 'flair.nii';
8 +filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
9 +flair_path = strcat(inputheader, id, '\', filename,'\', filename);
10 +%disp(path);
11 +flair = niftiread(flair_path); %size 240x240x155
12 +cp_flair = flair;
13 +
14 +type = 'seg.nii';
15 +filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii
16 +seg_path = strcat(inputheader, id, '\', filename, '\', filename);
17 +seg = niftiread(seg_path);
18 +
19 +[x,y,z] = size(seg);
20 +
21 +% copy flair, segment flair data
22 +
23 +cp_flair(seg == 0) = 0;
24 +
25 +% save a segmented data
26 +type = 'seg_flair.nii';
27 +filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg_flair.nii
28 +outpath = strcat(outfolder, filename);
29 +%niftiwrite(cp_flair, outpath);
30 +
31 +
32 +%whos seg
33 +
34 +
35 +%extract = seg(84, :, 86);
36 +
37 +
38 +
39 +
40 +% cp84 = cp_flair(84,:,86);
41 +% flair84 = flair(84,:, 86);
42 +
43 +
44 +
45 +%[flair,info] = ReadData3D(filename)
46 +% whos flair
47 +
48 +%volumeViewer(flair);
...\ No newline at end of file ...\ No newline at end of file
1 +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
2 +outfolder = strcat('..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\');
3 +
4 +files = dir(inputheader);
5 +id = {files.name};
6 +% files + dir dir
7 +dirFlag = [files.isdir] & ~strcmp(id, '.') & ~strcmp(id, '..');
8 +subFolders = files(dirFlag);
9 +disp(length(subFolders));
10 +
11 +% for k = 1 : length(subFolders)
12 +% fprintf('Sub folder #%d = %s\n', k, subFolders(k).name);
13 +% end
14 +
15 +for i = 1 : length(subFolders)
16 +
17 + id = subFolders(i).name;
18 + fprintf('Sub folder #%d = %s\n', i, id);
19 +
20 + type = 'flair.nii';
21 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
22 + flair_path = strcat(inputheader, id, '\', filename,'\', filename);
23 + flair = niftiread(flair_path); %size 240x240x155
24 + cp_flair = flair;
25 +
26 + type = 'seg.nii';
27 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii
28 + seg_path = strcat(inputheader, id, '\', filename, '\', filename);
29 + seg = niftiread(seg_path);
30 +
31 + [x,y,z] = size(seg);
32 +
33 + % copy flair, segment flair data
34 +
35 + cp_flair(seg == 0) = 0;
36 +
37 + % save a segmented data
38 + type = 'seg_flair.nii';
39 + filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg_flair.nii
40 + outpath = strcat(outfolder, filename);
41 + niftiwrite(cp_flair, outpath);
42 +
43 +end
1 +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\';
2 +outfolder = '..\testfolder\';
3 +id = 'BraTS19_2013_2_1';
4 +
5 +type = 'seg_flair.nii';
6 +filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii
7 +flair_path = strcat(outfolder, filename);
8 +%disp(path);
9 +ffffff = niftiread(flair_path); %size 240x240x155
10 +
11 +fff84 = ffffff(84, :, 86);
...\ No newline at end of file ...\ No newline at end of file