Showing
64 changed files
with
5314 additions
and
0 deletions
code/.idea/code.iml
0 → 100644
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<module type="PYTHON_MODULE" version="4"> | ||
3 | + <component name="NewModuleRootManager"> | ||
4 | + <content url="file://$MODULE_DIR$" /> | ||
5 | + <orderEntry type="inheritedJdk" /> | ||
6 | + <orderEntry type="sourceFolder" forTests="false" /> | ||
7 | + </component> | ||
8 | + <component name="TestRunnerService"> | ||
9 | + <option name="PROJECT_TEST_RUNNER" value="Unittests" /> | ||
10 | + </component> | ||
11 | +</module> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/.idea/encodings.xml
0 → 100644
code/.idea/misc.xml
0 → 100644
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<project version="4"> | ||
3 | + <component name="JavaScriptSettings"> | ||
4 | + <option name="languageLevel" value="ES6" /> | ||
5 | + </component> | ||
6 | + <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (code) (3)" project-jdk-type="Python SDK" /> | ||
7 | +</project> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/.idea/modules.xml
0 → 100644
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<project version="4"> | ||
3 | + <component name="ProjectModuleManager"> | ||
4 | + <modules> | ||
5 | + <module fileurl="file://$PROJECT_DIR$/.idea/code.iml" filepath="$PROJECT_DIR$/.idea/code.iml" /> | ||
6 | + </modules> | ||
7 | + </component> | ||
8 | +</project> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/.idea/workspace.xml
0 → 100644
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<project version="4"> | ||
3 | + <component name="ChangeListManager"> | ||
4 | + <list default="true" id="38bbd81d-cf79-4a3f-8979-7c3ceb27bc32" name="Default Changelist" comment="" /> | ||
5 | + <option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" /> | ||
6 | + <option name="SHOW_DIALOG" value="false" /> | ||
7 | + <option name="HIGHLIGHT_CONFLICTS" value="true" /> | ||
8 | + <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" /> | ||
9 | + <option name="LAST_RESOLUTION" value="IGNORE" /> | ||
10 | + </component> | ||
11 | + <component name="ProjectFrameBounds"> | ||
12 | + <option name="x" value="-9" /> | ||
13 | + <option name="width" value="1825" /> | ||
14 | + <option name="height" value="1039" /> | ||
15 | + </component> | ||
16 | + <component name="ProjectView"> | ||
17 | + <navigator proportions="" version="1"> | ||
18 | + <foldersAlwaysOnTop value="true" /> | ||
19 | + </navigator> | ||
20 | + <panes> | ||
21 | + <pane id="Scope" /> | ||
22 | + <pane id="ProjectPane" /> | ||
23 | + </panes> | ||
24 | + </component> | ||
25 | + <component name="PropertiesComponent"> | ||
26 | + <property name="WebServerToolWindowFactoryState" value="false" /> | ||
27 | + <property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" /> | ||
28 | + <property name="nodejs_npm_path_reset_for_default_project" value="true" /> | ||
29 | + <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" /> | ||
30 | + </component> | ||
31 | + <component name="RunDashboard"> | ||
32 | + <option name="ruleStates"> | ||
33 | + <list> | ||
34 | + <RuleState> | ||
35 | + <option name="name" value="ConfigurationTypeDashboardGroupingRule" /> | ||
36 | + </RuleState> | ||
37 | + <RuleState> | ||
38 | + <option name="name" value="StatusDashboardGroupingRule" /> | ||
39 | + </RuleState> | ||
40 | + </list> | ||
41 | + </option> | ||
42 | + </component> | ||
43 | + <component name="SvnConfiguration"> | ||
44 | + <configuration /> | ||
45 | + </component> | ||
46 | + <component name="TaskManager"> | ||
47 | + <task active="true" id="Default" summary="Default task"> | ||
48 | + <changelist id="38bbd81d-cf79-4a3f-8979-7c3ceb27bc32" name="Default Changelist" comment="" /> | ||
49 | + <created>1584871834688</created> | ||
50 | + <option name="number" value="Default" /> | ||
51 | + <option name="presentableId" value="Default" /> | ||
52 | + <updated>1584871834688</updated> | ||
53 | + <workItem from="1584871840396" duration="72000" /> | ||
54 | + </task> | ||
55 | + <servers /> | ||
56 | + </component> | ||
57 | + <component name="TimeTrackingManager"> | ||
58 | + <option name="totallyTimeSpent" value="72000" /> | ||
59 | + </component> | ||
60 | + <component name="ToolWindowManager"> | ||
61 | + <frame x="-7" y="0" width="1460" height="831" extended-state="0" /> | ||
62 | + <layout> | ||
63 | + <window_info id="Favorites" side_tool="true" /> | ||
64 | + <window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.25" /> | ||
65 | + <window_info id="Structure" order="1" side_tool="true" weight="0.25" /> | ||
66 | + <window_info anchor="bottom" id="Docker" show_stripe_button="false" /> | ||
67 | + <window_info anchor="bottom" id="Database Changes" /> | ||
68 | + <window_info anchor="bottom" id="Version Control" /> | ||
69 | + <window_info anchor="bottom" id="Python Console" /> | ||
70 | + <window_info anchor="bottom" id="Terminal" /> | ||
71 | + <window_info anchor="bottom" id="Event Log" side_tool="true" /> | ||
72 | + <window_info anchor="bottom" id="Message" order="0" /> | ||
73 | + <window_info anchor="bottom" id="Find" order="1" /> | ||
74 | + <window_info anchor="bottom" id="Run" order="2" /> | ||
75 | + <window_info anchor="bottom" id="Debug" order="3" weight="0.4" /> | ||
76 | + <window_info anchor="bottom" id="Cvs" order="4" weight="0.25" /> | ||
77 | + <window_info anchor="bottom" id="Inspection" order="5" weight="0.4" /> | ||
78 | + <window_info anchor="bottom" id="TODO" order="6" /> | ||
79 | + <window_info anchor="right" id="SciView" /> | ||
80 | + <window_info anchor="right" id="Database" /> | ||
81 | + <window_info anchor="right" id="Commander" internal_type="SLIDING" order="0" type="SLIDING" weight="0.4" /> | ||
82 | + <window_info anchor="right" id="Ant Build" order="1" weight="0.25" /> | ||
83 | + <window_info anchor="right" content_ui="combo" id="Hierarchy" order="2" weight="0.25" /> | ||
84 | + </layout> | ||
85 | + </component> | ||
86 | + <component name="TypeScriptGeneratedFilesManager"> | ||
87 | + <option name="version" value="1" /> | ||
88 | + </component> | ||
89 | +</project> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +{ | ||
2 | + "cells": [ | ||
3 | + { | ||
4 | + "cell_type": "code", | ||
5 | + "execution_count": 1, | ||
6 | + "metadata": {}, | ||
7 | + "outputs": [], | ||
8 | + "source": [ | ||
9 | + "import pandas as pd" | ||
10 | + ] | ||
11 | + }, | ||
12 | + { | ||
13 | + "cell_type": "code", | ||
14 | + "execution_count": 3, | ||
15 | + "metadata": {}, | ||
16 | + "outputs": [ | ||
17 | + { | ||
18 | + "data": { | ||
19 | + "text/html": [ | ||
20 | + "<div>\n", | ||
21 | + "<style scoped>\n", | ||
22 | + " .dataframe tbody tr th:only-of-type {\n", | ||
23 | + " vertical-align: middle;\n", | ||
24 | + " }\n", | ||
25 | + "\n", | ||
26 | + " .dataframe tbody tr th {\n", | ||
27 | + " vertical-align: top;\n", | ||
28 | + " }\n", | ||
29 | + "\n", | ||
30 | + " .dataframe thead th {\n", | ||
31 | + " text-align: right;\n", | ||
32 | + " }\n", | ||
33 | + "</style>\n", | ||
34 | + "<table border=\"1\" class=\"dataframe\">\n", | ||
35 | + " <thead>\n", | ||
36 | + " <tr style=\"text-align: right;\">\n", | ||
37 | + " <th></th>\n", | ||
38 | + " <th>id</th>\n", | ||
39 | + " <th>Label</th>\n", | ||
40 | + " <th>Subject</th>\n", | ||
41 | + " <th>Date</th>\n", | ||
42 | + " <th>Gender</th>\n", | ||
43 | + " <th>Age</th>\n", | ||
44 | + " <th>mmse</th>\n", | ||
45 | + " <th>ageAtEntry</th>\n", | ||
46 | + " <th>cdr</th>\n", | ||
47 | + " <th>commun</th>\n", | ||
48 | + " <th>...</th>\n", | ||
49 | + " <th>memory</th>\n", | ||
50 | + " <th>orient</th>\n", | ||
51 | + " <th>perscare</th>\n", | ||
52 | + " <th>apoe</th>\n", | ||
53 | + " <th>sumbox</th>\n", | ||
54 | + " <th>acsparnt</th>\n", | ||
55 | + " <th>height</th>\n", | ||
56 | + " <th>weight</th>\n", | ||
57 | + " <th>primStudy</th>\n", | ||
58 | + " <th>acsStudy</th>\n", | ||
59 | + " </tr>\n", | ||
60 | + " </thead>\n", | ||
61 | + " <tbody>\n", | ||
62 | + " <tr>\n", | ||
63 | + " <th>0</th>\n", | ||
64 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
65 | + " <td>OAS30001_ClinicalData_d3025</td>\n", | ||
66 | + " <td>OAS30001</td>\n", | ||
67 | + " <td>NaN</td>\n", | ||
68 | + " <td>female</td>\n", | ||
69 | + " <td>NaN</td>\n", | ||
70 | + " <td>30.0</td>\n", | ||
71 | + " <td>65.149895</td>\n", | ||
72 | + " <td>0.0</td>\n", | ||
73 | + " <td>0.0</td>\n", | ||
74 | + " <td>...</td>\n", | ||
75 | + " <td>0.0</td>\n", | ||
76 | + " <td>0.0</td>\n", | ||
77 | + " <td>0.0</td>\n", | ||
78 | + " <td>23.0</td>\n", | ||
79 | + " <td>0.0</td>\n", | ||
80 | + " <td>NaN</td>\n", | ||
81 | + " <td>64.0</td>\n", | ||
82 | + " <td>180.0</td>\n", | ||
83 | + " <td>NaN</td>\n", | ||
84 | + " <td>NaN</td>\n", | ||
85 | + " </tr>\n", | ||
86 | + " <tr>\n", | ||
87 | + " <th>1</th>\n", | ||
88 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
89 | + " <td>OAS30001_ClinicalData_d3977</td>\n", | ||
90 | + " <td>OAS30001</td>\n", | ||
91 | + " <td>NaN</td>\n", | ||
92 | + " <td>female</td>\n", | ||
93 | + " <td>NaN</td>\n", | ||
94 | + " <td>29.0</td>\n", | ||
95 | + " <td>65.149895</td>\n", | ||
96 | + " <td>0.0</td>\n", | ||
97 | + " <td>0.0</td>\n", | ||
98 | + " <td>...</td>\n", | ||
99 | + " <td>0.0</td>\n", | ||
100 | + " <td>0.0</td>\n", | ||
101 | + " <td>0.0</td>\n", | ||
102 | + " <td>23.0</td>\n", | ||
103 | + " <td>0.0</td>\n", | ||
104 | + " <td>NaN</td>\n", | ||
105 | + " <td>NaN</td>\n", | ||
106 | + " <td>NaN</td>\n", | ||
107 | + " <td>NaN</td>\n", | ||
108 | + " <td>NaN</td>\n", | ||
109 | + " </tr>\n", | ||
110 | + " <tr>\n", | ||
111 | + " <th>2</th>\n", | ||
112 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
113 | + " <td>OAS30001_ClinicalData_d3332</td>\n", | ||
114 | + " <td>OAS30001</td>\n", | ||
115 | + " <td>NaN</td>\n", | ||
116 | + " <td>female</td>\n", | ||
117 | + " <td>NaN</td>\n", | ||
118 | + " <td>30.0</td>\n", | ||
119 | + " <td>65.149895</td>\n", | ||
120 | + " <td>0.0</td>\n", | ||
121 | + " <td>0.0</td>\n", | ||
122 | + " <td>...</td>\n", | ||
123 | + " <td>0.0</td>\n", | ||
124 | + " <td>0.0</td>\n", | ||
125 | + " <td>0.0</td>\n", | ||
126 | + " <td>23.0</td>\n", | ||
127 | + " <td>0.0</td>\n", | ||
128 | + " <td>NaN</td>\n", | ||
129 | + " <td>63.0</td>\n", | ||
130 | + " <td>185.0</td>\n", | ||
131 | + " <td>NaN</td>\n", | ||
132 | + " <td>NaN</td>\n", | ||
133 | + " </tr>\n", | ||
134 | + " <tr>\n", | ||
135 | + " <th>3</th>\n", | ||
136 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
137 | + " <td>OAS30001_ClinicalData_d0000</td>\n", | ||
138 | + " <td>OAS30001</td>\n", | ||
139 | + " <td>NaN</td>\n", | ||
140 | + " <td>female</td>\n", | ||
141 | + " <td>NaN</td>\n", | ||
142 | + " <td>28.0</td>\n", | ||
143 | + " <td>65.149895</td>\n", | ||
144 | + " <td>0.0</td>\n", | ||
145 | + " <td>0.0</td>\n", | ||
146 | + " <td>...</td>\n", | ||
147 | + " <td>0.0</td>\n", | ||
148 | + " <td>0.0</td>\n", | ||
149 | + " <td>0.0</td>\n", | ||
150 | + " <td>23.0</td>\n", | ||
151 | + " <td>0.0</td>\n", | ||
152 | + " <td>NaN</td>\n", | ||
153 | + " <td>NaN</td>\n", | ||
154 | + " <td>NaN</td>\n", | ||
155 | + " <td>NaN</td>\n", | ||
156 | + " <td>NaN</td>\n", | ||
157 | + " </tr>\n", | ||
158 | + " <tr>\n", | ||
159 | + " <th>4</th>\n", | ||
160 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
161 | + " <td>OAS30001_ClinicalData_d1456</td>\n", | ||
162 | + " <td>OAS30001</td>\n", | ||
163 | + " <td>NaN</td>\n", | ||
164 | + " <td>female</td>\n", | ||
165 | + " <td>NaN</td>\n", | ||
166 | + " <td>30.0</td>\n", | ||
167 | + " <td>65.149895</td>\n", | ||
168 | + " <td>0.0</td>\n", | ||
169 | + " <td>0.0</td>\n", | ||
170 | + " <td>...</td>\n", | ||
171 | + " <td>0.0</td>\n", | ||
172 | + " <td>0.0</td>\n", | ||
173 | + " <td>0.0</td>\n", | ||
174 | + " <td>23.0</td>\n", | ||
175 | + " <td>0.0</td>\n", | ||
176 | + " <td>NaN</td>\n", | ||
177 | + " <td>63.0</td>\n", | ||
178 | + " <td>173.0</td>\n", | ||
179 | + " <td>NaN</td>\n", | ||
180 | + " <td>NaN</td>\n", | ||
181 | + " </tr>\n", | ||
182 | + " </tbody>\n", | ||
183 | + "</table>\n", | ||
184 | + "<p>5 rows × 27 columns</p>\n", | ||
185 | + "</div>" | ||
186 | + ], | ||
187 | + "text/plain": [ | ||
188 | + " id Label Subject Date Gender \\\n", | ||
189 | + "0 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3025 OAS30001 NaN female \n", | ||
190 | + "1 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3977 OAS30001 NaN female \n", | ||
191 | + "2 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3332 OAS30001 NaN female \n", | ||
192 | + "3 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d0000 OAS30001 NaN female \n", | ||
193 | + "4 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d1456 OAS30001 NaN female \n", | ||
194 | + "\n", | ||
195 | + " Age mmse ageAtEntry cdr commun ... memory orient perscare apoe \\\n", | ||
196 | + "0 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
197 | + "1 NaN 29.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
198 | + "2 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
199 | + "3 NaN 28.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
200 | + "4 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
201 | + "\n", | ||
202 | + " sumbox acsparnt height weight primStudy acsStudy \n", | ||
203 | + "0 0.0 NaN 64.0 180.0 NaN NaN \n", | ||
204 | + "1 0.0 NaN NaN NaN NaN NaN \n", | ||
205 | + "2 0.0 NaN 63.0 185.0 NaN NaN \n", | ||
206 | + "3 0.0 NaN NaN NaN NaN NaN \n", | ||
207 | + "4 0.0 NaN 63.0 173.0 NaN NaN \n", | ||
208 | + "\n", | ||
209 | + "[5 rows x 27 columns]" | ||
210 | + ] | ||
211 | + }, | ||
212 | + "execution_count": 3, | ||
213 | + "metadata": {}, | ||
214 | + "output_type": "execute_result" | ||
215 | + } | ||
216 | + ], | ||
217 | + "source": [ | ||
218 | + "all_data = pd.read_csv(\"..\\data\\ADRC clinical data_all.csv\")\n", | ||
219 | + "\n", | ||
220 | + "all_data.head()" | ||
221 | + ] | ||
222 | + }, | ||
223 | + { | ||
224 | + "cell_type": "code", | ||
225 | + "execution_count": 18, | ||
226 | + "metadata": {}, | ||
227 | + "outputs": [ | ||
228 | + { | ||
229 | + "name": "stdout", | ||
230 | + "output_type": "stream", | ||
231 | + "text": [ | ||
232 | + "Subject\n", | ||
233 | + "OAS30001 0.0\n", | ||
234 | + "OAS30002 0.0\n", | ||
235 | + "OAS30003 0.0\n", | ||
236 | + "OAS30004 0.0\n", | ||
237 | + "OAS30005 0.0\n", | ||
238 | + " ... \n", | ||
239 | + "OAS31168 0.0\n", | ||
240 | + "OAS31169 3.0\n", | ||
241 | + "OAS31170 2.0\n", | ||
242 | + "OAS31171 2.0\n", | ||
243 | + "OAS31172 0.0\n", | ||
244 | + "Name: cdr, Length: 1098, dtype: float64\n" | ||
245 | + ] | ||
246 | + } | ||
247 | + ], | ||
248 | + "source": [ | ||
249 | + "ad = all_data.groupby(['Subject'])['cdr'].max()\n", | ||
250 | + "print(ad)" | ||
251 | + ] | ||
252 | + }, | ||
253 | + { | ||
254 | + "cell_type": "code", | ||
255 | + "execution_count": 21, | ||
256 | + "metadata": {}, | ||
257 | + "outputs": [ | ||
258 | + { | ||
259 | + "data": { | ||
260 | + "text/plain": [ | ||
261 | + "'OAS30001'" | ||
262 | + ] | ||
263 | + }, | ||
264 | + "execution_count": 21, | ||
265 | + "metadata": {}, | ||
266 | + "output_type": "execute_result" | ||
267 | + } | ||
268 | + ], | ||
269 | + "source": [ | ||
270 | + "ad.index[0]" | ||
271 | + ] | ||
272 | + }, | ||
273 | + { | ||
274 | + "cell_type": "code", | ||
275 | + "execution_count": 22, | ||
276 | + "metadata": {}, | ||
277 | + "outputs": [ | ||
278 | + { | ||
279 | + "ename": "SyntaxError", | ||
280 | + "evalue": "unexpected EOF while parsing (<ipython-input-22-b8a078b72aca>, line 5)", | ||
281 | + "output_type": "error", | ||
282 | + "traceback": [ | ||
283 | + "\u001b[1;36m File \u001b[1;32m\"<ipython-input-22-b8a078b72aca>\"\u001b[1;36m, line \u001b[1;32m5\u001b[0m\n\u001b[1;33m #print(filtered)\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m unexpected EOF while parsing\n" | ||
284 | + ] | ||
285 | + } | ||
286 | + ], | ||
287 | + "source": [ | ||
288 | + "filtered = []\n", | ||
289 | + "for i, val in enumerate(ad):\n", | ||
290 | + " if ad[i] == 0:\n", | ||
291 | + " filtered.append(ad.index[i])\n", | ||
292 | + "#print(filtered)" | ||
293 | + ] | ||
294 | + }, | ||
295 | + { | ||
296 | + "cell_type": "code", | ||
297 | + "execution_count": 23, | ||
298 | + "metadata": {}, | ||
299 | + "outputs": [], | ||
300 | + "source": [ | ||
301 | + "df_filtered = pd.DataFrame(filtered)\n", | ||
302 | + "df_filtered.to_csv('..\\data\\ADRC clinical data_normal.csv')" | ||
303 | + ] | ||
304 | + }, | ||
305 | + { | ||
306 | + "cell_type": "code", | ||
307 | + "execution_count": null, | ||
308 | + "metadata": {}, | ||
309 | + "outputs": [], | ||
310 | + "source": [] | ||
311 | + } | ||
312 | + ], | ||
313 | + "metadata": { | ||
314 | + "kernelspec": { | ||
315 | + "display_name": "ML", | ||
316 | + "language": "python", | ||
317 | + "name": "ml" | ||
318 | + }, | ||
319 | + "language_info": { | ||
320 | + "codemirror_mode": { | ||
321 | + "name": "ipython", | ||
322 | + "version": 3 | ||
323 | + }, | ||
324 | + "file_extension": ".py", | ||
325 | + "mimetype": "text/x-python", | ||
326 | + "name": "python", | ||
327 | + "nbconvert_exporter": "python", | ||
328 | + "pygments_lexer": "ipython3", | ||
329 | + "version": "3.7.4" | ||
330 | + } | ||
331 | + }, | ||
332 | + "nbformat": 4, | ||
333 | + "nbformat_minor": 2 | ||
334 | +} |
This diff could not be displayed because it is too large.
code/fast-autoaugment-master/.gitignore
0 → 100644
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | + | ||
6 | +# C extensions | ||
7 | +*.so | ||
8 | + | ||
9 | +# Distribution / packaging | ||
10 | +.Python | ||
11 | +build/ | ||
12 | +develop-eggs/ | ||
13 | +dist/ | ||
14 | +downloads/ | ||
15 | +eggs/ | ||
16 | +.eggs/ | ||
17 | +lib/ | ||
18 | +lib64/ | ||
19 | +parts/ | ||
20 | +sdist/ | ||
21 | +var/ | ||
22 | +wheels/ | ||
23 | +*.egg-info/ | ||
24 | +.installed.cfg | ||
25 | +*.egg | ||
26 | +MANIFEST | ||
27 | + | ||
28 | +# PyInstaller | ||
29 | +# Usually these files are written by a python script from a template | ||
30 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
31 | +*.manifest | ||
32 | +*.spec | ||
33 | + | ||
34 | +# Installer logs | ||
35 | +pip-log.txt | ||
36 | +pip-delete-this-directory.txt | ||
37 | + | ||
38 | +# Unit test / coverage reports | ||
39 | +htmlcov/ | ||
40 | +.tox/ | ||
41 | +.coverage | ||
42 | +.coverage.* | ||
43 | +.cache | ||
44 | +nosetests.xml | ||
45 | +coverage.xml | ||
46 | +*.cover | ||
47 | +.hypothesis/ | ||
48 | +.pytest_cache/ | ||
49 | + | ||
50 | +# Translations | ||
51 | +*.mo | ||
52 | +*.pot | ||
53 | + | ||
54 | +# Django stuff: | ||
55 | +*.log | ||
56 | +local_settings.py | ||
57 | +db.sqlite3 | ||
58 | + | ||
59 | +# Flask stuff: | ||
60 | +instance/ | ||
61 | +.webassets-cache | ||
62 | + | ||
63 | +# Scrapy stuff: | ||
64 | +.scrapy | ||
65 | + | ||
66 | +# Sphinx documentation | ||
67 | +docs/_build/ | ||
68 | + | ||
69 | +# PyBuilder | ||
70 | +target/ | ||
71 | + | ||
72 | +# Jupyter Notebook | ||
73 | +.ipynb_checkpoints | ||
74 | + | ||
75 | +# pyenv | ||
76 | +.python-version | ||
77 | + | ||
78 | +# celery beat schedule file | ||
79 | +celerybeat-schedule | ||
80 | + | ||
81 | +# SageMath parsed files | ||
82 | +*.sage.py | ||
83 | + | ||
84 | +# Environments | ||
85 | +.env | ||
86 | +.venv | ||
87 | +env/ | ||
88 | +venv/ | ||
89 | +ENV/ | ||
90 | +env.bak/ | ||
91 | +venv.bak/ | ||
92 | + | ||
93 | +# Spyder project settings | ||
94 | +.spyderproject | ||
95 | +.spyproject | ||
96 | + | ||
97 | +# Rope project settings | ||
98 | +.ropeproject | ||
99 | + | ||
100 | +# mkdocs documentation | ||
101 | +/site | ||
102 | + | ||
103 | +# mypy | ||
104 | +.mypy_cache/ |
File mode changed
This diff could not be displayed because it is too large.
1 | +""" | ||
2 | +Reference : | ||
3 | +- https://github.com/hysts/pytorch_image_classification/blob/master/augmentations/mixup.py | ||
4 | +- https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/imagenet_input.py#L120 | ||
5 | +""" | ||
6 | + | ||
7 | +import numpy as np | ||
8 | +import torch | ||
9 | + | ||
10 | +from FastAutoAugment.metrics import CrossEntropyLabelSmooth | ||
11 | + | ||
12 | + | ||
13 | +def mixup(data, targets, alpha): | ||
14 | + indices = torch.randperm(data.size(0)) | ||
15 | + shuffled_data = data[indices] | ||
16 | + shuffled_targets = targets[indices] | ||
17 | + | ||
18 | + lam = np.random.beta(alpha, alpha) | ||
19 | + lam = max(lam, 1. - lam) | ||
20 | + assert 0.0 <= lam <= 1.0, lam | ||
21 | + data = data * lam + shuffled_data * (1 - lam) | ||
22 | + | ||
23 | + return data, targets, shuffled_targets, lam | ||
24 | + | ||
25 | + | ||
26 | +class CrossEntropyMixUpLabelSmooth(torch.nn.Module): | ||
27 | + def __init__(self, num_classes, epsilon, reduction='mean'): | ||
28 | + super(CrossEntropyMixUpLabelSmooth, self).__init__() | ||
29 | + self.ce = CrossEntropyLabelSmooth(num_classes, epsilon, reduction=reduction) | ||
30 | + | ||
31 | + def forward(self, input, target1, target2, lam): # pylint: disable=redefined-builtin | ||
32 | + return lam * self.ce(input, target1) + (1 - lam) * self.ce(input, target2) |
1 | +# code in this file is adpated from rpmcruz/autoaugment | ||
2 | +# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py | ||
3 | +import random | ||
4 | + | ||
5 | +import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw | ||
6 | +import numpy as np | ||
7 | +import torch | ||
8 | +from torchvision.transforms.transforms import Compose | ||
9 | + | ||
10 | +random_mirror = True | ||
11 | + | ||
12 | + | ||
13 | +def ShearX(img, v): # [-0.3, 0.3] | ||
14 | + assert -0.3 <= v <= 0.3 | ||
15 | + if random_mirror and random.random() > 0.5: | ||
16 | + v = -v | ||
17 | + return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) | ||
18 | + | ||
19 | + | ||
20 | +def ShearY(img, v): # [-0.3, 0.3] | ||
21 | + assert -0.3 <= v <= 0.3 | ||
22 | + if random_mirror and random.random() > 0.5: | ||
23 | + v = -v | ||
24 | + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) | ||
25 | + | ||
26 | + | ||
27 | +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
28 | + assert -0.45 <= v <= 0.45 | ||
29 | + if random_mirror and random.random() > 0.5: | ||
30 | + v = -v | ||
31 | + v = v * img.size[0] | ||
32 | + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) | ||
33 | + | ||
34 | + | ||
35 | +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
36 | + assert -0.45 <= v <= 0.45 | ||
37 | + if random_mirror and random.random() > 0.5: | ||
38 | + v = -v | ||
39 | + v = v * img.size[1] | ||
40 | + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) | ||
41 | + | ||
42 | + | ||
43 | +def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
44 | + assert 0 <= v <= 10 | ||
45 | + if random.random() > 0.5: | ||
46 | + v = -v | ||
47 | + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) | ||
48 | + | ||
49 | + | ||
50 | +def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
51 | + assert 0 <= v <= 10 | ||
52 | + if random.random() > 0.5: | ||
53 | + v = -v | ||
54 | + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) | ||
55 | + | ||
56 | + | ||
57 | +def Rotate(img, v): # [-30, 30] | ||
58 | + assert -30 <= v <= 30 | ||
59 | + if random_mirror and random.random() > 0.5: | ||
60 | + v = -v | ||
61 | + return img.rotate(v) | ||
62 | + | ||
63 | + | ||
64 | +def AutoContrast(img, _): | ||
65 | + return PIL.ImageOps.autocontrast(img) | ||
66 | + | ||
67 | + | ||
68 | +def Invert(img, _): | ||
69 | + return PIL.ImageOps.invert(img) | ||
70 | + | ||
71 | + | ||
72 | +def Equalize(img, _): | ||
73 | + return PIL.ImageOps.equalize(img) | ||
74 | + | ||
75 | + | ||
76 | +def Flip(img, _): # not from the paper | ||
77 | + return PIL.ImageOps.mirror(img) | ||
78 | + | ||
79 | + | ||
80 | +def Solarize(img, v): # [0, 256] | ||
81 | + assert 0 <= v <= 256 | ||
82 | + return PIL.ImageOps.solarize(img, v) | ||
83 | + | ||
84 | + | ||
85 | +def Posterize(img, v): # [4, 8] | ||
86 | + assert 4 <= v <= 8 | ||
87 | + v = int(v) | ||
88 | + return PIL.ImageOps.posterize(img, v) | ||
89 | + | ||
90 | + | ||
91 | +def Posterize2(img, v): # [0, 4] | ||
92 | + assert 0 <= v <= 4 | ||
93 | + v = int(v) | ||
94 | + return PIL.ImageOps.posterize(img, v) | ||
95 | + | ||
96 | + | ||
97 | +def Contrast(img, v): # [0.1,1.9] | ||
98 | + assert 0.1 <= v <= 1.9 | ||
99 | + return PIL.ImageEnhance.Contrast(img).enhance(v) | ||
100 | + | ||
101 | + | ||
102 | +def Color(img, v): # [0.1,1.9] | ||
103 | + assert 0.1 <= v <= 1.9 | ||
104 | + return PIL.ImageEnhance.Color(img).enhance(v) | ||
105 | + | ||
106 | + | ||
107 | +def Brightness(img, v): # [0.1,1.9] | ||
108 | + assert 0.1 <= v <= 1.9 | ||
109 | + return PIL.ImageEnhance.Brightness(img).enhance(v) | ||
110 | + | ||
111 | + | ||
112 | +def Sharpness(img, v): # [0.1,1.9] | ||
113 | + assert 0.1 <= v <= 1.9 | ||
114 | + return PIL.ImageEnhance.Sharpness(img).enhance(v) | ||
115 | + | ||
116 | + | ||
117 | +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] | ||
118 | + assert 0.0 <= v <= 0.2 | ||
119 | + if v <= 0.: | ||
120 | + return img | ||
121 | + | ||
122 | + v = v * img.size[0] | ||
123 | + return CutoutAbs(img, v) | ||
124 | + | ||
125 | + | ||
126 | +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] | ||
127 | + # assert 0 <= v <= 20 | ||
128 | + if v < 0: | ||
129 | + return img | ||
130 | + w, h = img.size | ||
131 | + x0 = np.random.uniform(w) | ||
132 | + y0 = np.random.uniform(h) | ||
133 | + | ||
134 | + x0 = int(max(0, x0 - v / 2.)) | ||
135 | + y0 = int(max(0, y0 - v / 2.)) | ||
136 | + x1 = min(w, x0 + v) | ||
137 | + y1 = min(h, y0 + v) | ||
138 | + | ||
139 | + xy = (x0, y0, x1, y1) | ||
140 | + color = (125, 123, 114) | ||
141 | + # color = (0, 0, 0) | ||
142 | + img = img.copy() | ||
143 | + PIL.ImageDraw.Draw(img).rectangle(xy, color) | ||
144 | + return img | ||
145 | + | ||
146 | + | ||
147 | +def SamplePairing(imgs): # [0, 0.4] | ||
148 | + def f(img1, v): | ||
149 | + i = np.random.choice(len(imgs)) | ||
150 | + img2 = PIL.Image.fromarray(imgs[i]) | ||
151 | + return PIL.Image.blend(img1, img2, v) | ||
152 | + | ||
153 | + return f | ||
154 | + | ||
155 | + | ||
156 | +def augment_list(for_autoaug=True): # 16 oeprations and their ranges | ||
157 | + l = [ | ||
158 | + (ShearX, -0.3, 0.3), # 0 | ||
159 | + (ShearY, -0.3, 0.3), # 1 | ||
160 | + (TranslateX, -0.45, 0.45), # 2 | ||
161 | + (TranslateY, -0.45, 0.45), # 3 | ||
162 | + (Rotate, -30, 30), # 4 | ||
163 | + (AutoContrast, 0, 1), # 5 | ||
164 | + (Invert, 0, 1), # 6 | ||
165 | + (Equalize, 0, 1), # 7 | ||
166 | + (Solarize, 0, 256), # 8 | ||
167 | + (Posterize, 4, 8), # 9 | ||
168 | + (Contrast, 0.1, 1.9), # 10 | ||
169 | + (Color, 0.1, 1.9), # 11 | ||
170 | + (Brightness, 0.1, 1.9), # 12 | ||
171 | + (Sharpness, 0.1, 1.9), # 13 | ||
172 | + (Cutout, 0, 0.2), # 14 | ||
173 | + # (SamplePairing(imgs), 0, 0.4), # 15 | ||
174 | + ] | ||
175 | + if for_autoaug: | ||
176 | + l += [ | ||
177 | + (CutoutAbs, 0, 20), # compatible with auto-augment | ||
178 | + (Posterize2, 0, 4), # 9 | ||
179 | + (TranslateXAbs, 0, 10), # 9 | ||
180 | + (TranslateYAbs, 0, 10), # 9 | ||
181 | + ] | ||
182 | + return l | ||
183 | + | ||
184 | + | ||
185 | +augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} | ||
186 | + | ||
187 | + | ||
188 | +def get_augment(name): | ||
189 | + return augment_dict[name] | ||
190 | + | ||
191 | + | ||
192 | +def apply_augment(img, name, level): | ||
193 | + augment_fn, low, high = get_augment(name) | ||
194 | + return augment_fn(img.copy(), level * (high - low) + low) | ||
195 | + | ||
196 | + | ||
197 | +class Lighting(object): | ||
198 | + """Lighting noise(AlexNet - style PCA - based noise)""" | ||
199 | + | ||
200 | + def __init__(self, alphastd, eigval, eigvec): | ||
201 | + self.alphastd = alphastd | ||
202 | + self.eigval = torch.Tensor(eigval) | ||
203 | + self.eigvec = torch.Tensor(eigvec) | ||
204 | + | ||
205 | + def __call__(self, img): | ||
206 | + if self.alphastd == 0: | ||
207 | + return img | ||
208 | + | ||
209 | + alpha = img.new().resize_(3).normal_(0, self.alphastd) | ||
210 | + rgb = self.eigvec.type_as(img).clone() \ | ||
211 | + .mul(alpha.view(1, 3).expand(3, 3)) \ | ||
212 | + .mul(self.eigval.view(1, 3).expand(3, 3)) \ | ||
213 | + .sum(1).squeeze() | ||
214 | + | ||
215 | + return img.add(rgb.view(3, 1, 1).expand_as(img)) |
1 | +import copy | ||
2 | +import logging | ||
3 | +import warnings | ||
4 | + | ||
5 | +formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') | ||
6 | +warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) | ||
7 | +warnings.filterwarnings("ignore", "DeprecationWarning: 'saved_variables' is deprecated", UserWarning) | ||
8 | + | ||
9 | + | ||
10 | +def get_logger(name, level=logging.DEBUG): | ||
11 | + logger = logging.getLogger(name) | ||
12 | + logger.handlers.clear() | ||
13 | + logger.setLevel(level) | ||
14 | + ch = logging.StreamHandler() | ||
15 | + ch.setLevel(level) | ||
16 | + ch.setFormatter(formatter) | ||
17 | + logger.addHandler(ch) | ||
18 | + return logger | ||
19 | + | ||
20 | + | ||
21 | +def add_filehandler(logger, filepath, level=logging.DEBUG): | ||
22 | + fh = logging.FileHandler(filepath) | ||
23 | + fh.setLevel(level) | ||
24 | + fh.setFormatter(formatter) | ||
25 | + logger.addHandler(fh) | ||
26 | + | ||
27 | + | ||
28 | +class EMA: | ||
29 | + def __init__(self, mu): | ||
30 | + self.mu = mu | ||
31 | + self.shadow = {} | ||
32 | + | ||
33 | + def state_dict(self): | ||
34 | + return copy.deepcopy(self.shadow) | ||
35 | + | ||
36 | + def __len__(self): | ||
37 | + return len(self.shadow) | ||
38 | + | ||
39 | + def __call__(self, module, step=None): | ||
40 | + if step is None: | ||
41 | + mu = self.mu | ||
42 | + else: | ||
43 | + # see : https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/ExponentialMovingAverage?hl=PL | ||
44 | + mu = min(self.mu, (1. + step) / (10 + step)) | ||
45 | + | ||
46 | + for name, x in module.state_dict().items(): | ||
47 | + if name in self.shadow: | ||
48 | + new_average = (1.0 - mu) * x + mu * self.shadow[name] | ||
49 | + self.shadow[name] = new_average.clone() | ||
50 | + else: | ||
51 | + self.shadow[name] = x.clone() |
1 | +import logging | ||
2 | + | ||
3 | +import numpy as np | ||
4 | +import os | ||
5 | + | ||
6 | +import math | ||
7 | +import random | ||
8 | +import torch | ||
9 | +import torchvision | ||
10 | +from PIL import Image | ||
11 | + | ||
12 | +from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset | ||
13 | +import torch.distributed as dist | ||
14 | +from torchvision.transforms import transforms | ||
15 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
16 | +from theconf import Config as C | ||
17 | + | ||
18 | +from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet | ||
19 | +from FastAutoAugment.augmentations import * | ||
20 | +from FastAutoAugment.common import get_logger | ||
21 | +from FastAutoAugment.imagenet import ImageNet | ||
22 | +from FastAutoAugment.networks.efficientnet_pytorch.model import EfficientNet | ||
23 | + | ||
24 | +logger = get_logger('Fast AutoAugment') | ||
25 | +logger.setLevel(logging.INFO) | ||
26 | +_IMAGENET_PCA = { | ||
27 | + 'eigval': [0.2175, 0.0188, 0.0045], | ||
28 | + 'eigvec': [ | ||
29 | + [-0.5675, 0.7192, 0.4009], | ||
30 | + [-0.5808, -0.0045, -0.8140], | ||
31 | + [-0.5836, -0.6948, 0.4203], | ||
32 | + ] | ||
33 | +} | ||
34 | +_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) | ||
35 | + | ||
36 | + | ||
37 | +def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1): | ||
38 | + if 'cifar' in dataset or 'svhn' in dataset: | ||
39 | + transform_train = transforms.Compose([ | ||
40 | + transforms.RandomCrop(32, padding=4), | ||
41 | + transforms.RandomHorizontalFlip(), | ||
42 | + transforms.ToTensor(), | ||
43 | + transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), | ||
44 | + ]) | ||
45 | + transform_test = transforms.Compose([ | ||
46 | + transforms.ToTensor(), | ||
47 | + transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), | ||
48 | + ]) | ||
49 | + elif 'imagenet' in dataset: | ||
50 | + input_size = 224 | ||
51 | + sized_size = 256 | ||
52 | + | ||
53 | + if 'efficientnet' in C.get()['model']['type']: | ||
54 | + input_size = EfficientNet.get_image_size(C.get()['model']['type']) | ||
55 | + sized_size = input_size + 32 # TODO | ||
56 | + # sized_size = int(round(input_size / 224. * 256)) | ||
57 | + # sized_size = input_size | ||
58 | + logger.info('size changed to %d/%d.' % (input_size, sized_size)) | ||
59 | + | ||
60 | + transform_train = transforms.Compose([ | ||
61 | + EfficientNetRandomCrop(input_size), | ||
62 | + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), | ||
63 | + # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC), | ||
64 | + transforms.RandomHorizontalFlip(), | ||
65 | + transforms.ColorJitter( | ||
66 | + brightness=0.4, | ||
67 | + contrast=0.4, | ||
68 | + saturation=0.4, | ||
69 | + ), | ||
70 | + transforms.ToTensor(), | ||
71 | + Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), | ||
72 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
73 | + ]) | ||
74 | + | ||
75 | + transform_test = transforms.Compose([ | ||
76 | + EfficientNetCenterCrop(input_size), | ||
77 | + transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), | ||
78 | + transforms.ToTensor(), | ||
79 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
80 | + ]) | ||
81 | + | ||
82 | + else: | ||
83 | + raise ValueError('dataset=%s' % dataset) | ||
84 | + | ||
85 | + total_aug = augs = None | ||
86 | + if isinstance(C.get()['aug'], list): | ||
87 | + logger.debug('augmentation provided.') | ||
88 | + transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) | ||
89 | + else: | ||
90 | + logger.debug('augmentation: %s' % C.get()['aug']) | ||
91 | + if C.get()['aug'] == 'fa_reduced_cifar10': | ||
92 | + transform_train.transforms.insert(0, Augmentation(fa_reduced_cifar10())) | ||
93 | + | ||
94 | + elif C.get()['aug'] == 'fa_reduced_imagenet': | ||
95 | + transform_train.transforms.insert(0, Augmentation(fa_resnet50_rimagenet())) | ||
96 | + | ||
97 | + elif C.get()['aug'] == 'fa_reduced_svhn': | ||
98 | + transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn())) | ||
99 | + | ||
100 | + elif C.get()['aug'] == 'arsaug': | ||
101 | + transform_train.transforms.insert(0, Augmentation(arsaug_policy())) | ||
102 | + elif C.get()['aug'] == 'autoaug_cifar10': | ||
103 | + transform_train.transforms.insert(0, Augmentation(autoaug_paper_cifar10())) | ||
104 | + elif C.get()['aug'] == 'autoaug_extend': | ||
105 | + transform_train.transforms.insert(0, Augmentation(autoaug_policy())) | ||
106 | + elif C.get()['aug'] in ['default']: | ||
107 | + pass | ||
108 | + else: | ||
109 | + raise ValueError('not found augmentations. %s' % C.get()['aug']) | ||
110 | + | ||
111 | + if C.get()['cutout'] > 0: | ||
112 | + transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) | ||
113 | + | ||
114 | + if dataset == 'cifar10': | ||
115 | + total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) | ||
116 | + testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) | ||
117 | + elif dataset == 'reduced_cifar10': | ||
118 | + total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) | ||
119 | + sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset | ||
120 | + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) | ||
121 | + train_idx, valid_idx = next(sss) | ||
122 | + targets = [total_trainset.targets[idx] for idx in train_idx] | ||
123 | + total_trainset = Subset(total_trainset, train_idx) | ||
124 | + total_trainset.targets = targets | ||
125 | + | ||
126 | + testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) | ||
127 | + elif dataset == 'cifar100': | ||
128 | + total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train) | ||
129 | + testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) | ||
130 | + elif dataset == 'svhn': | ||
131 | + trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) | ||
132 | + extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) | ||
133 | + total_trainset = ConcatDataset([trainset, extraset]) | ||
134 | + testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) | ||
135 | + elif dataset == 'reduced_svhn': | ||
136 | + total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) | ||
137 | + sss = StratifiedShuffleSplit(n_splits=1, test_size=73257-1000, random_state=0) # 1000 trainset | ||
138 | + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) | ||
139 | + train_idx, valid_idx = next(sss) | ||
140 | + targets = [total_trainset.targets[idx] for idx in train_idx] | ||
141 | + total_trainset = Subset(total_trainset, train_idx) | ||
142 | + total_trainset.targets = targets | ||
143 | + | ||
144 | + testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) | ||
145 | + elif dataset == 'imagenet': | ||
146 | + total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) | ||
147 | + testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) | ||
148 | + | ||
149 | + # compatibility | ||
150 | + total_trainset.targets = [lb for _, lb in total_trainset.samples] | ||
151 | + elif dataset == 'reduced_imagenet': | ||
152 | + # randomly chosen indices | ||
153 | +# idx120 = sorted(random.sample(list(range(1000)), k=120)) | ||
154 | + idx120 = [16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959] | ||
155 | + total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) | ||
156 | + testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) | ||
157 | + | ||
158 | + # compatibility | ||
159 | + total_trainset.targets = [lb for _, lb in total_trainset.samples] | ||
160 | + | ||
161 | + sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0) # 4000 trainset | ||
162 | + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) | ||
163 | + train_idx, valid_idx = next(sss) | ||
164 | + | ||
165 | + # filter out | ||
166 | + train_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, train_idx)) | ||
167 | + valid_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, valid_idx)) | ||
168 | + test_idx = list(filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) | ||
169 | + | ||
170 | + targets = [idx120.index(total_trainset.targets[idx]) for idx in train_idx] | ||
171 | + for idx in range(len(total_trainset.samples)): | ||
172 | + if total_trainset.samples[idx][1] not in idx120: | ||
173 | + continue | ||
174 | + total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index(total_trainset.samples[idx][1])) | ||
175 | + total_trainset = Subset(total_trainset, train_idx) | ||
176 | + total_trainset.targets = targets | ||
177 | + | ||
178 | + for idx in range(len(testset.samples)): | ||
179 | + if testset.samples[idx][1] not in idx120: | ||
180 | + continue | ||
181 | + testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) | ||
182 | + testset = Subset(testset, test_idx) | ||
183 | + print('reduced_imagenet train=', len(total_trainset)) | ||
184 | + else: | ||
185 | + raise ValueError('invalid dataset name=%s' % dataset) | ||
186 | + | ||
187 | + if total_aug is not None and augs is not None: | ||
188 | + total_trainset.set_preaug(augs, total_aug) | ||
189 | + print('set_preaug-') | ||
190 | + | ||
191 | + train_sampler = None | ||
192 | + if split > 0.0: | ||
193 | + sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) | ||
194 | + sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) | ||
195 | + for _ in range(split_idx + 1): | ||
196 | + train_idx, valid_idx = next(sss) | ||
197 | + | ||
198 | + if target_lb >= 0: | ||
199 | + train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb] | ||
200 | + valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb] | ||
201 | + | ||
202 | + train_sampler = SubsetRandomSampler(train_idx) | ||
203 | + valid_sampler = SubsetSampler(valid_idx) | ||
204 | + | ||
205 | + if multinode: | ||
206 | + train_sampler = torch.utils.data.distributed.DistributedSampler(Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank()) | ||
207 | + else: | ||
208 | + valid_sampler = SubsetSampler([]) | ||
209 | + | ||
210 | + if multinode: | ||
211 | + train_sampler = torch.utils.data.distributed.DistributedSampler(total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) | ||
212 | + logger.info(f'----- dataset with DistributedSampler {dist.get_rank()}/{dist.get_world_size()}') | ||
213 | + | ||
214 | + trainloader = torch.utils.data.DataLoader( | ||
215 | + total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8, pin_memory=True, | ||
216 | + sampler=train_sampler, drop_last=True) | ||
217 | + validloader = torch.utils.data.DataLoader( | ||
218 | + total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, | ||
219 | + sampler=valid_sampler, drop_last=False) | ||
220 | + | ||
221 | + testloader = torch.utils.data.DataLoader( | ||
222 | + testset, batch_size=batch, shuffle=False, num_workers=8, pin_memory=True, | ||
223 | + drop_last=False | ||
224 | + ) | ||
225 | + return train_sampler, trainloader, validloader, testloader | ||
226 | + | ||
227 | + | ||
228 | +class CutoutDefault(object): | ||
229 | + """ | ||
230 | + Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py | ||
231 | + """ | ||
232 | + def __init__(self, length): | ||
233 | + self.length = length | ||
234 | + | ||
235 | + def __call__(self, img): | ||
236 | + h, w = img.size(1), img.size(2) | ||
237 | + mask = np.ones((h, w), np.float32) | ||
238 | + y = np.random.randint(h) | ||
239 | + x = np.random.randint(w) | ||
240 | + | ||
241 | + y1 = np.clip(y - self.length // 2, 0, h) | ||
242 | + y2 = np.clip(y + self.length // 2, 0, h) | ||
243 | + x1 = np.clip(x - self.length // 2, 0, w) | ||
244 | + x2 = np.clip(x + self.length // 2, 0, w) | ||
245 | + | ||
246 | + mask[y1: y2, x1: x2] = 0. | ||
247 | + mask = torch.from_numpy(mask) | ||
248 | + mask = mask.expand_as(img) | ||
249 | + img *= mask | ||
250 | + return img | ||
251 | + | ||
252 | + | ||
253 | +class Augmentation(object): | ||
254 | + def __init__(self, policies): | ||
255 | + self.policies = policies | ||
256 | + | ||
257 | + def __call__(self, img): | ||
258 | + for _ in range(1): | ||
259 | + policy = random.choice(self.policies) | ||
260 | + for name, pr, level in policy: | ||
261 | + if random.random() > pr: | ||
262 | + continue | ||
263 | + img = apply_augment(img, name, level) | ||
264 | + return img | ||
265 | + | ||
266 | + | ||
267 | +class EfficientNetRandomCrop: | ||
268 | + def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3), area_range=(0.08, 1.0), max_attempts=10): | ||
269 | + assert 0.0 < min_covered | ||
270 | + assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1] | ||
271 | + assert 0 < area_range[0] <= area_range[1] | ||
272 | + assert 1 <= max_attempts | ||
273 | + | ||
274 | + self.min_covered = min_covered | ||
275 | + self.aspect_ratio_range = aspect_ratio_range | ||
276 | + self.area_range = area_range | ||
277 | + self.max_attempts = max_attempts | ||
278 | + self._fallback = EfficientNetCenterCrop(imgsize) | ||
279 | + | ||
280 | + def __call__(self, img): | ||
281 | + # https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111 | ||
282 | + original_width, original_height = img.size | ||
283 | + min_area = self.area_range[0] * (original_width * original_height) | ||
284 | + max_area = self.area_range[1] * (original_width * original_height) | ||
285 | + | ||
286 | + for _ in range(self.max_attempts): | ||
287 | + aspect_ratio = random.uniform(*self.aspect_ratio_range) | ||
288 | + height = int(round(math.sqrt(min_area / aspect_ratio))) | ||
289 | + max_height = int(round(math.sqrt(max_area / aspect_ratio))) | ||
290 | + | ||
291 | + if max_height * aspect_ratio > original_width: | ||
292 | + max_height = (original_width + 0.5 - 1e-7) / aspect_ratio | ||
293 | + max_height = int(max_height) | ||
294 | + if max_height * aspect_ratio > original_width: | ||
295 | + max_height -= 1 | ||
296 | + | ||
297 | + if max_height > original_height: | ||
298 | + max_height = original_height | ||
299 | + | ||
300 | + if height >= max_height: | ||
301 | + height = max_height | ||
302 | + | ||
303 | + height = int(round(random.uniform(height, max_height))) | ||
304 | + width = int(round(height * aspect_ratio)) | ||
305 | + area = width * height | ||
306 | + | ||
307 | + if area < min_area or area > max_area: | ||
308 | + continue | ||
309 | + if width > original_width or height > original_height: | ||
310 | + continue | ||
311 | + if area < self.min_covered * (original_width * original_height): | ||
312 | + continue | ||
313 | + if width == original_width and height == original_height: | ||
314 | + return self._fallback(img) # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L102 | ||
315 | + | ||
316 | + x = random.randint(0, original_width - width) | ||
317 | + y = random.randint(0, original_height - height) | ||
318 | + return img.crop((x, y, x + width, y + height)) | ||
319 | + | ||
320 | + return self._fallback(img) | ||
321 | + | ||
322 | + | ||
323 | +class EfficientNetCenterCrop: | ||
324 | + def __init__(self, imgsize): | ||
325 | + self.imgsize = imgsize | ||
326 | + | ||
327 | + def __call__(self, img): | ||
328 | + """Crop the given PIL Image and resize it to desired size. | ||
329 | + | ||
330 | + Args: | ||
331 | + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. | ||
332 | + output_size (sequence or int): (height, width) of the crop box. If int, | ||
333 | + it is used for both directions | ||
334 | + Returns: | ||
335 | + PIL Image: Cropped image. | ||
336 | + """ | ||
337 | + image_width, image_height = img.size | ||
338 | + image_short = min(image_width, image_height) | ||
339 | + | ||
340 | + crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short | ||
341 | + | ||
342 | + crop_height, crop_width = crop_size, crop_size | ||
343 | + crop_top = int(round((image_height - crop_height) / 2.)) | ||
344 | + crop_left = int(round((image_width - crop_width) / 2.)) | ||
345 | + return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) | ||
346 | + | ||
347 | + | ||
348 | +class SubsetSampler(Sampler): | ||
349 | + r"""Samples elements from a given list of indices, without replacement. | ||
350 | + | ||
351 | + Arguments: | ||
352 | + indices (sequence): a sequence of indices | ||
353 | + """ | ||
354 | + | ||
355 | + def __init__(self, indices): | ||
356 | + self.indices = indices | ||
357 | + | ||
358 | + def __iter__(self): | ||
359 | + return (i for i in self.indices) | ||
360 | + | ||
361 | + def __len__(self): | ||
362 | + return len(self.indices) |
1 | +from __future__ import print_function | ||
2 | +import os | ||
3 | +import shutil | ||
4 | +import torch | ||
5 | + | ||
6 | +ARCHIVE_DICT = { | ||
7 | + 'train': { | ||
8 | + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', | ||
9 | + 'md5': '1d675b47d978889d74fa0da5fadfb00e', | ||
10 | + }, | ||
11 | + 'val': { | ||
12 | + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', | ||
13 | + 'md5': '29b22e2961454d5413ddabcf34fc5622', | ||
14 | + }, | ||
15 | + 'devkit': { | ||
16 | + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', | ||
17 | + 'md5': 'fa75699e90414af021442c21a62c3abf', | ||
18 | + } | ||
19 | +} | ||
20 | + | ||
21 | + | ||
22 | +import torchvision | ||
23 | +from torchvision.datasets.utils import check_integrity, download_url | ||
24 | + | ||
25 | + | ||
26 | +# copy ILSVRC/ImageSets/CLS-LOC/train_cls.txt to ./root/ | ||
27 | +# to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file | ||
28 | +class ImageNet(torchvision.datasets.ImageFolder): | ||
29 | + """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset. | ||
30 | + | ||
31 | + Args: | ||
32 | + root (string): Root directory of the ImageNet Dataset. | ||
33 | + split (string, optional): The dataset split, supports ``train``, or ``val``. | ||
34 | + download (bool, optional): If true, downloads the dataset from the internet and | ||
35 | + puts it in root directory. If dataset is already downloaded, it is not | ||
36 | + downloaded again. | ||
37 | + transform (callable, optional): A function/transform that takes in an PIL image | ||
38 | + and returns a transformed version. E.g, ``transforms.RandomCrop`` | ||
39 | + target_transform (callable, optional): A function/transform that takes in the | ||
40 | + target and transforms it. | ||
41 | + loader (callable, optional): A function to load an image given its path. | ||
42 | + | ||
43 | + Attributes: | ||
44 | + classes (list): List of the class names. | ||
45 | + class_to_idx (dict): Dict with items (class_name, class_index). | ||
46 | + wnids (list): List of the WordNet IDs. | ||
47 | + wnid_to_idx (dict): Dict with items (wordnet_id, class_index). | ||
48 | + imgs (list): List of (image path, class_index) tuples | ||
49 | + targets (list): The class_index value for each image in the dataset | ||
50 | + """ | ||
51 | + | ||
52 | + def __init__(self, root, split='train', download=False, **kwargs): | ||
53 | + root = self.root = os.path.expanduser(root) | ||
54 | + self.split = self._verify_split(split) | ||
55 | + | ||
56 | + if download: | ||
57 | + self.download() | ||
58 | + wnid_to_classes = self._load_meta_file()[0] | ||
59 | + | ||
60 | + # to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file | ||
61 | + listfile = os.path.join(root, 'train_cls.txt') | ||
62 | + if split == 'train' and os.path.exists(listfile): | ||
63 | + torchvision.datasets.VisionDataset.__init__(self, root, **kwargs) | ||
64 | + with open(listfile, 'r') as f: | ||
65 | + datalist = [ | ||
66 | + line.strip().split(' ')[0] | ||
67 | + for line in f.readlines() | ||
68 | + if line.strip() | ||
69 | + ] | ||
70 | + | ||
71 | + classes = list(set([line.split('/')[0] for line in datalist])) | ||
72 | + classes.sort() | ||
73 | + class_to_idx = {classes[i]: i for i in range(len(classes))} | ||
74 | + | ||
75 | + samples = [ | ||
76 | + (os.path.join(self.split_folder, line + '.JPEG'), class_to_idx[line.split('/')[0]]) | ||
77 | + for line in datalist | ||
78 | + ] | ||
79 | + | ||
80 | + self.loader = torchvision.datasets.folder.default_loader | ||
81 | + self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS | ||
82 | + | ||
83 | + self.classes = classes | ||
84 | + self.class_to_idx = class_to_idx | ||
85 | + self.samples = samples | ||
86 | + self.targets = [s[1] for s in samples] | ||
87 | + | ||
88 | + self.imgs = self.samples | ||
89 | + else: | ||
90 | + super(ImageNet, self).__init__(self.split_folder, **kwargs) | ||
91 | + | ||
92 | + self.root = root | ||
93 | + | ||
94 | + idcs = [idx for _, idx in self.imgs] | ||
95 | + self.wnids = self.classes | ||
96 | + self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)} | ||
97 | + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] | ||
98 | + self.class_to_idx = {cls: idx | ||
99 | + for clss, idx in zip(self.classes, idcs) | ||
100 | + for cls in clss} | ||
101 | + | ||
102 | + def download(self): | ||
103 | + if not check_integrity(self.meta_file): | ||
104 | + tmpdir = os.path.join(self.root, 'tmp') | ||
105 | + | ||
106 | + archive_dict = ARCHIVE_DICT['devkit'] | ||
107 | + download_and_extract_tar(archive_dict['url'], self.root, | ||
108 | + extract_root=tmpdir, | ||
109 | + md5=archive_dict['md5']) | ||
110 | + devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] | ||
111 | + meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) | ||
112 | + self._save_meta_file(*meta) | ||
113 | + | ||
114 | + shutil.rmtree(tmpdir) | ||
115 | + | ||
116 | + if not os.path.isdir(self.split_folder): | ||
117 | + archive_dict = ARCHIVE_DICT[self.split] | ||
118 | + download_and_extract_tar(archive_dict['url'], self.root, | ||
119 | + extract_root=self.split_folder, | ||
120 | + md5=archive_dict['md5']) | ||
121 | + | ||
122 | + if self.split == 'train': | ||
123 | + prepare_train_folder(self.split_folder) | ||
124 | + elif self.split == 'val': | ||
125 | + val_wnids = self._load_meta_file()[1] | ||
126 | + prepare_val_folder(self.split_folder, val_wnids) | ||
127 | + else: | ||
128 | + msg = ("You set download=True, but a folder '{}' already exist in " | ||
129 | + "the root directory. If you want to re-download or re-extract the " | ||
130 | + "archive, delete the folder.") | ||
131 | + print(msg.format(self.split)) | ||
132 | + | ||
133 | + @property | ||
134 | + def meta_file(self): | ||
135 | + return os.path.join(self.root, 'meta.bin') | ||
136 | + | ||
137 | + def _load_meta_file(self): | ||
138 | + if check_integrity(self.meta_file): | ||
139 | + return torch.load(self.meta_file) | ||
140 | + raise RuntimeError("Meta file not found or corrupted.", | ||
141 | + "You can use download=True to create it.") | ||
142 | + | ||
143 | + def _save_meta_file(self, wnid_to_class, val_wnids): | ||
144 | + torch.save((wnid_to_class, val_wnids), self.meta_file) | ||
145 | + | ||
146 | + def _verify_split(self, split): | ||
147 | + if split not in self.valid_splits: | ||
148 | + msg = "Unknown split {} .".format(split) | ||
149 | + msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) | ||
150 | + raise ValueError(msg) | ||
151 | + return split | ||
152 | + | ||
153 | + @property | ||
154 | + def valid_splits(self): | ||
155 | + return 'train', 'val' | ||
156 | + | ||
157 | + @property | ||
158 | + def split_folder(self): | ||
159 | + return os.path.join(self.root, self.split) | ||
160 | + | ||
161 | + def extra_repr(self): | ||
162 | + return "Split: {split}".format(**self.__dict__) | ||
163 | + | ||
164 | + | ||
165 | +def extract_tar(src, dest=None, gzip=None, delete=False): | ||
166 | + import tarfile | ||
167 | + | ||
168 | + if dest is None: | ||
169 | + dest = os.path.dirname(src) | ||
170 | + if gzip is None: | ||
171 | + gzip = src.lower().endswith('.gz') | ||
172 | + | ||
173 | + mode = 'r:gz' if gzip else 'r' | ||
174 | + with tarfile.open(src, mode) as tarfh: | ||
175 | + tarfh.extractall(path=dest) | ||
176 | + | ||
177 | + if delete: | ||
178 | + os.remove(src) | ||
179 | + | ||
180 | + | ||
181 | +def download_and_extract_tar(url, download_root, extract_root=None, filename=None, | ||
182 | + md5=None, **kwargs): | ||
183 | + download_root = os.path.expanduser(download_root) | ||
184 | + if extract_root is None: | ||
185 | + extract_root = download_root | ||
186 | + if filename is None: | ||
187 | + filename = os.path.basename(url) | ||
188 | + | ||
189 | + if not check_integrity(os.path.join(download_root, filename), md5): | ||
190 | + download_url(url, download_root, filename=filename, md5=md5) | ||
191 | + | ||
192 | + extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) | ||
193 | + | ||
194 | + | ||
195 | +def parse_devkit(root): | ||
196 | + idx_to_wnid, wnid_to_classes = parse_meta(root) | ||
197 | + val_idcs = parse_val_groundtruth(root) | ||
198 | + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] | ||
199 | + return wnid_to_classes, val_wnids | ||
200 | + | ||
201 | + | ||
202 | +def parse_meta(devkit_root, path='data', filename='meta.mat'): | ||
203 | + import scipy.io as sio | ||
204 | + | ||
205 | + metafile = os.path.join(devkit_root, path, filename) | ||
206 | + meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] | ||
207 | + nums_children = list(zip(*meta))[4] | ||
208 | + meta = [meta[idx] for idx, num_children in enumerate(nums_children) | ||
209 | + if num_children == 0] | ||
210 | + idcs, wnids, classes = list(zip(*meta))[:3] | ||
211 | + classes = [tuple(clss.split(', ')) for clss in classes] | ||
212 | + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} | ||
213 | + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} | ||
214 | + return idx_to_wnid, wnid_to_classes | ||
215 | + | ||
216 | + | ||
217 | +def parse_val_groundtruth(devkit_root, path='data', | ||
218 | + filename='ILSVRC2012_validation_ground_truth.txt'): | ||
219 | + with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: | ||
220 | + val_idcs = txtfh.readlines() | ||
221 | + return [int(val_idx) for val_idx in val_idcs] | ||
222 | + | ||
223 | + | ||
224 | +def prepare_train_folder(folder): | ||
225 | + for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: | ||
226 | + extract_tar(archive, os.path.splitext(archive)[0], delete=True) | ||
227 | + | ||
228 | + | ||
229 | +def prepare_val_folder(folder, wnids): | ||
230 | + img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) | ||
231 | + | ||
232 | + for wnid in set(wnids): | ||
233 | + os.mkdir(os.path.join(folder, wnid)) | ||
234 | + | ||
235 | + for wnid, img_file in zip(wnids, img_files): | ||
236 | + shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) | ||
237 | + | ||
238 | + | ||
239 | +def _splitexts(root): | ||
240 | + exts = [] | ||
241 | + ext = '.' | ||
242 | + while ext: | ||
243 | + root, ext = os.path.splitext(root) | ||
244 | + exts.append(ext) | ||
245 | + return root, ''.join(reversed(exts)) |
1 | +import torch | ||
2 | + | ||
3 | +from theconf import Config as C | ||
4 | + | ||
5 | + | ||
6 | +def adjust_learning_rate_resnet(optimizer): | ||
7 | + """ | ||
8 | + Sets the learning rate to the initial LR decayed by 10 on every predefined epochs | ||
9 | + Ref: AutoAugment | ||
10 | + """ | ||
11 | + | ||
12 | + if C.get()['epoch'] == 90: | ||
13 | + return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80]) | ||
14 | + elif C.get()['epoch'] == 270: # autoaugment | ||
15 | + return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240]) | ||
16 | + else: | ||
17 | + raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch']) |
1 | +import copy | ||
2 | + | ||
3 | +import torch | ||
4 | +import numpy as np | ||
5 | +from collections import defaultdict | ||
6 | + | ||
7 | +from torch import nn | ||
8 | + | ||
9 | + | ||
10 | +def accuracy(output, target, topk=(1,)): | ||
11 | + """Computes the precision@k for the specified values of k""" | ||
12 | + maxk = max(topk) | ||
13 | + batch_size = target.size(0) | ||
14 | + | ||
15 | + _, pred = output.topk(maxk, 1, True, True) | ||
16 | + pred = pred.t() | ||
17 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
18 | + | ||
19 | + res = [] | ||
20 | + for k in topk: | ||
21 | + correct_k = correct[:k].view(-1).float().sum(0) | ||
22 | + res.append(correct_k.mul_(1. / batch_size)) | ||
23 | + return res | ||
24 | + | ||
25 | + | ||
26 | +class CrossEntropyLabelSmooth(torch.nn.Module): | ||
27 | + def __init__(self, num_classes, epsilon, reduction='mean'): | ||
28 | + super(CrossEntropyLabelSmooth, self).__init__() | ||
29 | + self.num_classes = num_classes | ||
30 | + self.epsilon = epsilon | ||
31 | + self.reduction = reduction | ||
32 | + self.logsoftmax = torch.nn.LogSoftmax(dim=1) | ||
33 | + | ||
34 | + def forward(self, input, target): # pylint: disable=redefined-builtin | ||
35 | + log_probs = self.logsoftmax(input) | ||
36 | + targets = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1) | ||
37 | + if self.epsilon > 0.0: | ||
38 | + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes | ||
39 | + targets = targets.detach() | ||
40 | + loss = (-targets * log_probs) | ||
41 | + | ||
42 | + if self.reduction in ['avg', 'mean']: | ||
43 | + loss = torch.mean(torch.sum(loss, dim=1)) | ||
44 | + elif self.reduction == 'sum': | ||
45 | + loss = loss.sum() | ||
46 | + return loss | ||
47 | + | ||
48 | + | ||
49 | +class Accumulator: | ||
50 | + def __init__(self): | ||
51 | + self.metrics = defaultdict(lambda: 0.) | ||
52 | + | ||
53 | + def add(self, key, value): | ||
54 | + self.metrics[key] += value | ||
55 | + | ||
56 | + def add_dict(self, dict): | ||
57 | + for key, value in dict.items(): | ||
58 | + self.add(key, value) | ||
59 | + | ||
60 | + def __getitem__(self, item): | ||
61 | + return self.metrics[item] | ||
62 | + | ||
63 | + def __setitem__(self, key, value): | ||
64 | + self.metrics[key] = value | ||
65 | + | ||
66 | + def get_dict(self): | ||
67 | + return copy.deepcopy(dict(self.metrics)) | ||
68 | + | ||
69 | + def items(self): | ||
70 | + return self.metrics.items() | ||
71 | + | ||
72 | + def __str__(self): | ||
73 | + return str(dict(self.metrics)) | ||
74 | + | ||
75 | + def __truediv__(self, other): | ||
76 | + newone = Accumulator() | ||
77 | + for key, value in self.items(): | ||
78 | + if isinstance(other, str): | ||
79 | + if other != key: | ||
80 | + newone[key] = value / self[other] | ||
81 | + else: | ||
82 | + newone[key] = value | ||
83 | + else: | ||
84 | + newone[key] = value / other | ||
85 | + return newone | ||
86 | + | ||
87 | + | ||
88 | +class SummaryWriterDummy: | ||
89 | + def __init__(self, log_dir): | ||
90 | + pass | ||
91 | + | ||
92 | + def add_scalar(self, *args, **kwargs): | ||
93 | + pass |
1 | +import torch | ||
2 | + | ||
3 | +from torch import nn | ||
4 | +from torch.nn import DataParallel | ||
5 | +from torch.nn.parallel import DistributedDataParallel | ||
6 | +import torch.backends.cudnn as cudnn | ||
7 | +# from torchvision import models | ||
8 | +import numpy as np | ||
9 | + | ||
10 | +from FastAutoAugment.networks.resnet import ResNet | ||
11 | +from FastAutoAugment.networks.pyramidnet import PyramidNet | ||
12 | +from FastAutoAugment.networks.shakeshake.shake_resnet import ShakeResNet | ||
13 | +from FastAutoAugment.networks.wideresnet import WideResNet | ||
14 | +from FastAutoAugment.networks.shakeshake.shake_resnext import ShakeResNeXt | ||
15 | +from FastAutoAugment.networks.efficientnet_pytorch import EfficientNet, RoutingFn | ||
16 | +from FastAutoAugment.tf_port.tpu_bn import TpuBatchNormalization | ||
17 | + | ||
18 | + | ||
19 | +def get_model(conf, num_class=10, local_rank=-1): | ||
20 | + name = conf['type'] | ||
21 | + | ||
22 | + if name == 'resnet50': | ||
23 | + model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True) | ||
24 | + elif name == 'resnet200': | ||
25 | + model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True) | ||
26 | + elif name == 'wresnet40_2': | ||
27 | + model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class) | ||
28 | + elif name == 'wresnet28_10': | ||
29 | + model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class) | ||
30 | + | ||
31 | + elif name == 'shakeshake26_2x32d': | ||
32 | + model = ShakeResNet(26, 32, num_class) | ||
33 | + elif name == 'shakeshake26_2x64d': | ||
34 | + model = ShakeResNet(26, 64, num_class) | ||
35 | + elif name == 'shakeshake26_2x96d': | ||
36 | + model = ShakeResNet(26, 96, num_class) | ||
37 | + elif name == 'shakeshake26_2x112d': | ||
38 | + model = ShakeResNet(26, 112, num_class) | ||
39 | + | ||
40 | + elif name == 'shakeshake26_2x96d_next': | ||
41 | + model = ShakeResNeXt(26, 96, 4, num_class) | ||
42 | + | ||
43 | + elif name == 'pyramid': | ||
44 | + model = PyramidNet('cifar10', depth=conf['depth'], alpha=conf['alpha'], num_classes=num_class, bottleneck=conf['bottleneck']) | ||
45 | + | ||
46 | + elif 'efficientnet' in name: | ||
47 | + model = EfficientNet.from_name(name, condconv_num_expert=conf['condconv_num_expert'], norm_layer=None) # TpuBatchNormalization | ||
48 | + if local_rank >= 0: | ||
49 | + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | ||
50 | + def kernel_initializer(module): | ||
51 | + def get_fan_in_out(module): | ||
52 | + num_input_fmaps = module.weight.size(1) | ||
53 | + num_output_fmaps = module.weight.size(0) | ||
54 | + receptive_field_size = 1 | ||
55 | + if module.weight.dim() > 2: | ||
56 | + receptive_field_size = module.weight[0][0].numel() | ||
57 | + fan_in = num_input_fmaps * receptive_field_size | ||
58 | + fan_out = num_output_fmaps * receptive_field_size | ||
59 | + return fan_in, fan_out | ||
60 | + | ||
61 | + if isinstance(module, torch.nn.Conv2d): | ||
62 | + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L58 | ||
63 | + fan_in, fan_out = get_fan_in_out(module) | ||
64 | + torch.nn.init.normal_(module.weight, mean=0.0, std=np.sqrt(2.0 / fan_out)) | ||
65 | + if module.bias is not None: | ||
66 | + torch.nn.init.constant_(module.bias, val=0.) | ||
67 | + elif isinstance(module, RoutingFn): | ||
68 | + torch.nn.init.xavier_uniform_(module.weight) | ||
69 | + torch.nn.init.constant_(module.bias, val=0.) | ||
70 | + elif isinstance(module, torch.nn.Linear): | ||
71 | + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L82 | ||
72 | + fan_in, fan_out = get_fan_in_out(module) | ||
73 | + delta = 1.0 / np.sqrt(fan_out) | ||
74 | + torch.nn.init.uniform_(module.weight, a=-delta, b=delta) | ||
75 | + if module.bias is not None: | ||
76 | + torch.nn.init.constant_(module.bias, val=0.) | ||
77 | + model.apply(kernel_initializer) | ||
78 | + else: | ||
79 | + raise NameError('no model named, %s' % name) | ||
80 | + | ||
81 | + if local_rank >= 0: | ||
82 | + device = torch.device('cuda', local_rank) | ||
83 | + model = model.to(device) | ||
84 | + model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) | ||
85 | + else: | ||
86 | + model = model.cuda() | ||
87 | +# model = DataParallel(model) | ||
88 | + | ||
89 | + cudnn.benchmark = True | ||
90 | + return model | ||
91 | + | ||
92 | + | ||
93 | +def num_class(dataset): | ||
94 | + return { | ||
95 | + 'cifar10': 10, | ||
96 | + 'reduced_cifar10': 10, | ||
97 | + 'cifar10.1': 10, | ||
98 | + 'cifar100': 100, | ||
99 | + 'svhn': 10, | ||
100 | + 'reduced_svhn': 10, | ||
101 | + 'imagenet': 1000, | ||
102 | + 'reduced_imagenet': 120, | ||
103 | + }[dataset] |
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | +from torch._six import container_abcs | ||
5 | + | ||
6 | +from itertools import repeat | ||
7 | +from functools import partial | ||
8 | +from typing import Union, List, Tuple, Optional, Callable | ||
9 | +import numpy as np | ||
10 | +import math | ||
11 | + | ||
12 | + | ||
13 | +def _ntuple(n): | ||
14 | + def parse(x): | ||
15 | + if isinstance(x, container_abcs.Iterable): | ||
16 | + return x | ||
17 | + return tuple(repeat(x, n)) | ||
18 | + return parse | ||
19 | + | ||
20 | + | ||
21 | +_single = _ntuple(1) | ||
22 | +_pair = _ntuple(2) | ||
23 | +_triple = _ntuple(3) | ||
24 | +_quadruple = _ntuple(4) | ||
25 | + | ||
26 | + | ||
27 | +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): | ||
28 | + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 | ||
29 | + | ||
30 | + | ||
31 | +def _get_padding(kernel_size, stride=1, dilation=1, **_): | ||
32 | + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 | ||
33 | + return padding | ||
34 | + | ||
35 | + | ||
36 | +def _calc_same_pad(i: int, k: int, s: int, d: int): | ||
37 | + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) | ||
38 | + | ||
39 | + | ||
40 | +def conv2d_same( | ||
41 | + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), | ||
42 | + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): | ||
43 | + ih, iw = x.size()[-2:] | ||
44 | + kh, kw = weight.size()[-2:] | ||
45 | + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) | ||
46 | + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) | ||
47 | + if pad_h > 0 or pad_w > 0: | ||
48 | + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) | ||
49 | + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) | ||
50 | + | ||
51 | + | ||
52 | +def get_padding_value(padding, kernel_size, **kwargs): | ||
53 | + dynamic = False | ||
54 | + if isinstance(padding, str): | ||
55 | + # for any string padding, the padding will be calculated for you, one of three ways | ||
56 | + padding = padding.lower() | ||
57 | + if padding == 'same': | ||
58 | + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact | ||
59 | + if _is_static_pad(kernel_size, **kwargs): | ||
60 | + # static case, no extra overhead | ||
61 | + padding = _get_padding(kernel_size, **kwargs) | ||
62 | + else: | ||
63 | + # dynamic padding | ||
64 | + padding = 0 | ||
65 | + dynamic = True | ||
66 | + elif padding == 'valid': | ||
67 | + # 'VALID' padding, same as padding=0 | ||
68 | + padding = 0 | ||
69 | + else: | ||
70 | + # Default to PyTorch style 'same'-ish symmetric padding | ||
71 | + padding = _get_padding(kernel_size, **kwargs) | ||
72 | + return padding, dynamic | ||
73 | + | ||
74 | + | ||
75 | +def get_condconv_initializer(initializer, num_experts, expert_shape): | ||
76 | + def condconv_initializer(weight): | ||
77 | + """CondConv initializer function.""" | ||
78 | + num_params = np.prod(expert_shape) | ||
79 | + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params): | ||
80 | + raise (ValueError('CondConv variables must have shape [num_experts, num_params]')) | ||
81 | + for i in range(num_experts): | ||
82 | + initializer(weight[i].view(expert_shape)) | ||
83 | + return condconv_initializer | ||
84 | + | ||
85 | + | ||
86 | +class CondConv2d(nn.Module): | ||
87 | + """ Conditional Convolution | ||
88 | + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py | ||
89 | + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: | ||
90 | + https://github.com/pytorch/pytorch/issues/17983 | ||
91 | + """ | ||
92 | + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] | ||
93 | + | ||
94 | + def __init__(self, in_channels, out_channels, kernel_size=3, | ||
95 | + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): | ||
96 | + super(CondConv2d, self).__init__() | ||
97 | + assert num_experts > 1 | ||
98 | + | ||
99 | + if isinstance(stride, container_abcs.Iterable) and len(stride) == 1: | ||
100 | + stride = stride[0] | ||
101 | + # print('CondConv', num_experts) | ||
102 | + | ||
103 | + self.in_channels = in_channels | ||
104 | + self.out_channels = out_channels | ||
105 | + self.kernel_size = _pair(kernel_size) | ||
106 | + self.stride = _pair(stride) | ||
107 | + padding_val, is_padding_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) | ||
108 | + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript | ||
109 | + self.padding = _pair(padding_val) | ||
110 | + self.dilation = _pair(dilation) | ||
111 | + self.groups = groups | ||
112 | + self.num_experts = num_experts | ||
113 | + | ||
114 | + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size | ||
115 | + weight_num_param = 1 | ||
116 | + for wd in self.weight_shape: | ||
117 | + weight_num_param *= wd | ||
118 | + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) | ||
119 | + | ||
120 | + if bias: | ||
121 | + self.bias_shape = (self.out_channels,) | ||
122 | + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) | ||
123 | + else: | ||
124 | + self.register_parameter('bias', None) | ||
125 | + | ||
126 | + self.reset_parameters() | ||
127 | + | ||
128 | + def reset_parameters(self): | ||
129 | + num_input_fmaps = self.weight.size(1) | ||
130 | + num_output_fmaps = self.weight.size(0) | ||
131 | + receptive_field_size = 1 | ||
132 | + if self.weight.dim() > 2: | ||
133 | + receptive_field_size = self.weight[0][0].numel() | ||
134 | + fan_in = num_input_fmaps * receptive_field_size | ||
135 | + fan_out = num_output_fmaps * receptive_field_size | ||
136 | + | ||
137 | + init_weight = get_condconv_initializer(partial(nn.init.normal_, mean=0.0, std=np.sqrt(2.0 / fan_out)), self.num_experts, self.weight_shape) | ||
138 | + init_weight(self.weight) | ||
139 | + if self.bias is not None: | ||
140 | + # fan_in = np.prod(self.weight_shape[1:]) | ||
141 | + # bound = 1 / math.sqrt(fan_in) | ||
142 | + init_bias = get_condconv_initializer(partial(nn.init.constant_, val=0), self.num_experts, self.bias_shape) | ||
143 | + init_bias(self.bias) | ||
144 | + | ||
145 | + def forward(self, x, routing_weights): | ||
146 | + x_orig = x | ||
147 | + B, C, H, W = x.shape | ||
148 | + weight = torch.matmul(routing_weights, self.weight) # (Expert x out x in x 3x3) --> (B x out x in x 3x3) | ||
149 | + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size | ||
150 | + weight = weight.view(new_weight_shape) # (B*out x in x 3 x 3) | ||
151 | + bias = None | ||
152 | + if self.bias is not None: | ||
153 | + bias = torch.matmul(routing_weights, self.bias) | ||
154 | + bias = bias.view(B * self.out_channels) | ||
155 | + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel | ||
156 | + x = x.view(1, B * C, H, W) | ||
157 | + if self.dynamic_padding: | ||
158 | + out = conv2d_same( | ||
159 | + x, weight, bias, stride=self.stride, padding=self.padding, | ||
160 | + dilation=self.dilation, groups=self.groups * B) | ||
161 | + else: | ||
162 | + out = F.conv2d( | ||
163 | + x, weight, bias, stride=self.stride, padding=self.padding, | ||
164 | + dilation=self.dilation, groups=self.groups * B) | ||
165 | + | ||
166 | + # out : (1 x B*out x ...) | ||
167 | + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) | ||
168 | + | ||
169 | + # out2 = self.forward_legacy(x_orig, routing_weights) | ||
170 | + # lt = torch.lt(torch.abs(torch.add(out, -out2)), 1e-8) | ||
171 | + # assert torch.all(lt), torch.abs(torch.add(out, -out2))[lt] | ||
172 | + # print('checked') | ||
173 | + return out | ||
174 | + | ||
175 | + def forward_legacy(self, x, routing_weights): | ||
176 | + # Literal port (from TF definition) | ||
177 | + B, C, H, W = x.shape | ||
178 | + weight = torch.matmul(routing_weights, self.weight) # (Expert x out x in x 3x3) --> (B x out x in x 3x3) | ||
179 | + x = torch.split(x, 1, 0) | ||
180 | + weight = torch.split(weight, 1, 0) | ||
181 | + if self.bias is not None: | ||
182 | + bias = torch.matmul(routing_weights, self.bias) | ||
183 | + bias = torch.split(bias, 1, 0) | ||
184 | + else: | ||
185 | + bias = [None] * B | ||
186 | + out = [] | ||
187 | + if self.dynamic_padding: | ||
188 | + conv_fn = conv2d_same | ||
189 | + else: | ||
190 | + conv_fn = F.conv2d | ||
191 | + for xi, wi, bi in zip(x, weight, bias): | ||
192 | + wi = wi.view(*self.weight_shape) | ||
193 | + if bi is not None: | ||
194 | + bi = bi.view(*self.bias_shape) | ||
195 | + out.append(conv_fn( | ||
196 | + xi, wi, bi, stride=self.stride, padding=self.padding, | ||
197 | + dilation=self.dilation, groups=self.groups)) | ||
198 | + out = torch.cat(out, 0) | ||
199 | + return out |
1 | +import torch | ||
2 | +from torch import nn | ||
3 | +from torch.nn import functional as F | ||
4 | + | ||
5 | +from functools import partial | ||
6 | +from .utils import ( | ||
7 | + round_filters, | ||
8 | + round_repeats, | ||
9 | + drop_connect, | ||
10 | + get_same_padding_conv2d, | ||
11 | + get_model_params, | ||
12 | + efficientnet_params, | ||
13 | + load_pretrained_weights, | ||
14 | + MemoryEfficientSwish, | ||
15 | +) | ||
16 | + | ||
17 | + | ||
18 | +class RoutingFn(nn.Linear): | ||
19 | + pass | ||
20 | + | ||
21 | + | ||
22 | +class MBConvBlock(nn.Module): | ||
23 | + """ | ||
24 | + Mobile Inverted Residual Bottleneck Block | ||
25 | + | ||
26 | + Args: | ||
27 | + block_args (namedtuple): BlockArgs, see above | ||
28 | + global_params (namedtuple): GlobalParam, see above | ||
29 | + | ||
30 | + Attributes: | ||
31 | + has_se (bool): Whether the block contains a Squeeze and Excitation layer. | ||
32 | + """ | ||
33 | + | ||
34 | + def __init__(self, block_args, global_params, norm_layer=None): | ||
35 | + super().__init__() | ||
36 | + self._block_args = block_args | ||
37 | + self._bn_mom = 1 - global_params.batch_norm_momentum | ||
38 | + self._bn_eps = global_params.batch_norm_epsilon | ||
39 | + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) | ||
40 | + self.id_skip = block_args.id_skip # skip connection and drop connect | ||
41 | + if norm_layer is None: | ||
42 | + norm_layer = nn.BatchNorm2d | ||
43 | + | ||
44 | + self.condconv_num_expert = block_args.condconv_num_expert | ||
45 | + if self._is_condconv(): | ||
46 | + self.routing_fn = RoutingFn(self._block_args.input_filters, self.condconv_num_expert) | ||
47 | + | ||
48 | + # Get static or dynamic convolution depending on image size | ||
49 | + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size, condconv_num_expert=block_args.condconv_num_expert) | ||
50 | + Conv2dse = get_same_padding_conv2d(image_size=global_params.image_size) | ||
51 | + | ||
52 | + # Expansion phase | ||
53 | + inp = self._block_args.input_filters # number of input channels | ||
54 | + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels | ||
55 | + if self._block_args.expand_ratio != 1: | ||
56 | + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) | ||
57 | + self._bn0 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | ||
58 | + | ||
59 | + # Depthwise convolution phase | ||
60 | + k = self._block_args.kernel_size | ||
61 | + s = self._block_args.stride | ||
62 | + self._depthwise_conv = Conv2d( | ||
63 | + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise | ||
64 | + kernel_size=k, stride=s, bias=False) | ||
65 | + self._bn1 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | ||
66 | + | ||
67 | + # Squeeze and Excitation layer, if desired | ||
68 | + if self.has_se: | ||
69 | + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) | ||
70 | + self._se_reduce = Conv2dse(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) | ||
71 | + self._se_expand = Conv2dse(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) | ||
72 | + | ||
73 | + # Output phase | ||
74 | + final_oup = self._block_args.output_filters | ||
75 | + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) | ||
76 | + self._bn2 = norm_layer(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) | ||
77 | + self._swish = MemoryEfficientSwish() | ||
78 | + | ||
79 | + def _is_condconv(self): | ||
80 | + return self.condconv_num_expert > 1 | ||
81 | + | ||
82 | + def forward(self, inputs, drop_connect_rate=None): | ||
83 | + """ | ||
84 | + :param inputs: input tensor | ||
85 | + :param drop_connect_rate: drop connect rate (float, between 0 and 1) | ||
86 | + :return: output of block | ||
87 | + """ | ||
88 | + | ||
89 | + if self._is_condconv(): | ||
90 | + feat = F.adaptive_avg_pool2d(inputs, 1).flatten(1) | ||
91 | + routing_w = torch.sigmoid(self.routing_fn(feat)) | ||
92 | + | ||
93 | + if self._block_args.expand_ratio != 1: | ||
94 | + _expand_conv = partial(self._expand_conv, routing_weights=routing_w) | ||
95 | + _depthwise_conv = partial(self._depthwise_conv, routing_weights=routing_w) | ||
96 | + _project_conv = partial(self._project_conv, routing_weights=routing_w) | ||
97 | + else: | ||
98 | + if self._block_args.expand_ratio != 1: | ||
99 | + _expand_conv = self._expand_conv | ||
100 | + _depthwise_conv, _project_conv = self._depthwise_conv, self._project_conv | ||
101 | + | ||
102 | + # Expansion and Depthwise Convolution | ||
103 | + x = inputs | ||
104 | + if self._block_args.expand_ratio != 1: | ||
105 | + x = self._swish(self._bn0(_expand_conv(inputs))) | ||
106 | + x = self._swish(self._bn1(_depthwise_conv(x))) | ||
107 | + | ||
108 | + # Squeeze and Excitation | ||
109 | + if self.has_se: | ||
110 | + x_squeezed = F.adaptive_avg_pool2d(x, 1) | ||
111 | + x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) | ||
112 | + x = torch.sigmoid(x_squeezed) * x | ||
113 | + | ||
114 | + x = self._bn2(_project_conv(x)) | ||
115 | + | ||
116 | + # Skip connection and drop connect | ||
117 | + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters | ||
118 | + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: | ||
119 | + if drop_connect_rate: | ||
120 | + x = drop_connect(x, drop_p=drop_connect_rate, training=self.training) | ||
121 | + x = x + inputs # skip connection | ||
122 | + return x | ||
123 | + | ||
124 | + def set_swish(self): | ||
125 | + """Sets swish function as memory efficient (for training) or standard (for export)""" | ||
126 | + self._swish = MemoryEfficientSwish() | ||
127 | + | ||
128 | + | ||
129 | +class EfficientNet(nn.Module): | ||
130 | + """ | ||
131 | + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods | ||
132 | + | ||
133 | + Args: | ||
134 | + blocks_args (list): A list of BlockArgs to construct blocks | ||
135 | + global_params (namedtuple): A set of GlobalParams shared between blocks | ||
136 | + | ||
137 | + Example: | ||
138 | + model = EfficientNet.from_pretrained('efficientnet-b0') | ||
139 | + | ||
140 | + """ | ||
141 | + | ||
142 | + def __init__(self, blocks_args=None, global_params=None, norm_layer=None): | ||
143 | + super().__init__() | ||
144 | + assert isinstance(blocks_args, list), 'blocks_args should be a list' | ||
145 | + assert len(blocks_args) > 0, 'block args must be greater than 0' | ||
146 | + self._global_params = global_params | ||
147 | + self._blocks_args = blocks_args | ||
148 | + if norm_layer is None: | ||
149 | + norm_layer = nn.BatchNorm2d | ||
150 | + | ||
151 | + # Get static or dynamic convolution depending on image size | ||
152 | + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) | ||
153 | + | ||
154 | + # Batch norm parameters | ||
155 | + bn_mom = 1 - self._global_params.batch_norm_momentum | ||
156 | + bn_eps = self._global_params.batch_norm_epsilon | ||
157 | + | ||
158 | + # Stem | ||
159 | + in_channels = 3 # rgb | ||
160 | + out_channels = round_filters(32, self._global_params) # number of output channels | ||
161 | + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) | ||
162 | + self._bn0 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps) | ||
163 | + | ||
164 | + # Build blocks | ||
165 | + self._blocks = nn.ModuleList([]) | ||
166 | + for idx, block_args in enumerate(self._blocks_args): | ||
167 | + # Update block input and output filters based on depth multiplier. | ||
168 | + block_args = block_args._replace( | ||
169 | + input_filters=round_filters(block_args.input_filters, self._global_params), | ||
170 | + output_filters=round_filters(block_args.output_filters, self._global_params), | ||
171 | + num_repeat=round_repeats(block_args.num_repeat, self._global_params) | ||
172 | + ) | ||
173 | + | ||
174 | + # The first block needs to take care of stride and filter size increase. | ||
175 | + self._blocks.append(MBConvBlock(block_args, self._global_params, norm_layer=norm_layer)) | ||
176 | + if block_args.num_repeat > 1: | ||
177 | + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) | ||
178 | + for _ in range(block_args.num_repeat - 1): | ||
179 | + self._blocks.append(MBConvBlock(block_args, self._global_params, norm_layer=norm_layer)) | ||
180 | + | ||
181 | + # Head | ||
182 | + in_channels = block_args.output_filters # output of final block | ||
183 | + out_channels = round_filters(1280, self._global_params) | ||
184 | + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | ||
185 | + self._bn1 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps) | ||
186 | + | ||
187 | + # Final linear layer | ||
188 | + self._avg_pooling = nn.AdaptiveAvgPool2d(1) | ||
189 | + self._dropout = nn.Dropout(self._global_params.dropout_rate) | ||
190 | + self._fc = nn.Linear(out_channels, self._global_params.num_classes) | ||
191 | + self._swish = MemoryEfficientSwish() | ||
192 | + | ||
193 | + def set_swish(self): | ||
194 | + """Sets swish function as memory efficient (for training) or standard (for export)""" | ||
195 | + self._swish = MemoryEfficientSwish() | ||
196 | + for block in self._blocks: | ||
197 | + block.set_swish() | ||
198 | + | ||
199 | + def extract_features(self, inputs): | ||
200 | + """ Returns output of the final convolution layer """ | ||
201 | + | ||
202 | + # Stem | ||
203 | + x = self._swish(self._bn0(self._conv_stem(inputs))) | ||
204 | + | ||
205 | + # Blocks | ||
206 | + for idx, block in enumerate(self._blocks): | ||
207 | + drop_connect_rate = self._global_params.drop_connect_rate | ||
208 | + if drop_connect_rate: | ||
209 | + drop_connect_rate *= float(idx) / len(self._blocks) | ||
210 | + x = block(x, drop_connect_rate=drop_connect_rate) | ||
211 | + | ||
212 | + # Head | ||
213 | + x = self._swish(self._bn1(self._conv_head(x))) | ||
214 | + | ||
215 | + return x | ||
216 | + | ||
217 | + def forward(self, inputs): | ||
218 | + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ | ||
219 | + bs = inputs.size(0) | ||
220 | + # Convolution layers | ||
221 | + x = self.extract_features(inputs) | ||
222 | + | ||
223 | + # Pooling and final linear layer | ||
224 | + x = self._avg_pooling(x) | ||
225 | + x = x.view(bs, -1) | ||
226 | + x = self._dropout(x) | ||
227 | + x = self._fc(x) | ||
228 | + return x | ||
229 | + | ||
230 | + @classmethod | ||
231 | + def from_name(cls, model_name, override_params=None, norm_layer=None, condconv_num_expert=1): | ||
232 | + cls._check_model_name_is_valid(model_name) | ||
233 | + blocks_args, global_params = get_model_params(model_name, override_params, condconv_num_expert=condconv_num_expert) | ||
234 | + return cls(blocks_args, global_params, norm_layer=norm_layer) | ||
235 | + | ||
236 | + @classmethod | ||
237 | + def from_pretrained(cls, model_name, num_classes=1000): | ||
238 | + model = cls.from_name(model_name, override_params={'num_classes': num_classes}) | ||
239 | + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) | ||
240 | + | ||
241 | + return model | ||
242 | + | ||
243 | + @classmethod | ||
244 | + def get_image_size(cls, model_name): | ||
245 | + cls._check_model_name_is_valid(model_name) | ||
246 | + _, _, res, _ = efficientnet_params(model_name) | ||
247 | + return res | ||
248 | + | ||
249 | + @classmethod | ||
250 | + def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): | ||
251 | + """ Validates model name. None that pretrained weights are only available for | ||
252 | + the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ | ||
253 | + num_models = 4 if also_need_pretrained_weights else 8 | ||
254 | + valid_models = ['efficientnet-b'+str(i) for i in range(num_models)] | ||
255 | + if model_name not in valid_models: | ||
256 | + raise ValueError(f'model_name={model_name} should be one of: ' + ', '.join(valid_models)) |
1 | +""" | ||
2 | +This file contains helper functions for building the model and for loading model parameters. | ||
3 | +These helper functions are built to mirror those in the official TensorFlow implementation. | ||
4 | +""" | ||
5 | + | ||
6 | +import re | ||
7 | +import math | ||
8 | +import collections | ||
9 | +from functools import partial | ||
10 | +import torch | ||
11 | +from torch import nn | ||
12 | +from torch.nn import functional as F | ||
13 | +from torch.utils import model_zoo | ||
14 | + | ||
15 | +######################################################################## | ||
16 | +############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### | ||
17 | +######################################################################## | ||
18 | + | ||
19 | + | ||
20 | +# Parameters for the entire model (stem, all blocks, and head) | ||
21 | +from FastAutoAugment.networks.efficientnet_pytorch.condconv import CondConv2d | ||
22 | + | ||
23 | +GlobalParams = collections.namedtuple('GlobalParams', [ | ||
24 | + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', | ||
25 | + 'num_classes', 'width_coefficient', 'depth_coefficient', | ||
26 | + 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) | ||
27 | + | ||
28 | +# Parameters for an individual model block | ||
29 | +BlockArgs = collections.namedtuple('BlockArgs', [ | ||
30 | + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', | ||
31 | + 'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'condconv_num_expert']) | ||
32 | + | ||
33 | +# Change namedtuple defaults | ||
34 | +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) | ||
35 | +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) | ||
36 | + | ||
37 | + | ||
38 | +class SwishImplementation(torch.autograd.Function): | ||
39 | + @staticmethod | ||
40 | + def forward(ctx, i): | ||
41 | + result = i * torch.sigmoid(i) | ||
42 | + ctx.save_for_backward(i) | ||
43 | + return result | ||
44 | + | ||
45 | + @staticmethod | ||
46 | + def backward(ctx, grad_output): | ||
47 | + i = ctx.saved_tensors[0] | ||
48 | + sigmoid_i = torch.sigmoid(i) | ||
49 | + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) | ||
50 | + | ||
51 | + | ||
52 | +class MemoryEfficientSwish(nn.Module): | ||
53 | + def forward(self, x): | ||
54 | + return SwishImplementation.apply(x) | ||
55 | + | ||
56 | + | ||
57 | +def round_filters(filters, global_params): | ||
58 | + """ Calculate and round number of filters based on depth multiplier. """ | ||
59 | + multiplier = global_params.width_coefficient | ||
60 | + if not multiplier: | ||
61 | + return filters | ||
62 | + divisor = global_params.depth_divisor | ||
63 | + min_depth = global_params.min_depth | ||
64 | + filters *= multiplier | ||
65 | + min_depth = min_depth or divisor | ||
66 | + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) | ||
67 | + if new_filters < 0.9 * filters: # prevent rounding by more than 10% | ||
68 | + new_filters += divisor | ||
69 | + return int(new_filters) | ||
70 | + | ||
71 | + | ||
72 | +def round_repeats(repeats, global_params): | ||
73 | + """ Round number of filters based on depth multiplier. """ | ||
74 | + multiplier = global_params.depth_coefficient | ||
75 | + if not multiplier: | ||
76 | + return repeats | ||
77 | + return int(math.ceil(multiplier * repeats)) | ||
78 | + | ||
79 | + | ||
80 | +def drop_connect(inputs, drop_p, training): | ||
81 | + """ Drop connect. """ | ||
82 | + if not training: | ||
83 | + return inputs * (1. - drop_p) | ||
84 | + batch_size = inputs.shape[0] | ||
85 | + random_tensor = torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) | ||
86 | + binary_tensor = random_tensor > drop_p | ||
87 | + output = inputs * binary_tensor.float() | ||
88 | + # output = inputs / (1. - drop_p) * binary_tensor.float() | ||
89 | + return output | ||
90 | + | ||
91 | + # if not training: return inputs | ||
92 | + # batch_size = inputs.shape[0] | ||
93 | + # keep_prob = 1 - drop_p | ||
94 | + # random_tensor = keep_prob | ||
95 | + # random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) | ||
96 | + # binary_tensor = torch.floor(random_tensor) | ||
97 | + # output = inputs / keep_prob * binary_tensor | ||
98 | + # return output | ||
99 | + | ||
100 | + | ||
101 | +def get_same_padding_conv2d(image_size=None, condconv_num_expert=1): | ||
102 | + """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. | ||
103 | + Static padding is necessary for ONNX exporting of models. """ | ||
104 | + if condconv_num_expert > 1: | ||
105 | + return partial(CondConv2d, num_experts=condconv_num_expert) | ||
106 | + elif image_size is None: | ||
107 | + return Conv2dDynamicSamePadding | ||
108 | + else: | ||
109 | + return partial(Conv2dStaticSamePadding, image_size=image_size) | ||
110 | + | ||
111 | + | ||
112 | +class Conv2dDynamicSamePadding(nn.Conv2d): | ||
113 | + """ 2D Convolutions like TensorFlow, for a dynamic image size """ | ||
114 | + | ||
115 | + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): | ||
116 | + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) | ||
117 | + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 | ||
118 | + | ||
119 | + def forward(self, x): | ||
120 | + ih, iw = x.size()[-2:] | ||
121 | + kh, kw = self.weight.size()[-2:] | ||
122 | + sh, sw = self.stride | ||
123 | + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) | ||
124 | + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) | ||
125 | + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) | ||
126 | + if pad_h > 0 or pad_w > 0: | ||
127 | + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) | ||
128 | + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) | ||
129 | + | ||
130 | + | ||
131 | +class Conv2dStaticSamePadding(nn.Conv2d): | ||
132 | + """ 2D Convolutions like TensorFlow, for a fixed image size""" | ||
133 | + | ||
134 | + def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): | ||
135 | + super().__init__(in_channels, out_channels, kernel_size, **kwargs) | ||
136 | + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 | ||
137 | + | ||
138 | + # Calculate padding based on image size and save it | ||
139 | + assert image_size is not None | ||
140 | + ih, iw = image_size if type(image_size) == list else [image_size, image_size] | ||
141 | + kh, kw = self.weight.size()[-2:] | ||
142 | + sh, sw = self.stride | ||
143 | + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) | ||
144 | + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) | ||
145 | + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) | ||
146 | + if pad_h > 0 or pad_w > 0: | ||
147 | + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) | ||
148 | + else: | ||
149 | + self.static_padding = Identity() | ||
150 | + | ||
151 | + def forward(self, x): | ||
152 | + x = self.static_padding(x) | ||
153 | + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) | ||
154 | + return x | ||
155 | + | ||
156 | + | ||
157 | +class Identity(nn.Module): | ||
158 | + def __init__(self, ): | ||
159 | + super(Identity, self).__init__() | ||
160 | + | ||
161 | + def forward(self, input): | ||
162 | + return input | ||
163 | + | ||
164 | + | ||
165 | +######################################################################## | ||
166 | +############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## | ||
167 | +######################################################################## | ||
168 | + | ||
169 | + | ||
170 | +def efficientnet_params(model_name): | ||
171 | + """ Map EfficientNet model name to parameter coefficients. """ | ||
172 | + params_dict = { | ||
173 | + # Coefficients: width,depth,res,dropout | ||
174 | + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), | ||
175 | + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), | ||
176 | + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), | ||
177 | + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), | ||
178 | + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), | ||
179 | + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), | ||
180 | + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), | ||
181 | + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), | ||
182 | + } | ||
183 | + return params_dict[model_name] | ||
184 | + | ||
185 | + | ||
186 | +class BlockDecoder(object): | ||
187 | + """ Block Decoder for readability, straight from the official TensorFlow repository """ | ||
188 | + | ||
189 | + @staticmethod | ||
190 | + def _decode_block_string(block_string): | ||
191 | + """ Gets a block through a string notation of arguments. """ | ||
192 | + assert isinstance(block_string, str) | ||
193 | + | ||
194 | + ops = block_string.split('_') | ||
195 | + options = {} | ||
196 | + for op in ops: | ||
197 | + splits = re.split(r'(\d.*)', op) | ||
198 | + if len(splits) >= 2: | ||
199 | + key, value = splits[:2] | ||
200 | + options[key] = value | ||
201 | + | ||
202 | + # Check stride | ||
203 | + assert (('s' in options and len(options['s']) == 1) or | ||
204 | + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) | ||
205 | + | ||
206 | + return BlockArgs( | ||
207 | + kernel_size=int(options['k']), | ||
208 | + num_repeat=int(options['r']), | ||
209 | + input_filters=int(options['i']), | ||
210 | + output_filters=int(options['o']), | ||
211 | + expand_ratio=int(options['e']), | ||
212 | + id_skip=('noskip' not in block_string), | ||
213 | + se_ratio=float(options['se']) if 'se' in options else None, | ||
214 | + stride=[int(options['s'][0])], | ||
215 | + condconv_num_expert=0 | ||
216 | + ) | ||
217 | + | ||
218 | + @staticmethod | ||
219 | + def _encode_block_string(block): | ||
220 | + """Encodes a block to a string.""" | ||
221 | + args = [ | ||
222 | + 'r%d' % block.num_repeat, | ||
223 | + 'k%d' % block.kernel_size, | ||
224 | + 's%d%d' % (block.strides[0], block.strides[1]), | ||
225 | + 'e%s' % block.expand_ratio, | ||
226 | + 'i%d' % block.input_filters, | ||
227 | + 'o%d' % block.output_filters | ||
228 | + ] | ||
229 | + if 0 < block.se_ratio <= 1: | ||
230 | + args.append('se%s' % block.se_ratio) | ||
231 | + if block.id_skip is False: | ||
232 | + args.append('noskip') | ||
233 | + return '_'.join(args) | ||
234 | + | ||
235 | + @staticmethod | ||
236 | + def decode(string_list): | ||
237 | + """ | ||
238 | + Decodes a list of string notations to specify blocks inside the network. | ||
239 | + | ||
240 | + :param string_list: a list of strings, each string is a notation of block | ||
241 | + :return: a list of BlockArgs namedtuples of block args | ||
242 | + """ | ||
243 | + assert isinstance(string_list, list) | ||
244 | + blocks_args = [] | ||
245 | + for block_string in string_list: | ||
246 | + blocks_args.append(BlockDecoder._decode_block_string(block_string)) | ||
247 | + return blocks_args | ||
248 | + | ||
249 | + @staticmethod | ||
250 | + def encode(blocks_args): | ||
251 | + """ | ||
252 | + Encodes a list of BlockArgs to a list of strings. | ||
253 | + | ||
254 | + :param blocks_args: a list of BlockArgs namedtuples of block args | ||
255 | + :return: a list of strings, each string is a notation of block | ||
256 | + """ | ||
257 | + block_strings = [] | ||
258 | + for block in blocks_args: | ||
259 | + block_strings.append(BlockDecoder._encode_block_string(block)) | ||
260 | + return block_strings | ||
261 | + | ||
262 | + | ||
263 | +def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, | ||
264 | + drop_connect_rate=0.2, image_size=None, num_classes=1000, condconv_num_expert=1): | ||
265 | + """ Creates a efficientnet model. """ | ||
266 | + | ||
267 | + blocks_args = [ | ||
268 | + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', | ||
269 | + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', | ||
270 | + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', | ||
271 | + 'r1_k3_s11_e6_i192_o320_se0.25', | ||
272 | + ] | ||
273 | + blocks_args = BlockDecoder.decode(blocks_args) | ||
274 | + | ||
275 | + blocks_args_new = blocks_args[:-3] | ||
276 | + for blocks_arg in blocks_args[-3:]: | ||
277 | + blocks_arg = blocks_arg._replace(condconv_num_expert=condconv_num_expert) | ||
278 | + blocks_args_new.append(blocks_arg) | ||
279 | + blocks_args = blocks_args_new | ||
280 | + | ||
281 | + global_params = GlobalParams( | ||
282 | + batch_norm_momentum=0.99, | ||
283 | + batch_norm_epsilon=1e-3, | ||
284 | + dropout_rate=dropout_rate, | ||
285 | + drop_connect_rate=drop_connect_rate, | ||
286 | + # data_format='channels_last', # removed, this is always true in PyTorch | ||
287 | + num_classes=num_classes, | ||
288 | + width_coefficient=width_coefficient, | ||
289 | + depth_coefficient=depth_coefficient, | ||
290 | + depth_divisor=8, | ||
291 | + min_depth=None, | ||
292 | + image_size=image_size, | ||
293 | + ) | ||
294 | + | ||
295 | + return blocks_args, global_params | ||
296 | + | ||
297 | + | ||
298 | +def get_model_params(model_name, override_params, condconv_num_expert=1): | ||
299 | + """ Get the block args and global params for a given model """ | ||
300 | + if model_name.startswith('efficientnet'): | ||
301 | + w, d, s, p = efficientnet_params(model_name) | ||
302 | + # note: all models have drop connect rate = 0.2 | ||
303 | + blocks_args, global_params = efficientnet( | ||
304 | + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s, condconv_num_expert=condconv_num_expert) | ||
305 | + else: | ||
306 | + raise NotImplementedError('model name is not pre-defined: %s' % model_name) | ||
307 | + if override_params: | ||
308 | + # ValueError will be raised here if override_params has fields not included in global_params. | ||
309 | + global_params = global_params._replace(**override_params) | ||
310 | + return blocks_args, global_params | ||
311 | + | ||
312 | + | ||
313 | +url_map = { | ||
314 | + 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth', | ||
315 | + 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth', | ||
316 | + 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth', | ||
317 | + 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth', | ||
318 | + 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth', | ||
319 | + 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth', | ||
320 | + 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth', | ||
321 | + 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', | ||
322 | +} | ||
323 | + | ||
324 | + | ||
325 | +def load_pretrained_weights(model, model_name, load_fc=True): | ||
326 | + """ Loads pretrained weights, and downloads if loading for the first time. """ | ||
327 | + state_dict = model_zoo.load_url(url_map[model_name]) | ||
328 | + if load_fc: | ||
329 | + model.load_state_dict(state_dict) | ||
330 | + else: | ||
331 | + state_dict.pop('_fc.weight') | ||
332 | + state_dict.pop('_fc.bias') | ||
333 | + res = model.load_state_dict(state_dict, strict=False) | ||
334 | + assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' | ||
335 | + print('Loaded pretrained weights for {}'.format(model_name)) |
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import math | ||
4 | + | ||
5 | +from FastAutoAugment.networks.shakedrop import ShakeDrop | ||
6 | + | ||
7 | + | ||
8 | +def conv3x3(in_planes, out_planes, stride=1): | ||
9 | + """ | ||
10 | + 3x3 convolution with padding | ||
11 | + """ | ||
12 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
13 | + | ||
14 | + | ||
15 | +class BasicBlock(nn.Module): | ||
16 | + outchannel_ratio = 1 | ||
17 | + | ||
18 | + def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): | ||
19 | + super(BasicBlock, self).__init__() | ||
20 | + self.bn1 = nn.BatchNorm2d(inplanes) | ||
21 | + self.conv1 = conv3x3(inplanes, planes, stride) | ||
22 | + self.bn2 = nn.BatchNorm2d(planes) | ||
23 | + self.conv2 = conv3x3(planes, planes) | ||
24 | + self.bn3 = nn.BatchNorm2d(planes) | ||
25 | + self.relu = nn.ReLU(inplace=True) | ||
26 | + self.downsample = downsample | ||
27 | + self.stride = stride | ||
28 | + self.shake_drop = ShakeDrop(p_shakedrop) | ||
29 | + | ||
30 | + def forward(self, x): | ||
31 | + | ||
32 | + out = self.bn1(x) | ||
33 | + out = self.conv1(out) | ||
34 | + out = self.bn2(out) | ||
35 | + out = self.relu(out) | ||
36 | + out = self.conv2(out) | ||
37 | + out = self.bn3(out) | ||
38 | + | ||
39 | + out = self.shake_drop(out) | ||
40 | + | ||
41 | + if self.downsample is not None: | ||
42 | + shortcut = self.downsample(x) | ||
43 | + featuremap_size = shortcut.size()[2:4] | ||
44 | + else: | ||
45 | + shortcut = x | ||
46 | + featuremap_size = out.size()[2:4] | ||
47 | + | ||
48 | + batch_size = out.size()[0] | ||
49 | + residual_channel = out.size()[1] | ||
50 | + shortcut_channel = shortcut.size()[1] | ||
51 | + | ||
52 | + if residual_channel != shortcut_channel: | ||
53 | + padding = torch.autograd.Variable( | ||
54 | + torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], | ||
55 | + featuremap_size[1]).fill_(0)) | ||
56 | + out += torch.cat((shortcut, padding), 1) | ||
57 | + else: | ||
58 | + out += shortcut | ||
59 | + | ||
60 | + return out | ||
61 | + | ||
62 | + | ||
63 | +class Bottleneck(nn.Module): | ||
64 | + outchannel_ratio = 4 | ||
65 | + | ||
66 | + def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): | ||
67 | + super(Bottleneck, self).__init__() | ||
68 | + self.bn1 = nn.BatchNorm2d(inplanes) | ||
69 | + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
70 | + self.bn2 = nn.BatchNorm2d(planes) | ||
71 | + self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride, | ||
72 | + padding=1, bias=False) | ||
73 | + self.bn3 = nn.BatchNorm2d((planes * 1)) | ||
74 | + self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) | ||
75 | + self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) | ||
76 | + self.relu = nn.ReLU(inplace=True) | ||
77 | + self.downsample = downsample | ||
78 | + self.stride = stride | ||
79 | + self.shake_drop = ShakeDrop(p_shakedrop) | ||
80 | + | ||
81 | + def forward(self, x): | ||
82 | + | ||
83 | + out = self.bn1(x) | ||
84 | + out = self.conv1(out) | ||
85 | + | ||
86 | + out = self.bn2(out) | ||
87 | + out = self.relu(out) | ||
88 | + out = self.conv2(out) | ||
89 | + | ||
90 | + out = self.bn3(out) | ||
91 | + out = self.relu(out) | ||
92 | + out = self.conv3(out) | ||
93 | + | ||
94 | + out = self.bn4(out) | ||
95 | + | ||
96 | + out = self.shake_drop(out) | ||
97 | + | ||
98 | + if self.downsample is not None: | ||
99 | + shortcut = self.downsample(x) | ||
100 | + featuremap_size = shortcut.size()[2:4] | ||
101 | + else: | ||
102 | + shortcut = x | ||
103 | + featuremap_size = out.size()[2:4] | ||
104 | + | ||
105 | + batch_size = out.size()[0] | ||
106 | + residual_channel = out.size()[1] | ||
107 | + shortcut_channel = shortcut.size()[1] | ||
108 | + | ||
109 | + if residual_channel != shortcut_channel: | ||
110 | + padding = torch.autograd.Variable( | ||
111 | + torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], | ||
112 | + featuremap_size[1]).fill_(0)) | ||
113 | + out += torch.cat((shortcut, padding), 1) | ||
114 | + else: | ||
115 | + out += shortcut | ||
116 | + | ||
117 | + return out | ||
118 | + | ||
119 | + | ||
120 | +class PyramidNet(nn.Module): | ||
121 | + | ||
122 | + def __init__(self, dataset, depth, alpha, num_classes, bottleneck=True): | ||
123 | + super(PyramidNet, self).__init__() | ||
124 | + self.dataset = dataset | ||
125 | + if self.dataset.startswith('cifar'): | ||
126 | + self.inplanes = 16 | ||
127 | + if bottleneck: | ||
128 | + n = int((depth - 2) / 9) | ||
129 | + block = Bottleneck | ||
130 | + else: | ||
131 | + n = int((depth - 2) / 6) | ||
132 | + block = BasicBlock | ||
133 | + | ||
134 | + self.addrate = alpha / (3 * n * 1.0) | ||
135 | + self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)] | ||
136 | + | ||
137 | + self.input_featuremap_dim = self.inplanes | ||
138 | + self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) | ||
139 | + self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) | ||
140 | + | ||
141 | + self.featuremap_dim = self.input_featuremap_dim | ||
142 | + self.layer1 = self.pyramidal_make_layer(block, n) | ||
143 | + self.layer2 = self.pyramidal_make_layer(block, n, stride=2) | ||
144 | + self.layer3 = self.pyramidal_make_layer(block, n, stride=2) | ||
145 | + | ||
146 | + self.final_featuremap_dim = self.input_featuremap_dim | ||
147 | + self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) | ||
148 | + self.relu_final = nn.ReLU(inplace=True) | ||
149 | + self.avgpool = nn.AvgPool2d(8) | ||
150 | + self.fc = nn.Linear(self.final_featuremap_dim, num_classes) | ||
151 | + | ||
152 | + elif dataset == 'imagenet': | ||
153 | + blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} | ||
154 | + layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], | ||
155 | + 200: [3, 24, 36, 3]} | ||
156 | + | ||
157 | + if layers.get(depth) is None: | ||
158 | + if bottleneck == True: | ||
159 | + blocks[depth] = Bottleneck | ||
160 | + temp_cfg = int((depth - 2) / 12) | ||
161 | + else: | ||
162 | + blocks[depth] = BasicBlock | ||
163 | + temp_cfg = int((depth - 2) / 8) | ||
164 | + | ||
165 | + layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg] | ||
166 | + print('=> the layer configuration for each stage is set to', layers[depth]) | ||
167 | + | ||
168 | + self.inplanes = 64 | ||
169 | + self.addrate = alpha / (sum(layers[depth]) * 1.0) | ||
170 | + | ||
171 | + self.input_featuremap_dim = self.inplanes | ||
172 | + self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) | ||
173 | + self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) | ||
174 | + self.relu = nn.ReLU(inplace=True) | ||
175 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
176 | + | ||
177 | + self.featuremap_dim = self.input_featuremap_dim | ||
178 | + self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) | ||
179 | + self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) | ||
180 | + self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) | ||
181 | + self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) | ||
182 | + | ||
183 | + self.final_featuremap_dim = self.input_featuremap_dim | ||
184 | + self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) | ||
185 | + self.relu_final = nn.ReLU(inplace=True) | ||
186 | + self.avgpool = nn.AvgPool2d(7) | ||
187 | + self.fc = nn.Linear(self.final_featuremap_dim, num_classes) | ||
188 | + | ||
189 | + for m in self.modules(): | ||
190 | + if isinstance(m, nn.Conv2d): | ||
191 | + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
192 | + m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
193 | + elif isinstance(m, nn.BatchNorm2d): | ||
194 | + m.weight.data.fill_(1) | ||
195 | + m.bias.data.zero_() | ||
196 | + | ||
197 | + assert len(self.ps_shakedrop) == 0, self.ps_shakedrop | ||
198 | + | ||
199 | + def pyramidal_make_layer(self, block, block_depth, stride=1): | ||
200 | + downsample = None | ||
201 | + if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: | ||
202 | + downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) | ||
203 | + | ||
204 | + layers = [] | ||
205 | + self.featuremap_dim = self.featuremap_dim + self.addrate | ||
206 | + layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample, p_shakedrop=self.ps_shakedrop.pop(0))) | ||
207 | + for i in range(1, block_depth): | ||
208 | + temp_featuremap_dim = self.featuremap_dim + self.addrate | ||
209 | + layers.append( | ||
210 | + block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1, p_shakedrop=self.ps_shakedrop.pop(0))) | ||
211 | + self.featuremap_dim = temp_featuremap_dim | ||
212 | + self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio | ||
213 | + | ||
214 | + return nn.Sequential(*layers) | ||
215 | + | ||
216 | + def forward(self, x): | ||
217 | + if self.dataset == 'cifar10' or self.dataset == 'cifar100': | ||
218 | + x = self.conv1(x) | ||
219 | + x = self.bn1(x) | ||
220 | + | ||
221 | + x = self.layer1(x) | ||
222 | + x = self.layer2(x) | ||
223 | + x = self.layer3(x) | ||
224 | + | ||
225 | + x = self.bn_final(x) | ||
226 | + x = self.relu_final(x) | ||
227 | + x = self.avgpool(x) | ||
228 | + x = x.view(x.size(0), -1) | ||
229 | + x = self.fc(x) | ||
230 | + | ||
231 | + elif self.dataset == 'imagenet': | ||
232 | + x = self.conv1(x) | ||
233 | + x = self.bn1(x) | ||
234 | + x = self.relu(x) | ||
235 | + x = self.maxpool(x) | ||
236 | + | ||
237 | + x = self.layer1(x) | ||
238 | + x = self.layer2(x) | ||
239 | + x = self.layer3(x) | ||
240 | + x = self.layer4(x) | ||
241 | + | ||
242 | + x = self.bn_final(x) | ||
243 | + x = self.relu_final(x) | ||
244 | + x = self.avgpool(x) | ||
245 | + x = x.view(x.size(0), -1) | ||
246 | + x = self.fc(x) | ||
247 | + | ||
248 | + return x |
1 | +# Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py | ||
2 | + | ||
3 | +import torch.nn as nn | ||
4 | +import math | ||
5 | + | ||
6 | + | ||
7 | +def conv3x3(in_planes, out_planes, stride=1): | ||
8 | + "3x3 convolution with padding" | ||
9 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
10 | + padding=1, bias=False) | ||
11 | + | ||
12 | + | ||
13 | +class BasicBlock(nn.Module): | ||
14 | + expansion = 1 | ||
15 | + | ||
16 | + def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
17 | + super(BasicBlock, self).__init__() | ||
18 | + self.conv1 = conv3x3(inplanes, planes, stride) | ||
19 | + self.bn1 = nn.BatchNorm2d(planes) | ||
20 | + self.conv2 = conv3x3(planes, planes) | ||
21 | + self.bn2 = nn.BatchNorm2d(planes) | ||
22 | + self.relu = nn.ReLU(inplace=True) | ||
23 | + | ||
24 | + self.downsample = downsample | ||
25 | + self.stride = stride | ||
26 | + | ||
27 | + def forward(self, x): | ||
28 | + residual = x | ||
29 | + | ||
30 | + out = self.conv1(x) | ||
31 | + out = self.bn1(out) | ||
32 | + out = self.relu(out) | ||
33 | + | ||
34 | + out = self.conv2(out) | ||
35 | + out = self.bn2(out) | ||
36 | + | ||
37 | + if self.downsample is not None: | ||
38 | + residual = self.downsample(x) | ||
39 | + | ||
40 | + out += residual | ||
41 | + out = self.relu(out) | ||
42 | + | ||
43 | + return out | ||
44 | + | ||
45 | + | ||
46 | +class Bottleneck(nn.Module): | ||
47 | + expansion = 4 | ||
48 | + | ||
49 | + def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
50 | + super(Bottleneck, self).__init__() | ||
51 | + | ||
52 | + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
53 | + self.bn1 = nn.BatchNorm2d(planes) | ||
54 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
55 | + self.bn2 = nn.BatchNorm2d(planes) | ||
56 | + self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) | ||
57 | + self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) | ||
58 | + self.relu = nn.ReLU(inplace=True) | ||
59 | + | ||
60 | + self.downsample = downsample | ||
61 | + self.stride = stride | ||
62 | + | ||
63 | + def forward(self, x): | ||
64 | + residual = x | ||
65 | + | ||
66 | + out = self.conv1(x) | ||
67 | + out = self.bn1(out) | ||
68 | + out = self.relu(out) | ||
69 | + | ||
70 | + out = self.conv2(out) | ||
71 | + out = self.bn2(out) | ||
72 | + out = self.relu(out) | ||
73 | + | ||
74 | + out = self.conv3(out) | ||
75 | + out = self.bn3(out) | ||
76 | + if self.downsample is not None: | ||
77 | + residual = self.downsample(x) | ||
78 | + | ||
79 | + out += residual | ||
80 | + out = self.relu(out) | ||
81 | + | ||
82 | + return out | ||
83 | + | ||
84 | +class ResNet(nn.Module): | ||
85 | + def __init__(self, dataset, depth, num_classes, bottleneck=False): | ||
86 | + super(ResNet, self).__init__() | ||
87 | + self.dataset = dataset | ||
88 | + if self.dataset.startswith('cifar'): | ||
89 | + self.inplanes = 16 | ||
90 | + print(bottleneck) | ||
91 | + if bottleneck == True: | ||
92 | + n = int((depth - 2) / 9) | ||
93 | + block = Bottleneck | ||
94 | + else: | ||
95 | + n = int((depth - 2) / 6) | ||
96 | + block = BasicBlock | ||
97 | + | ||
98 | + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) | ||
99 | + self.bn1 = nn.BatchNorm2d(self.inplanes) | ||
100 | + self.relu = nn.ReLU(inplace=True) | ||
101 | + self.layer1 = self._make_layer(block, 16, n) | ||
102 | + self.layer2 = self._make_layer(block, 32, n, stride=2) | ||
103 | + self.layer3 = self._make_layer(block, 64, n, stride=2) | ||
104 | + # self.avgpool = nn.AvgPool2d(8) | ||
105 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
106 | + self.fc = nn.Linear(64 * block.expansion, num_classes) | ||
107 | + | ||
108 | + elif dataset == 'imagenet': | ||
109 | + blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} | ||
110 | + layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} | ||
111 | + assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' | ||
112 | + | ||
113 | + self.inplanes = 64 | ||
114 | + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) | ||
115 | + self.bn1 = nn.BatchNorm2d(64) | ||
116 | + self.relu = nn.ReLU(inplace=True) | ||
117 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
118 | + self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) | ||
119 | + self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) | ||
120 | + self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) | ||
121 | + self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) | ||
122 | + # self.avgpool = nn.AvgPool2d(7) | ||
123 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
124 | + self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) | ||
125 | + | ||
126 | + for m in self.modules(): | ||
127 | + if isinstance(m, nn.Conv2d): | ||
128 | + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
129 | + m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
130 | + elif isinstance(m, nn.BatchNorm2d): | ||
131 | + m.weight.data.fill_(1) | ||
132 | + m.bias.data.zero_() | ||
133 | + | ||
134 | + def _make_layer(self, block, planes, blocks, stride=1): | ||
135 | + downsample = None | ||
136 | + if stride != 1 or self.inplanes != planes * block.expansion: | ||
137 | + downsample = nn.Sequential( | ||
138 | + nn.Conv2d(self.inplanes, planes * block.expansion, | ||
139 | + kernel_size=1, stride=stride, bias=False), | ||
140 | + nn.BatchNorm2d(planes * block.expansion), | ||
141 | + ) | ||
142 | + | ||
143 | + layers = [] | ||
144 | + layers.append(block(self.inplanes, planes, stride, downsample)) | ||
145 | + self.inplanes = planes * block.expansion | ||
146 | + for i in range(1, blocks): | ||
147 | + layers.append(block(self.inplanes, planes)) | ||
148 | + | ||
149 | + return nn.Sequential(*layers) | ||
150 | + | ||
151 | + def forward(self, x): | ||
152 | + if self.dataset == 'cifar10' or self.dataset == 'cifar100': | ||
153 | + x = self.conv1(x) | ||
154 | + x = self.bn1(x) | ||
155 | + x = self.relu(x) | ||
156 | + | ||
157 | + x = self.layer1(x) | ||
158 | + x = self.layer2(x) | ||
159 | + x = self.layer3(x) | ||
160 | + | ||
161 | + x = self.avgpool(x) | ||
162 | + x = x.view(x.size(0), -1) | ||
163 | + x = self.fc(x) | ||
164 | + | ||
165 | + elif self.dataset == 'imagenet': | ||
166 | + x = self.conv1(x) | ||
167 | + x = self.bn1(x) | ||
168 | + x = self.relu(x) | ||
169 | + x = self.maxpool(x) | ||
170 | + | ||
171 | + x = self.layer1(x) | ||
172 | + x = self.layer2(x) | ||
173 | + x = self.layer3(x) | ||
174 | + x = self.layer4(x) | ||
175 | + | ||
176 | + x = self.avgpool(x) | ||
177 | + x = x.view(x.size(0), -1) | ||
178 | + x = self.fc(x) | ||
179 | + | ||
180 | + return x |
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn as nn | ||
5 | +import torch.nn.functional as F | ||
6 | +from torch.autograd import Variable | ||
7 | + | ||
8 | + | ||
9 | +class ShakeDropFunction(torch.autograd.Function): | ||
10 | + | ||
11 | + @staticmethod | ||
12 | + def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): | ||
13 | + if training: | ||
14 | + gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop) | ||
15 | + ctx.save_for_backward(gate) | ||
16 | + if gate.item() == 0: | ||
17 | + alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range) | ||
18 | + alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) | ||
19 | + return alpha * x | ||
20 | + else: | ||
21 | + return x | ||
22 | + else: | ||
23 | + return (1 - p_drop) * x | ||
24 | + | ||
25 | + @staticmethod | ||
26 | + def backward(ctx, grad_output): | ||
27 | + gate = ctx.saved_tensors[0] | ||
28 | + if gate.item() == 0: | ||
29 | + beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1) | ||
30 | + beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) | ||
31 | + beta = Variable(beta) | ||
32 | + return beta * grad_output, None, None, None | ||
33 | + else: | ||
34 | + return grad_output, None, None, None | ||
35 | + | ||
36 | + | ||
37 | +class ShakeDrop(nn.Module): | ||
38 | + | ||
39 | + def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): | ||
40 | + super(ShakeDrop, self).__init__() | ||
41 | + self.p_drop = p_drop | ||
42 | + self.alpha_range = alpha_range | ||
43 | + | ||
44 | + def forward(self, x): | ||
45 | + return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) |
File mode changed
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import math | ||
4 | + | ||
5 | +import torch.nn as nn | ||
6 | +import torch.nn.functional as F | ||
7 | + | ||
8 | +from FastAutoAugment.networks.shakeshake.shakeshake import ShakeShake | ||
9 | +from FastAutoAugment.networks.shakeshake.shakeshake import Shortcut | ||
10 | + | ||
11 | + | ||
12 | +class ShakeBlock(nn.Module): | ||
13 | + | ||
14 | + def __init__(self, in_ch, out_ch, stride=1): | ||
15 | + super(ShakeBlock, self).__init__() | ||
16 | + self.equal_io = in_ch == out_ch | ||
17 | + self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride) | ||
18 | + | ||
19 | + self.branch1 = self._make_branch(in_ch, out_ch, stride) | ||
20 | + self.branch2 = self._make_branch(in_ch, out_ch, stride) | ||
21 | + | ||
22 | + def forward(self, x): | ||
23 | + h1 = self.branch1(x) | ||
24 | + h2 = self.branch2(x) | ||
25 | + h = ShakeShake.apply(h1, h2, self.training) | ||
26 | + h0 = x if self.equal_io else self.shortcut(x) | ||
27 | + return h + h0 | ||
28 | + | ||
29 | + def _make_branch(self, in_ch, out_ch, stride=1): | ||
30 | + return nn.Sequential( | ||
31 | + nn.ReLU(inplace=False), | ||
32 | + nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False), | ||
33 | + nn.BatchNorm2d(out_ch), | ||
34 | + nn.ReLU(inplace=False), | ||
35 | + nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False), | ||
36 | + nn.BatchNorm2d(out_ch)) | ||
37 | + | ||
38 | + | ||
39 | +class ShakeResNet(nn.Module): | ||
40 | + | ||
41 | + def __init__(self, depth, w_base, label): | ||
42 | + super(ShakeResNet, self).__init__() | ||
43 | + n_units = (depth - 2) / 6 | ||
44 | + | ||
45 | + in_chs = [16, w_base, w_base * 2, w_base * 4] | ||
46 | + self.in_chs = in_chs | ||
47 | + | ||
48 | + self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1) | ||
49 | + self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1]) | ||
50 | + self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2) | ||
51 | + self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2) | ||
52 | + self.fc_out = nn.Linear(in_chs[3], label) | ||
53 | + | ||
54 | + # Initialize paramters | ||
55 | + for m in self.modules(): | ||
56 | + if isinstance(m, nn.Conv2d): | ||
57 | + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
58 | + m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
59 | + elif isinstance(m, nn.BatchNorm2d): | ||
60 | + m.weight.data.fill_(1) | ||
61 | + m.bias.data.zero_() | ||
62 | + elif isinstance(m, nn.Linear): | ||
63 | + m.bias.data.zero_() | ||
64 | + | ||
65 | + def forward(self, x): | ||
66 | + h = self.c_in(x) | ||
67 | + h = self.layer1(h) | ||
68 | + h = self.layer2(h) | ||
69 | + h = self.layer3(h) | ||
70 | + h = F.relu(h) | ||
71 | + h = F.avg_pool2d(h, 8) | ||
72 | + h = h.view(-1, self.in_chs[3]) | ||
73 | + h = self.fc_out(h) | ||
74 | + return h | ||
75 | + | ||
76 | + def _make_layer(self, n_units, in_ch, out_ch, stride=1): | ||
77 | + layers = [] | ||
78 | + for i in range(int(n_units)): | ||
79 | + layers.append(ShakeBlock(in_ch, out_ch, stride=stride)) | ||
80 | + in_ch, stride = out_ch, 1 | ||
81 | + return nn.Sequential(*layers) |
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import math | ||
4 | + | ||
5 | +import torch.nn as nn | ||
6 | +import torch.nn.functional as F | ||
7 | + | ||
8 | +from FastAutoAugment.networks.shakeshake.shakeshake import ShakeShake | ||
9 | +from FastAutoAugment.networks.shakeshake.shakeshake import Shortcut | ||
10 | + | ||
11 | + | ||
12 | +class ShakeBottleNeck(nn.Module): | ||
13 | + | ||
14 | + def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1): | ||
15 | + super(ShakeBottleNeck, self).__init__() | ||
16 | + self.equal_io = in_ch == out_ch | ||
17 | + self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride) | ||
18 | + | ||
19 | + self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) | ||
20 | + self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) | ||
21 | + | ||
22 | + def forward(self, x): | ||
23 | + h1 = self.branch1(x) | ||
24 | + h2 = self.branch2(x) | ||
25 | + h = ShakeShake.apply(h1, h2, self.training) | ||
26 | + h0 = x if self.equal_io else self.shortcut(x) | ||
27 | + return h + h0 | ||
28 | + | ||
29 | + def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1): | ||
30 | + return nn.Sequential( | ||
31 | + nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False), | ||
32 | + nn.BatchNorm2d(mid_ch), | ||
33 | + nn.ReLU(inplace=False), | ||
34 | + nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False), | ||
35 | + nn.BatchNorm2d(mid_ch), | ||
36 | + nn.ReLU(inplace=False), | ||
37 | + nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False), | ||
38 | + nn.BatchNorm2d(out_ch)) | ||
39 | + | ||
40 | + | ||
41 | +class ShakeResNeXt(nn.Module): | ||
42 | + | ||
43 | + def __init__(self, depth, w_base, cardinary, label): | ||
44 | + super(ShakeResNeXt, self).__init__() | ||
45 | + n_units = (depth - 2) // 9 | ||
46 | + n_chs = [64, 128, 256, 1024] | ||
47 | + self.n_chs = n_chs | ||
48 | + self.in_ch = n_chs[0] | ||
49 | + | ||
50 | + self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1) | ||
51 | + self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary) | ||
52 | + self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2) | ||
53 | + self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2) | ||
54 | + self.fc_out = nn.Linear(n_chs[3], label) | ||
55 | + | ||
56 | + # Initialize paramters | ||
57 | + for m in self.modules(): | ||
58 | + if isinstance(m, nn.Conv2d): | ||
59 | + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
60 | + m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
61 | + elif isinstance(m, nn.BatchNorm2d): | ||
62 | + m.weight.data.fill_(1) | ||
63 | + m.bias.data.zero_() | ||
64 | + elif isinstance(m, nn.Linear): | ||
65 | + m.bias.data.zero_() | ||
66 | + | ||
67 | + def forward(self, x): | ||
68 | + h = self.c_in(x) | ||
69 | + h = self.layer1(h) | ||
70 | + h = self.layer2(h) | ||
71 | + h = self.layer3(h) | ||
72 | + h = F.relu(h) | ||
73 | + h = F.avg_pool2d(h, 8) | ||
74 | + h = h.view(-1, self.n_chs[3]) | ||
75 | + h = self.fc_out(h) | ||
76 | + return h | ||
77 | + | ||
78 | + def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1): | ||
79 | + layers = [] | ||
80 | + mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4 | ||
81 | + for i in range(n_units): | ||
82 | + layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride)) | ||
83 | + self.in_ch, stride = out_ch, 1 | ||
84 | + return nn.Sequential(*layers) |
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn as nn | ||
5 | +import torch.nn.functional as F | ||
6 | +from torch.autograd import Variable | ||
7 | + | ||
8 | + | ||
9 | +class ShakeShake(torch.autograd.Function): | ||
10 | + | ||
11 | + @staticmethod | ||
12 | + def forward(ctx, x1, x2, training=True): | ||
13 | + if training: | ||
14 | + alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_() | ||
15 | + alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1) | ||
16 | + else: | ||
17 | + alpha = 0.5 | ||
18 | + return alpha * x1 + (1 - alpha) * x2 | ||
19 | + | ||
20 | + @staticmethod | ||
21 | + def backward(ctx, grad_output): | ||
22 | + beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_() | ||
23 | + beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) | ||
24 | + beta = Variable(beta) | ||
25 | + | ||
26 | + return beta * grad_output, (1 - beta) * grad_output, None | ||
27 | + | ||
28 | + | ||
29 | +class Shortcut(nn.Module): | ||
30 | + | ||
31 | + def __init__(self, in_ch, out_ch, stride): | ||
32 | + super(Shortcut, self).__init__() | ||
33 | + self.stride = stride | ||
34 | + self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) | ||
35 | + self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) | ||
36 | + self.bn = nn.BatchNorm2d(out_ch) | ||
37 | + | ||
38 | + def forward(self, x): | ||
39 | + h = F.relu(x) | ||
40 | + | ||
41 | + h1 = F.avg_pool2d(h, 1, self.stride) | ||
42 | + h1 = self.conv1(h1) | ||
43 | + | ||
44 | + h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride) | ||
45 | + h2 = self.conv2(h2) | ||
46 | + | ||
47 | + h = torch.cat((h1, h2), 1) | ||
48 | + return self.bn(h) |
1 | +import torch.nn as nn | ||
2 | +import torch.nn.init as init | ||
3 | +import torch.nn.functional as F | ||
4 | +import numpy as np | ||
5 | + | ||
6 | + | ||
7 | +def conv3x3(in_planes, out_planes, stride=1): | ||
8 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) | ||
9 | + | ||
10 | + | ||
11 | +def conv_init(m): | ||
12 | + classname = m.__class__.__name__ | ||
13 | + if classname.find('Conv') != -1: | ||
14 | + init.xavier_uniform_(m.weight, gain=np.sqrt(2)) | ||
15 | + init.constant_(m.bias, 0) | ||
16 | + elif classname.find('BatchNorm') != -1: | ||
17 | + init.constant_(m.weight, 1) | ||
18 | + init.constant_(m.bias, 0) | ||
19 | + | ||
20 | + | ||
21 | +class WideBasic(nn.Module): | ||
22 | + def __init__(self, in_planes, planes, dropout_rate, stride=1): | ||
23 | + super(WideBasic, self).__init__() | ||
24 | + self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.9) | ||
25 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) | ||
26 | + self.dropout = nn.Dropout(p=dropout_rate) | ||
27 | + self.bn2 = nn.BatchNorm2d(planes, momentum=0.9) | ||
28 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) | ||
29 | + | ||
30 | + self.shortcut = nn.Sequential() | ||
31 | + if stride != 1 or in_planes != planes: | ||
32 | + self.shortcut = nn.Sequential( | ||
33 | + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), | ||
34 | + ) | ||
35 | + | ||
36 | + def forward(self, x): | ||
37 | + out = self.dropout(self.conv1(F.relu(self.bn1(x)))) | ||
38 | + out = self.conv2(F.relu(self.bn2(out))) | ||
39 | + out += self.shortcut(x) | ||
40 | + | ||
41 | + return out | ||
42 | + | ||
43 | + | ||
44 | +class WideResNet(nn.Module): | ||
45 | + def __init__(self, depth, widen_factor, dropout_rate, num_classes): | ||
46 | + super(WideResNet, self).__init__() | ||
47 | + self.in_planes = 16 | ||
48 | + | ||
49 | + assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' | ||
50 | + n = int((depth - 4) / 6) | ||
51 | + k = widen_factor | ||
52 | + | ||
53 | + nStages = [16, 16*k, 32*k, 64*k] | ||
54 | + | ||
55 | + self.conv1 = conv3x3(3, nStages[0]) | ||
56 | + self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) | ||
57 | + self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) | ||
58 | + self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) | ||
59 | + self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) | ||
60 | + self.linear = nn.Linear(nStages[3], num_classes) | ||
61 | + | ||
62 | + # self.apply(conv_init) | ||
63 | + | ||
64 | + def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): | ||
65 | + strides = [stride] + [1]*(num_blocks-1) | ||
66 | + layers = [] | ||
67 | + | ||
68 | + for stride in strides: | ||
69 | + layers.append(block(self.in_planes, planes, dropout_rate, stride)) | ||
70 | + self.in_planes = planes | ||
71 | + | ||
72 | + return nn.Sequential(*layers) | ||
73 | + | ||
74 | + def forward(self, x): | ||
75 | + out = self.conv1(x) | ||
76 | + out = self.layer1(out) | ||
77 | + out = self.layer2(out) | ||
78 | + out = self.layer3(out) | ||
79 | + out = F.relu(self.bn1(out)) | ||
80 | + # out = F.avg_pool2d(out, 8) | ||
81 | + out = F.adaptive_avg_pool2d(out, (1, 1)) | ||
82 | + out = out.view(out.size(0), -1) | ||
83 | + out = self.linear(out) | ||
84 | + | ||
85 | + return out |
1 | +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +# ============================================================================== | ||
15 | + | ||
16 | +import os | ||
17 | +import psutil | ||
18 | +import re | ||
19 | +import signal | ||
20 | +import subprocess | ||
21 | +import sys | ||
22 | +import threading | ||
23 | +import time | ||
24 | + | ||
25 | + | ||
26 | +GRACEFUL_TERMINATION_TIME_S = 5 | ||
27 | + | ||
28 | + | ||
29 | +def terminate_executor_shell_and_children(pid): | ||
30 | + print('terminate_executor_shell_and_children+', pid) | ||
31 | + # If the shell already ends, no need to terminate its child. | ||
32 | + try: | ||
33 | + p = psutil.Process(pid) | ||
34 | + except psutil.NoSuchProcess: | ||
35 | + print('nosuchprocess') | ||
36 | + return | ||
37 | + | ||
38 | + # Terminate children gracefully. | ||
39 | + for child in p.children(): | ||
40 | + try: | ||
41 | + child.terminate() | ||
42 | + except psutil.NoSuchProcess: | ||
43 | + pass | ||
44 | + | ||
45 | + # Wait for graceful termination. | ||
46 | + time.sleep(GRACEFUL_TERMINATION_TIME_S) | ||
47 | + | ||
48 | + # Send STOP to executor shell to stop progress. | ||
49 | + p.send_signal(signal.SIGSTOP) | ||
50 | + | ||
51 | + # Kill children recursively. | ||
52 | + for child in p.children(recursive=True): | ||
53 | + try: | ||
54 | + child.kill() | ||
55 | + except psutil.NoSuchProcess: | ||
56 | + pass | ||
57 | + | ||
58 | + # Kill shell itself. | ||
59 | + p.kill() | ||
60 | + print('terminate_executor_shell_and_children-', pid) | ||
61 | + | ||
62 | + | ||
63 | +def forward_stream(src_fd, dst_stream, prefix, index): | ||
64 | + with os.fdopen(src_fd, 'r') as src: | ||
65 | + line_buffer = '' | ||
66 | + while True: | ||
67 | + text = os.read(src.fileno(), 1000) | ||
68 | + if not isinstance(text, str): | ||
69 | + text = text.decode('utf-8') | ||
70 | + if not text: | ||
71 | + break | ||
72 | + | ||
73 | + for line in re.split('([\r\n])', text): | ||
74 | + line_buffer += line | ||
75 | + if line == '\r' or line == '\n': | ||
76 | + if index is not None: | ||
77 | + localtime = time.asctime(time.localtime(time.time())) | ||
78 | + line_buffer = '{time}[{rank}]<{prefix}>:{line}'.format( | ||
79 | + time=localtime, | ||
80 | + rank=str(index), | ||
81 | + prefix=prefix, | ||
82 | + line=line_buffer | ||
83 | + ) | ||
84 | + | ||
85 | + dst_stream.write(line_buffer) | ||
86 | + dst_stream.flush() | ||
87 | + line_buffer = '' | ||
88 | + | ||
89 | + | ||
90 | +def execute(command, env=None, stdout=None, stderr=None, index=None, event=None): | ||
91 | + # Make a pipe for the subprocess stdout/stderr. | ||
92 | + (stdout_r, stdout_w) = os.pipe() | ||
93 | + (stderr_r, stderr_w) = os.pipe() | ||
94 | + | ||
95 | + # Make a pipe for notifying the child that parent has died. | ||
96 | + (r, w) = os.pipe() | ||
97 | + | ||
98 | + middleman_pid = os.fork() | ||
99 | + if middleman_pid == 0: | ||
100 | + # Close unused file descriptors to enforce PIPE behavior. | ||
101 | + os.close(w) | ||
102 | + os.setsid() | ||
103 | + | ||
104 | + executor_shell = subprocess.Popen(command, shell=True, env=env, | ||
105 | + stdout=stdout_w, stderr=stderr_w) | ||
106 | + | ||
107 | + sigterm_received = threading.Event() | ||
108 | + | ||
109 | + def set_sigterm_received(signum, frame): | ||
110 | + sigterm_received.set() | ||
111 | + | ||
112 | + signal.signal(signal.SIGINT, set_sigterm_received) | ||
113 | + signal.signal(signal.SIGTERM, set_sigterm_received) | ||
114 | + | ||
115 | + def kill_executor_children_if_parent_dies(): | ||
116 | + # This read blocks until the pipe is closed on the other side | ||
117 | + # due to the process termination. | ||
118 | + os.read(r, 1) | ||
119 | + terminate_executor_shell_and_children(executor_shell.pid) | ||
120 | + | ||
121 | + bg = threading.Thread(target=kill_executor_children_if_parent_dies) | ||
122 | + bg.daemon = True | ||
123 | + bg.start() | ||
124 | + | ||
125 | + def kill_executor_children_if_sigterm_received(): | ||
126 | + sigterm_received.wait() | ||
127 | + terminate_executor_shell_and_children(executor_shell.pid) | ||
128 | + | ||
129 | + bg = threading.Thread(target=kill_executor_children_if_sigterm_received) | ||
130 | + bg.daemon = True | ||
131 | + bg.start() | ||
132 | + | ||
133 | + exit_code = executor_shell.wait() | ||
134 | + os._exit(exit_code) | ||
135 | + | ||
136 | + # Close unused file descriptors to enforce PIPE behavior. | ||
137 | + os.close(r) | ||
138 | + os.close(stdout_w) | ||
139 | + os.close(stderr_w) | ||
140 | + | ||
141 | + # Redirect command stdout & stderr to provided streams or sys.stdout/sys.stderr. | ||
142 | + # This is useful for Jupyter Notebook that uses custom sys.stdout/sys.stderr or | ||
143 | + # for redirecting to a file on disk. | ||
144 | + if stdout is None: | ||
145 | + stdout = sys.stdout | ||
146 | + if stderr is None: | ||
147 | + stderr = sys.stderr | ||
148 | + stdout_fwd = threading.Thread(target=forward_stream, args=(stdout_r, stdout, 'stdout', index)) | ||
149 | + stderr_fwd = threading.Thread(target=forward_stream, args=(stderr_r, stderr, 'stderr', index)) | ||
150 | + stdout_fwd.start() | ||
151 | + stderr_fwd.start() | ||
152 | + | ||
153 | + def kill_middleman_if_master_thread_terminate(): | ||
154 | + event.wait() | ||
155 | + try: | ||
156 | + os.kill(middleman_pid, signal.SIGTERM) | ||
157 | + except: | ||
158 | + # The process has already been killed elsewhere | ||
159 | + pass | ||
160 | + | ||
161 | + # TODO: Currently this requires explicitly declaration of the event and signal handler to set | ||
162 | + # the event (gloo_run.py:_launch_jobs()). Need to figure out a generalized way to hide this behind | ||
163 | + # interfaces. | ||
164 | + if event is not None: | ||
165 | + bg_thread = threading.Thread(target=kill_middleman_if_master_thread_terminate) | ||
166 | + bg_thread.daemon = True | ||
167 | + bg_thread.start() | ||
168 | + | ||
169 | + try: | ||
170 | + res, status = os.waitpid(middleman_pid, 0) | ||
171 | + except: | ||
172 | + # interrupted, send middleman TERM signal which will terminate children | ||
173 | + os.kill(middleman_pid, signal.SIGTERM) | ||
174 | + while True: | ||
175 | + try: | ||
176 | + _, status = os.waitpid(middleman_pid, 0) | ||
177 | + break | ||
178 | + except: | ||
179 | + # interrupted, wait for middleman to finish | ||
180 | + pass | ||
181 | + | ||
182 | + stdout_fwd.join() | ||
183 | + stderr_fwd.join() | ||
184 | + exit_code = status >> 8 | ||
185 | + return exit_code |
1 | +import copy | ||
2 | +import os | ||
3 | +import sys | ||
4 | +import time | ||
5 | +from collections import OrderedDict, defaultdict | ||
6 | + | ||
7 | +import torch | ||
8 | + | ||
9 | +import numpy as np | ||
10 | +from hyperopt import hp | ||
11 | +import ray | ||
12 | +import gorilla | ||
13 | +from ray.tune.trial import Trial | ||
14 | +from ray.tune.trial_runner import TrialRunner | ||
15 | +from ray.tune.suggest import HyperOptSearch | ||
16 | +from ray.tune import register_trainable, run_experiments | ||
17 | +from tqdm import tqdm | ||
18 | + | ||
19 | +from FastAutoAugment.archive import remove_deplicates, policy_decoder | ||
20 | +from FastAutoAugment.augmentations import augment_list | ||
21 | +from FastAutoAugment.common import get_logger, add_filehandler | ||
22 | +from FastAutoAugment.data import get_dataloaders | ||
23 | +from FastAutoAugment.metrics import Accumulator | ||
24 | +from FastAutoAugment.networks import get_model, num_class | ||
25 | +from FastAutoAugment.train import train_and_eval | ||
26 | +from theconf import Config as C, ConfigArgumentParser | ||
27 | + | ||
28 | + | ||
29 | +top1_valid_by_cv = defaultdict(lambda: list) | ||
30 | + | ||
31 | + | ||
32 | +def step_w_log(self): | ||
33 | + original = gorilla.get_original_attribute(ray.tune.trial_runner.TrialRunner, 'step') | ||
34 | + | ||
35 | + # log | ||
36 | + cnts = OrderedDict() | ||
37 | + for status in [Trial.RUNNING, Trial.TERMINATED, Trial.PENDING, Trial.PAUSED, Trial.ERROR]: | ||
38 | + cnt = len(list(filter(lambda x: x.status == status, self._trials))) | ||
39 | + cnts[status] = cnt | ||
40 | + best_top1_acc = 0. | ||
41 | + for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials): | ||
42 | + if not trial.last_result: | ||
43 | + continue | ||
44 | + best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid']) | ||
45 | + print('iter', self._iteration, 'top1_acc=%.3f' % best_top1_acc, cnts, end='\r') | ||
46 | + return original(self) | ||
47 | + | ||
48 | + | ||
49 | +patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner, 'step', step_w_log, settings=gorilla.Settings(allow_hit=True)) | ||
50 | +gorilla.apply(patch) | ||
51 | + | ||
52 | + | ||
53 | +logger = get_logger('Fast AutoAugment') | ||
54 | + | ||
55 | + | ||
56 | +def _get_path(dataset, model, tag): | ||
57 | + return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models/%s_%s_%s.model' % (dataset, model, tag)) # TODO | ||
58 | + | ||
59 | + | ||
60 | +@ray.remote(num_gpus=4, max_calls=1) | ||
61 | +def train_model(config, dataroot, augment, cv_ratio_test, cv_fold, save_path=None, skip_exist=False): | ||
62 | + C.get() | ||
63 | + C.get().conf = config | ||
64 | + C.get()['aug'] = augment | ||
65 | + | ||
66 | + result = train_and_eval(None, dataroot, cv_ratio_test, cv_fold, save_path=save_path, only_eval=skip_exist) | ||
67 | + return C.get()['model']['type'], cv_fold, result | ||
68 | + | ||
69 | + | ||
70 | +def eval_tta(config, augment, reporter): | ||
71 | + C.get() | ||
72 | + C.get().conf = config | ||
73 | + cv_ratio_test, cv_fold, save_path = augment['cv_ratio_test'], augment['cv_fold'], augment['save_path'] | ||
74 | + | ||
75 | + # setup - provided augmentation rules | ||
76 | + C.get()['aug'] = policy_decoder(augment, augment['num_policy'], augment['num_op']) | ||
77 | + | ||
78 | + # eval | ||
79 | + model = get_model(C.get()['model'], num_class(C.get()['dataset'])) | ||
80 | + ckpt = torch.load(save_path) | ||
81 | + if 'model' in ckpt: | ||
82 | + model.load_state_dict(ckpt['model']) | ||
83 | + else: | ||
84 | + model.load_state_dict(ckpt) | ||
85 | + model.eval() | ||
86 | + | ||
87 | + loaders = [] | ||
88 | + for _ in range(augment['num_policy']): # TODO | ||
89 | + _, tl, validloader, tl2 = get_dataloaders(C.get()['dataset'], C.get()['batch'], augment['dataroot'], cv_ratio_test, split_idx=cv_fold) | ||
90 | + loaders.append(iter(validloader)) | ||
91 | + del tl, tl2 | ||
92 | + | ||
93 | + start_t = time.time() | ||
94 | + metrics = Accumulator() | ||
95 | + loss_fn = torch.nn.CrossEntropyLoss(reduction='none') | ||
96 | + try: | ||
97 | + while True: | ||
98 | + losses = [] | ||
99 | + corrects = [] | ||
100 | + for loader in loaders: | ||
101 | + data, label = next(loader) | ||
102 | + data = data.cuda() | ||
103 | + label = label.cuda() | ||
104 | + | ||
105 | + pred = model(data) | ||
106 | + | ||
107 | + loss = loss_fn(pred, label) | ||
108 | + losses.append(loss.detach().cpu().numpy()) | ||
109 | + | ||
110 | + _, pred = pred.topk(1, 1, True, True) | ||
111 | + pred = pred.t() | ||
112 | + correct = pred.eq(label.view(1, -1).expand_as(pred)).detach().cpu().numpy() | ||
113 | + corrects.append(correct) | ||
114 | + del loss, correct, pred, data, label | ||
115 | + | ||
116 | + losses = np.concatenate(losses) | ||
117 | + losses_min = np.min(losses, axis=0).squeeze() | ||
118 | + | ||
119 | + corrects = np.concatenate(corrects) | ||
120 | + corrects_max = np.max(corrects, axis=0).squeeze() | ||
121 | + metrics.add_dict({ | ||
122 | + 'minus_loss': -1 * np.sum(losses_min), | ||
123 | + 'correct': np.sum(corrects_max), | ||
124 | + 'cnt': len(corrects_max) | ||
125 | + }) | ||
126 | + del corrects, corrects_max | ||
127 | + except StopIteration: | ||
128 | + pass | ||
129 | + | ||
130 | + del model | ||
131 | + metrics = metrics / 'cnt' | ||
132 | + gpu_secs = (time.time() - start_t) * torch.cuda.device_count() | ||
133 | + reporter(minus_loss=metrics['minus_loss'], top1_valid=metrics['correct'], elapsed_time=gpu_secs, done=True) | ||
134 | + return metrics['correct'] | ||
135 | + | ||
136 | + | ||
137 | +if __name__ == '__main__': | ||
138 | + import json | ||
139 | + from pystopwatch2 import PyStopwatch | ||
140 | + w = PyStopwatch() | ||
141 | + | ||
142 | + parser = ConfigArgumentParser(conflict_handler='resolve') | ||
143 | + parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder') | ||
144 | + parser.add_argument('--until', type=int, default=5) | ||
145 | + parser.add_argument('--num-op', type=int, default=2) | ||
146 | + parser.add_argument('--num-policy', type=int, default=5) | ||
147 | + parser.add_argument('--num-search', type=int, default=200) | ||
148 | + parser.add_argument('--cv-ratio', type=float, default=0.4) | ||
149 | + parser.add_argument('--decay', type=float, default=-1) | ||
150 | + parser.add_argument('--redis', type=str, default='gpu-cloud-vnode30.dakao.io:23655') | ||
151 | + parser.add_argument('--per-class', action='store_true') | ||
152 | + parser.add_argument('--resume', action='store_true') | ||
153 | + parser.add_argument('--smoke-test', action='store_true') | ||
154 | + args = parser.parse_args() | ||
155 | + | ||
156 | + if args.decay > 0: | ||
157 | + logger.info('decay=%.4f' % args.decay) | ||
158 | + C.get()['optimizer']['decay'] = args.decay | ||
159 | + | ||
160 | + add_filehandler(logger, os.path.join('models', '%s_%s_cv%.1f.log' % (C.get()['dataset'], C.get()['model']['type'], args.cv_ratio))) | ||
161 | + logger.info('configuration...') | ||
162 | + logger.info(json.dumps(C.get().conf, sort_keys=True, indent=4)) | ||
163 | + logger.info('initialize ray...') | ||
164 | + ray.init(redis_address=args.redis) | ||
165 | + | ||
166 | + num_result_per_cv = 10 | ||
167 | + cv_num = 5 | ||
168 | + copied_c = copy.deepcopy(C.get().conf) | ||
169 | + | ||
170 | + logger.info('search augmentation policies, dataset=%s model=%s' % (C.get()['dataset'], C.get()['model']['type'])) | ||
171 | + logger.info('----- Train without Augmentations cv=%d ratio(test)=%.1f -----' % (cv_num, args.cv_ratio)) | ||
172 | + w.start(tag='train_no_aug') | ||
173 | + paths = [_get_path(C.get()['dataset'], C.get()['model']['type'], 'ratio%.1f_fold%d' % (args.cv_ratio, i)) for i in range(cv_num)] | ||
174 | + print(paths) | ||
175 | + reqs = [ | ||
176 | + train_model.remote(copy.deepcopy(copied_c), args.dataroot, C.get()['aug'], args.cv_ratio, i, save_path=paths[i], skip_exist=True) | ||
177 | + for i in range(cv_num)] | ||
178 | + | ||
179 | + tqdm_epoch = tqdm(range(C.get()['epoch'])) | ||
180 | + is_done = False | ||
181 | + for epoch in tqdm_epoch: | ||
182 | + while True: | ||
183 | + epochs_per_cv = OrderedDict() | ||
184 | + for cv_idx in range(cv_num): | ||
185 | + try: | ||
186 | + latest_ckpt = torch.load(paths[cv_idx]) | ||
187 | + if 'epoch' not in latest_ckpt: | ||
188 | + epochs_per_cv['cv%d' % (cv_idx + 1)] = C.get()['epoch'] | ||
189 | + continue | ||
190 | + epochs_per_cv['cv%d' % (cv_idx+1)] = latest_ckpt['epoch'] | ||
191 | + except Exception as e: | ||
192 | + continue | ||
193 | + tqdm_epoch.set_postfix(epochs_per_cv) | ||
194 | + if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= C.get()['epoch']: | ||
195 | + is_done = True | ||
196 | + if len(epochs_per_cv) == cv_num and min(epochs_per_cv.values()) >= epoch: | ||
197 | + break | ||
198 | + time.sleep(10) | ||
199 | + if is_done: | ||
200 | + break | ||
201 | + | ||
202 | + logger.info('getting results...') | ||
203 | + pretrain_results = ray.get(reqs) | ||
204 | + for r_model, r_cv, r_dict in pretrain_results: | ||
205 | + logger.info('model=%s cv=%d top1_train=%.4f top1_valid=%.4f' % (r_model, r_cv+1, r_dict['top1_train'], r_dict['top1_valid'])) | ||
206 | + logger.info('processed in %.4f secs' % w.pause('train_no_aug')) | ||
207 | + | ||
208 | + if args.until == 1: | ||
209 | + sys.exit(0) | ||
210 | + | ||
211 | + logger.info('----- Search Test-Time Augmentation Policies -----') | ||
212 | + w.start(tag='search') | ||
213 | + | ||
214 | + ops = augment_list(False) | ||
215 | + space = {} | ||
216 | + for i in range(args.num_policy): | ||
217 | + for j in range(args.num_op): | ||
218 | + space['policy_%d_%d' % (i, j)] = hp.choice('policy_%d_%d' % (i, j), list(range(0, len(ops)))) | ||
219 | + space['prob_%d_%d' % (i, j)] = hp.uniform('prob_%d_ %d' % (i, j), 0.0, 1.0) | ||
220 | + space['level_%d_%d' % (i, j)] = hp.uniform('level_%d_ %d' % (i, j), 0.0, 1.0) | ||
221 | + | ||
222 | + final_policy_set = [] | ||
223 | + total_computation = 0 | ||
224 | + reward_attr = 'top1_valid' # top1_valid or minus_loss | ||
225 | + for _ in range(1): # run multiple times. | ||
226 | + for cv_fold in range(cv_num): | ||
227 | + name = "search_%s_%s_fold%d_ratio%.1f" % (C.get()['dataset'], C.get()['model']['type'], cv_fold, args.cv_ratio) | ||
228 | + print(name) | ||
229 | + register_trainable(name, lambda augs, rpt: eval_tta(copy.deepcopy(copied_c), augs, rpt)) | ||
230 | + algo = HyperOptSearch(space, max_concurrent=4*20, reward_attr=reward_attr) | ||
231 | + | ||
232 | + exp_config = { | ||
233 | + name: { | ||
234 | + 'run': name, | ||
235 | + 'num_samples': 4 if args.smoke_test else args.num_search, | ||
236 | + 'resources_per_trial': {'gpu': 1}, | ||
237 | + 'stop': {'training_iteration': args.num_policy}, | ||
238 | + 'config': { | ||
239 | + 'dataroot': args.dataroot, 'save_path': paths[cv_fold], | ||
240 | + 'cv_ratio_test': args.cv_ratio, 'cv_fold': cv_fold, | ||
241 | + 'num_op': args.num_op, 'num_policy': args.num_policy | ||
242 | + }, | ||
243 | + } | ||
244 | + } | ||
245 | + results = run_experiments(exp_config, search_alg=algo, scheduler=None, verbose=0, queue_trials=True, resume=args.resume, raise_on_failed_trial=False) | ||
246 | + print() | ||
247 | + results = [x for x in results if x.last_result is not None] | ||
248 | + results = sorted(results, key=lambda x: x.last_result[reward_attr], reverse=True) | ||
249 | + | ||
250 | + # calculate computation usage | ||
251 | + for result in results: | ||
252 | + total_computation += result.last_result['elapsed_time'] | ||
253 | + | ||
254 | + for result in results[:num_result_per_cv]: | ||
255 | + final_policy = policy_decoder(result.config, args.num_policy, args.num_op) | ||
256 | + logger.info('loss=%.12f top1_valid=%.4f %s' % (result.last_result['minus_loss'], result.last_result['top1_valid'], final_policy)) | ||
257 | + | ||
258 | + final_policy = remove_deplicates(final_policy) | ||
259 | + final_policy_set.extend(final_policy) | ||
260 | + | ||
261 | + logger.info(json.dumps(final_policy_set)) | ||
262 | + logger.info('final_policy=%d' % len(final_policy_set)) | ||
263 | + logger.info('processed in %.4f secs, gpu hours=%.4f' % (w.pause('search'), total_computation / 3600.)) | ||
264 | + logger.info('----- Train with Augmentations model=%s dataset=%s aug=%s ratio(test)=%.1f -----' % (C.get()['model']['type'], C.get()['dataset'], C.get()['aug'], args.cv_ratio)) | ||
265 | + w.start(tag='train_aug') | ||
266 | + | ||
267 | + num_experiments = 5 | ||
268 | + default_path = [_get_path(C.get()['dataset'], C.get()['model']['type'], 'ratio%.1f_default%d' % (args.cv_ratio, _)) for _ in range(num_experiments)] | ||
269 | + augment_path = [_get_path(C.get()['dataset'], C.get()['model']['type'], 'ratio%.1f_augment%d' % (args.cv_ratio, _)) for _ in range(num_experiments)] | ||
270 | + reqs = [train_model.remote(copy.deepcopy(copied_c), args.dataroot, C.get()['aug'], 0.0, 0, save_path=default_path[_], skip_exist=True) for _ in range(num_experiments)] + \ | ||
271 | + [train_model.remote(copy.deepcopy(copied_c), args.dataroot, final_policy_set, 0.0, 0, save_path=augment_path[_]) for _ in range(num_experiments)] | ||
272 | + | ||
273 | + tqdm_epoch = tqdm(range(C.get()['epoch'])) | ||
274 | + is_done = False | ||
275 | + for epoch in tqdm_epoch: | ||
276 | + while True: | ||
277 | + epochs = OrderedDict() | ||
278 | + for exp_idx in range(num_experiments): | ||
279 | + try: | ||
280 | + if os.path.exists(default_path[exp_idx]): | ||
281 | + latest_ckpt = torch.load(default_path[exp_idx]) | ||
282 | + epochs['default_exp%d' % (exp_idx + 1)] = latest_ckpt['epoch'] | ||
283 | + except: | ||
284 | + pass | ||
285 | + try: | ||
286 | + if os.path.exists(augment_path[exp_idx]): | ||
287 | + latest_ckpt = torch.load(augment_path[exp_idx]) | ||
288 | + epochs['augment_exp%d' % (exp_idx + 1)] = latest_ckpt['epoch'] | ||
289 | + except: | ||
290 | + pass | ||
291 | + | ||
292 | + tqdm_epoch.set_postfix(epochs) | ||
293 | + if len(epochs) == num_experiments*2 and min(epochs.values()) >= C.get()['epoch']: | ||
294 | + is_done = True | ||
295 | + if len(epochs) == num_experiments*2 and min(epochs.values()) >= epoch: | ||
296 | + break | ||
297 | + time.sleep(10) | ||
298 | + if is_done: | ||
299 | + break | ||
300 | + | ||
301 | + logger.info('getting results...') | ||
302 | + final_results = ray.get(reqs) | ||
303 | + | ||
304 | + for train_mode in ['default', 'augment']: | ||
305 | + avg = 0. | ||
306 | + for _ in range(num_experiments): | ||
307 | + r_model, r_cv, r_dict = final_results.pop(0) | ||
308 | + logger.info('[%s] top1_train=%.4f top1_test=%.4f' % (train_mode, r_dict['top1_train'], r_dict['top1_test'])) | ||
309 | + avg += r_dict['top1_test'] | ||
310 | + avg /= num_experiments | ||
311 | + logger.info('[%s] top1_test average=%.4f (#experiments=%d)' % (train_mode, avg, num_experiments)) | ||
312 | + logger.info('processed in %.4f secs' % w.pause('train_aug')) | ||
313 | + | ||
314 | + logger.info(w) |
File mode changed
1 | +import torch | ||
2 | +from torch.optim.optimizer import Optimizer | ||
3 | + | ||
4 | + | ||
5 | +class RMSpropTF(Optimizer): | ||
6 | + r"""Implements RMSprop algorithm. | ||
7 | + Reimplement original formulation to match TF rmsprop | ||
8 | + Proposed by G. Hinton in his | ||
9 | + `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_. | ||
10 | + The centered version first appears in `Generating Sequences | ||
11 | + With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_. | ||
12 | + The implementation here takes the square root of the gradient average before | ||
13 | + adding epsilon (note that TensorFlow interchanges these two operations). The effective | ||
14 | + learning rate is thus :math:`\alpha/(\sqrt{v + \epsilon})` where :math:`\alpha` from :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` | ||
15 | + is the scheduled learning rate and :math:`v` is the weighted moving average | ||
16 | + of the squared gradient. | ||
17 | + Arguments: | ||
18 | + params (iterable): iterable of parameters to optimize or dicts defining | ||
19 | + parameter groups | ||
20 | + lr (float, optional): learning rate (default: 1e-2) | ||
21 | + momentum (float, optional): momentum factor (default: 0) | ||
22 | + alpha (float, optional): smoothing constant (default: 0.99) | ||
23 | + eps (float, optional): term added to the denominator to improve | ||
24 | + numerical stability (default: 1e-8) | ||
25 | + centered (bool, optional) : if ``True``, compute the centered RMSProp, | ||
26 | + the gradient is normalized by an estimation of its variance | ||
27 | + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||
28 | + """ | ||
29 | + | ||
30 | + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0.0): | ||
31 | + if not 0.0 <= lr: | ||
32 | + raise ValueError("Invalid learning rate: {}".format(lr)) | ||
33 | + if not 0.0 <= eps: | ||
34 | + raise ValueError("Invalid epsilon value: {}".format(eps)) | ||
35 | + if not 0.0 < momentum: | ||
36 | + raise ValueError("Invalid momentum value: {}".format(momentum)) | ||
37 | + if not 0.0 <= alpha: | ||
38 | + raise ValueError("Invalid alpha value: {}".format(alpha)) | ||
39 | + assert momentum > 0.0 | ||
40 | + | ||
41 | + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, weight_decay=weight_decay) | ||
42 | + super(RMSpropTF, self).__init__(params, defaults) | ||
43 | + self.initialized = False | ||
44 | + | ||
45 | + def __setstate__(self, state): | ||
46 | + super(RMSpropTF, self).__setstate__(state) | ||
47 | + for group in self.param_groups: | ||
48 | + group.setdefault('momentum', 0) | ||
49 | + | ||
50 | + def load_state_dict(self, state_dict): | ||
51 | + super(RMSpropTF, self).load_state_dict(state_dict) | ||
52 | + self.initialized = True | ||
53 | + | ||
54 | + def step(self, closure=None): | ||
55 | + """Performs a single optimization step. | ||
56 | + We modified pytorch's RMSProp to be same as Tensorflow's | ||
57 | + See : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/training_ops.cc#L485 | ||
58 | + | ||
59 | + Arguments: | ||
60 | + closure (callable, optional): A closure that reevaluates the model | ||
61 | + and returns the loss. | ||
62 | + """ | ||
63 | + loss = None | ||
64 | + if closure is not None: | ||
65 | + loss = closure() | ||
66 | + | ||
67 | + for group in self.param_groups: | ||
68 | + for p in group['params']: | ||
69 | + if p.grad is None: | ||
70 | + continue | ||
71 | + grad = p.grad.data | ||
72 | + if grad.is_sparse: | ||
73 | + raise RuntimeError('RMSprop does not support sparse gradients') | ||
74 | + state = self.state[p] | ||
75 | + | ||
76 | + # State initialization | ||
77 | + if len(state) == 0: | ||
78 | + assert not self.initialized | ||
79 | + state['step'] = 0 | ||
80 | + state['ms'] = torch.ones_like(p.data) #, memory_format=torch.preserve_format) | ||
81 | + state['mom'] = torch.zeros_like(p.data) #, memory_format=torch.preserve_format) | ||
82 | + | ||
83 | + # weight decay ----- | ||
84 | + if group['weight_decay'] > 0: | ||
85 | + grad = grad.add(group['weight_decay'], p.data) | ||
86 | + | ||
87 | + rho = group['alpha'] | ||
88 | + ms = state['ms'] | ||
89 | + mom = state['mom'] | ||
90 | + state['step'] += 1 | ||
91 | + | ||
92 | + # ms.mul_(rho).addcmul_(1 - rho, grad, grad) | ||
93 | + ms.add_(torch.mul(grad, grad).add_(-ms) * (1. - rho)) | ||
94 | + assert group['momentum'] > 0 | ||
95 | + | ||
96 | + # new rmsprop | ||
97 | + mom.mul_(group['momentum']).addcdiv_(group['lr'], grad, (ms + group['eps']).sqrt()) | ||
98 | + | ||
99 | + p.data.add_(-1.0, mom) | ||
100 | + | ||
101 | + return loss |
1 | +import torch | ||
2 | +from torch.nn import BatchNorm2d | ||
3 | +from torch.nn.parameter import Parameter | ||
4 | +import torch.distributed as dist | ||
5 | +from torch import nn | ||
6 | + | ||
7 | + | ||
8 | +class TpuBatchNormalization(nn.Module): | ||
9 | + # Ref : https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py#L113 | ||
10 | + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | ||
11 | + track_running_stats=True): | ||
12 | + super(TpuBatchNormalization, self).__init__() # num_features, eps, momentum, affine, track_running_stats) | ||
13 | + | ||
14 | + self.weight = Parameter(torch.ones(num_features)) | ||
15 | + self.bias = Parameter(torch.zeros(num_features)) | ||
16 | + | ||
17 | + self.register_buffer('running_mean', torch.zeros(num_features)) | ||
18 | + self.register_buffer('running_var', torch.ones(num_features)) | ||
19 | + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) | ||
20 | + | ||
21 | + self.eps = eps | ||
22 | + self.momentum = momentum | ||
23 | + | ||
24 | + def _reduce_avg(self, t): | ||
25 | + dist.all_reduce(t, dist.ReduceOp.SUM) | ||
26 | + t.mul_(1. / dist.get_world_size()) | ||
27 | + | ||
28 | + def forward(self, input): | ||
29 | + if not self.training or not dist.is_initialized(): | ||
30 | + bn = (input - self.running_mean.view(1, self.running_mean.shape[0], 1, 1)) / \ | ||
31 | + (torch.sqrt(self.running_var.view(1, self.running_var.shape[0], 1, 1) + self.eps)) | ||
32 | + # print(self.weight.shape, self.bias.shape) | ||
33 | + return bn.mul(self.weight.view(1, self.weight.shape[0], 1, 1)).add(self.bias.view(1, self.bias.shape[0], 1, 1)) | ||
34 | + | ||
35 | + shard_mean, shard_invstd = torch.batch_norm_stats(input, self.eps) | ||
36 | + shard_vars = (1. / shard_invstd) ** 2 - self.eps | ||
37 | + | ||
38 | + shard_square_of_mean = torch.mul(shard_mean, shard_mean) | ||
39 | + shard_mean_of_square = shard_vars + shard_square_of_mean | ||
40 | + | ||
41 | + group_mean = shard_mean.clone().detach() | ||
42 | + self._reduce_avg(group_mean) | ||
43 | + group_mean_of_square = shard_mean_of_square.clone().detach() | ||
44 | + self._reduce_avg(group_mean_of_square) | ||
45 | + group_vars = group_mean_of_square - torch.mul(group_mean, group_mean) | ||
46 | + | ||
47 | + group_mean = group_mean.detach() | ||
48 | + group_vars = group_vars.detach() | ||
49 | + | ||
50 | + # print(self.running_mean.shape, self.running_var.shape) | ||
51 | + self.running_mean.mul_(1. - self.momentum).add_(group_mean.mul(self.momentum)) | ||
52 | + self.running_var.mul_(1. - self.momentum).add_(group_vars.mul(self.momentum)) | ||
53 | + self.num_batches_tracked.add_(1) | ||
54 | + | ||
55 | + # print(input.shape, group_mean.view(1, group_mean.shape[0], 1, 1).shape, group_vars.view(1, group_vars.shape[0], 1, 1).shape, self.eps) | ||
56 | + bn = (input - group_mean.view(1, group_mean.shape[0], 1, 1)) / (torch.sqrt(group_vars.view(1, group_vars.shape[0], 1, 1) + self.eps)) | ||
57 | + # print(self.weight.shape, self.bias.shape) | ||
58 | + return bn.mul(self.weight.view(1, self.weight.shape[0], 1, 1)).add(self.bias.view(1, self.bias.shape[0], 1, 1)) |
1 | +import sys | ||
2 | + | ||
3 | + | ||
4 | +sys.path.append('/data/private/fast-autoaugment-public') # TODO | ||
5 | + | ||
6 | +import itertools | ||
7 | +import json | ||
8 | +import logging | ||
9 | +import math | ||
10 | +import os | ||
11 | +from collections import OrderedDict | ||
12 | + | ||
13 | +import torch | ||
14 | +from torch import nn, optim | ||
15 | +from torch.nn.parallel.data_parallel import DataParallel | ||
16 | +from torch.nn.parallel import DistributedDataParallel | ||
17 | +import torch.distributed as dist | ||
18 | + | ||
19 | +from tqdm import tqdm | ||
20 | +from theconf import Config as C, ConfigArgumentParser | ||
21 | + | ||
22 | +from FastAutoAugment.common import get_logger, EMA, add_filehandler | ||
23 | +from FastAutoAugment.data import get_dataloaders | ||
24 | +from FastAutoAugment.lr_scheduler import adjust_learning_rate_resnet | ||
25 | +from FastAutoAugment.metrics import accuracy, Accumulator, CrossEntropyLabelSmooth | ||
26 | +from FastAutoAugment.networks import get_model, num_class | ||
27 | +from FastAutoAugment.tf_port.rmsprop import RMSpropTF | ||
28 | +from FastAutoAugment.aug_mixup import CrossEntropyMixUpLabelSmooth, mixup | ||
29 | +from warmup_scheduler import GradualWarmupScheduler | ||
30 | + | ||
31 | +logger = get_logger('Fast AutoAugment') | ||
32 | +logger.setLevel(logging.INFO) | ||
33 | + | ||
34 | + | ||
35 | +def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None, is_master=True, ema=None, wd=0.0, tqdm_disabled=False): | ||
36 | + if verbose: | ||
37 | + loader = tqdm(loader, disable=tqdm_disabled) | ||
38 | + loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) | ||
39 | + | ||
40 | + params_without_bn = [params for name, params in model.named_parameters() if not ('_bn' in name or '.bn' in name)] | ||
41 | + | ||
42 | + loss_ema = None | ||
43 | + metrics = Accumulator() | ||
44 | + cnt = 0 | ||
45 | + total_steps = len(loader) | ||
46 | + steps = 0 | ||
47 | + for data, label in loader: | ||
48 | + steps += 1 | ||
49 | + data, label = data.cuda(), label.cuda() | ||
50 | + | ||
51 | + if C.get().conf.get('mixup', 0.0) <= 0.0 or optimizer is None: | ||
52 | + preds = model(data) | ||
53 | + loss = loss_fn(preds, label) | ||
54 | + else: # mixup | ||
55 | + data, targets, shuffled_targets, lam = mixup(data, label, C.get()['mixup']) | ||
56 | + preds = model(data) | ||
57 | + loss = loss_fn(preds, targets, shuffled_targets, lam) | ||
58 | + del shuffled_targets, lam | ||
59 | + | ||
60 | + if optimizer: | ||
61 | + loss += wd * (1. / 2.) * sum([torch.sum(p ** 2) for p in params_without_bn]) | ||
62 | + loss.backward() | ||
63 | + if C.get()['optimizer']['clip'] > 0: | ||
64 | + nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer']['clip']) | ||
65 | + optimizer.step() | ||
66 | + optimizer.zero_grad() | ||
67 | + | ||
68 | + if ema is not None: | ||
69 | + ema(model, (epoch - 1) * total_steps + steps) | ||
70 | + | ||
71 | + top1, top5 = accuracy(preds, label, (1, 5)) | ||
72 | + metrics.add_dict({ | ||
73 | + 'loss': loss.item() * len(data), | ||
74 | + 'top1': top1.item() * len(data), | ||
75 | + 'top5': top5.item() * len(data), | ||
76 | + }) | ||
77 | + cnt += len(data) | ||
78 | + if loss_ema: | ||
79 | + loss_ema = loss_ema * 0.9 + loss.item() * 0.1 | ||
80 | + else: | ||
81 | + loss_ema = loss.item() | ||
82 | + if verbose: | ||
83 | + postfix = metrics / cnt | ||
84 | + if optimizer: | ||
85 | + postfix['lr'] = optimizer.param_groups[0]['lr'] | ||
86 | + postfix['loss_ema'] = loss_ema | ||
87 | + loader.set_postfix(postfix) | ||
88 | + | ||
89 | + if scheduler is not None: | ||
90 | + scheduler.step(epoch - 1 + float(steps) / total_steps) | ||
91 | + | ||
92 | + del preds, loss, top1, top5, data, label | ||
93 | + | ||
94 | + if tqdm_disabled and verbose: | ||
95 | + if optimizer: | ||
96 | + logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) | ||
97 | + else: | ||
98 | + logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) | ||
99 | + | ||
100 | + metrics /= cnt | ||
101 | + if optimizer: | ||
102 | + metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] | ||
103 | + if verbose: | ||
104 | + for key, value in metrics.items(): | ||
105 | + writer.add_scalar(key, value, epoch) | ||
106 | + return metrics | ||
107 | + | ||
108 | + | ||
109 | +def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False, local_rank=-1, evaluation_interval=5): | ||
110 | + total_batch = C.get()["batch"] | ||
111 | + if local_rank >= 0: | ||
112 | + dist.init_process_group(backend='nccl', init_method='env://', world_size=int(os.environ['WORLD_SIZE'])) | ||
113 | + device = torch.device('cuda', local_rank) | ||
114 | + torch.cuda.set_device(device) | ||
115 | + | ||
116 | + C.get()['lr'] *= dist.get_world_size() | ||
117 | + logger.info(f'local batch={C.get()["batch"]} world_size={dist.get_world_size()} ----> total batch={C.get()["batch"] * dist.get_world_size()}') | ||
118 | + total_batch = C.get()["batch"] * dist.get_world_size() | ||
119 | + | ||
120 | + is_master = local_rank < 0 or dist.get_rank() == 0 | ||
121 | + if is_master: | ||
122 | + add_filehandler(logger, args.save + '.log') | ||
123 | + | ||
124 | + if not reporter: | ||
125 | + reporter = lambda **kwargs: 0 | ||
126 | + | ||
127 | + max_epoch = C.get()['epoch'] | ||
128 | + trainsampler, trainloader, validloader, testloader_ = get_dataloaders(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold, multinode=(local_rank >= 0)) | ||
129 | + | ||
130 | + # create a model & an optimizer | ||
131 | + model = get_model(C.get()['model'], num_class(C.get()['dataset']), local_rank=local_rank) | ||
132 | + model_ema = get_model(C.get()['model'], num_class(C.get()['dataset']), local_rank=-1) | ||
133 | + model_ema.eval() | ||
134 | + | ||
135 | + criterion_ce = criterion = CrossEntropyLabelSmooth(num_class(C.get()['dataset']), C.get().conf.get('lb_smooth', 0)) | ||
136 | + if C.get().conf.get('mixup', 0.0) > 0.0: | ||
137 | + criterion = CrossEntropyMixUpLabelSmooth(num_class(C.get()['dataset']), C.get().conf.get('lb_smooth', 0)) | ||
138 | + if C.get()['optimizer']['type'] == 'sgd': | ||
139 | + optimizer = optim.SGD( | ||
140 | + model.parameters(), | ||
141 | + lr=C.get()['lr'], | ||
142 | + momentum=C.get()['optimizer'].get('momentum', 0.9), | ||
143 | + weight_decay=0.0, | ||
144 | + nesterov=C.get()['optimizer'].get('nesterov', True) | ||
145 | + ) | ||
146 | + elif C.get()['optimizer']['type'] == 'rmsprop': | ||
147 | + optimizer = RMSpropTF( | ||
148 | + model.parameters(), | ||
149 | + lr=C.get()['lr'], | ||
150 | + weight_decay=0.0, | ||
151 | + alpha=0.9, momentum=0.9, | ||
152 | + eps=0.001 | ||
153 | + ) | ||
154 | + else: | ||
155 | + raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type']) | ||
156 | + | ||
157 | + lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') | ||
158 | + if lr_scheduler_type == 'cosine': | ||
159 | + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.) | ||
160 | + elif lr_scheduler_type == 'resnet': | ||
161 | + scheduler = adjust_learning_rate_resnet(optimizer) | ||
162 | + elif lr_scheduler_type == 'efficientnet': | ||
163 | + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 0.97 ** int((x + C.get()['lr_schedule']['warmup']['epoch']) / 2.4)) | ||
164 | + else: | ||
165 | + raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type) | ||
166 | + | ||
167 | + if C.get()['lr_schedule'].get('warmup', None) and C.get()['lr_schedule']['warmup']['epoch'] > 0: | ||
168 | + scheduler = GradualWarmupScheduler( | ||
169 | + optimizer, | ||
170 | + multiplier=C.get()['lr_schedule']['warmup']['multiplier'], | ||
171 | + total_epoch=C.get()['lr_schedule']['warmup']['epoch'], | ||
172 | + after_scheduler=scheduler | ||
173 | + ) | ||
174 | + | ||
175 | + if not tag or not is_master: | ||
176 | + from FastAutoAugment.metrics import SummaryWriterDummy as SummaryWriter | ||
177 | + logger.warning('tag not provided, no tensorboard log.') | ||
178 | + else: | ||
179 | + from tensorboardX import SummaryWriter | ||
180 | + writers = [SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']] | ||
181 | + | ||
182 | + if C.get()['optimizer']['ema'] > 0.0 and is_master: | ||
183 | + # https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4?u=ildoonet | ||
184 | + ema = EMA(C.get()['optimizer']['ema']) | ||
185 | + else: | ||
186 | + ema = None | ||
187 | + | ||
188 | + result = OrderedDict() | ||
189 | + epoch_start = 1 | ||
190 | + if save_path != 'test.pth': # and is_master: --> should load all data(not able to be broadcasted) | ||
191 | + if save_path and os.path.exists(save_path): | ||
192 | + logger.info('%s file found. loading...' % save_path) | ||
193 | + data = torch.load(save_path) | ||
194 | + key = 'model' if 'model' in data else 'state_dict' | ||
195 | + | ||
196 | + if 'epoch' not in data: | ||
197 | + model.load_state_dict(data) | ||
198 | + else: | ||
199 | + logger.info('checkpoint epoch@%d' % data['epoch']) | ||
200 | + if not isinstance(model, (DataParallel, DistributedDataParallel)): | ||
201 | + model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()}) | ||
202 | + else: | ||
203 | + model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()}) | ||
204 | + logger.info('optimizer.load_state_dict+') | ||
205 | + optimizer.load_state_dict(data['optimizer']) | ||
206 | + if data['epoch'] < C.get()['epoch']: | ||
207 | + epoch_start = data['epoch'] | ||
208 | + else: | ||
209 | + only_eval = True | ||
210 | + if ema is not None: | ||
211 | + ema.shadow = data.get('ema', {}) if isinstance(data.get('ema', {}), dict) else data['ema'].state_dict() | ||
212 | + del data | ||
213 | + else: | ||
214 | + logger.info('"%s" file not found. skip to pretrain weights...' % save_path) | ||
215 | + if only_eval: | ||
216 | + logger.warning('model checkpoint not found. only-evaluation mode is off.') | ||
217 | + only_eval = False | ||
218 | + | ||
219 | + if local_rank >= 0: | ||
220 | + for name, x in model.state_dict().items(): | ||
221 | + dist.broadcast(x, 0) | ||
222 | + logger.info(f'multinode init. local_rank={dist.get_rank()} is_master={is_master}') | ||
223 | + torch.cuda.synchronize() | ||
224 | + | ||
225 | + tqdm_disabled = bool(os.environ.get('TASK_NAME', '')) and local_rank != 0 # KakaoBrain Environment | ||
226 | + | ||
227 | + if only_eval: | ||
228 | + logger.info('evaluation only+') | ||
229 | + model.eval() | ||
230 | + rs = dict() | ||
231 | + rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0], is_master=is_master) | ||
232 | + | ||
233 | + with torch.no_grad(): | ||
234 | + rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1], is_master=is_master) | ||
235 | + rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2], is_master=is_master) | ||
236 | + if ema is not None and len(ema) > 0: | ||
237 | + model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()}) | ||
238 | + rs['valid'] = run_epoch(model_ema, validloader, criterion_ce, None, desc_default='valid(EMA)', epoch=0, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled) | ||
239 | + rs['test'] = run_epoch(model_ema, testloader_, criterion_ce, None, desc_default='*test(EMA)', epoch=0, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled) | ||
240 | + for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): | ||
241 | + if setname not in rs: | ||
242 | + continue | ||
243 | + result['%s_%s' % (key, setname)] = rs[setname][key] | ||
244 | + result['epoch'] = 0 | ||
245 | + return result | ||
246 | + | ||
247 | + # train loop | ||
248 | + best_top1 = 0 | ||
249 | + for epoch in range(epoch_start, max_epoch + 1): | ||
250 | + if local_rank >= 0: | ||
251 | + trainsampler.set_epoch(epoch) | ||
252 | + | ||
253 | + model.train() | ||
254 | + rs = dict() | ||
255 | + rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=(is_master and local_rank <= 0), scheduler=scheduler, ema=ema, wd=C.get()['optimizer']['decay'], tqdm_disabled=tqdm_disabled) | ||
256 | + model.eval() | ||
257 | + | ||
258 | + if math.isnan(rs['train']['loss']): | ||
259 | + raise Exception('train loss is NaN.') | ||
260 | + | ||
261 | + if ema is not None and C.get()['optimizer']['ema_interval'] > 0 and epoch % C.get()['optimizer']['ema_interval'] == 0: | ||
262 | + logger.info(f'ema synced+ rank={dist.get_rank()}') | ||
263 | + if ema is not None: | ||
264 | + model.load_state_dict(ema.state_dict()) | ||
265 | + for name, x in model.state_dict().items(): | ||
266 | + # print(name) | ||
267 | + dist.broadcast(x, 0) | ||
268 | + torch.cuda.synchronize() | ||
269 | + logger.info(f'ema synced- rank={dist.get_rank()}') | ||
270 | + | ||
271 | + if is_master and (epoch % evaluation_interval == 0 or epoch == max_epoch): | ||
272 | + with torch.no_grad(): | ||
273 | + rs['valid'] = run_epoch(model, validloader, criterion_ce, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled) | ||
274 | + rs['test'] = run_epoch(model, testloader_, criterion_ce, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled) | ||
275 | + | ||
276 | + if ema is not None: | ||
277 | + model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()}) | ||
278 | + rs['valid'] = run_epoch(model_ema, validloader, criterion_ce, None, desc_default='valid(EMA)', epoch=epoch, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled) | ||
279 | + rs['test'] = run_epoch(model_ema, testloader_, criterion_ce, None, desc_default='*test(EMA)', epoch=epoch, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled) | ||
280 | + | ||
281 | + logger.info( | ||
282 | + f'epoch={epoch} ' | ||
283 | + f'[train] loss={rs["train"]["loss"]:.4f} top1={rs["train"]["top1"]:.4f} ' | ||
284 | + f'[valid] loss={rs["valid"]["loss"]:.4f} top1={rs["valid"]["top1"]:.4f} ' | ||
285 | + f'[test] loss={rs["test"]["loss"]:.4f} top1={rs["test"]["top1"]:.4f} ' | ||
286 | + ) | ||
287 | + | ||
288 | + if metric == 'last' or rs[metric]['top1'] > best_top1: | ||
289 | + if metric != 'last': | ||
290 | + best_top1 = rs[metric]['top1'] | ||
291 | + for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): | ||
292 | + result['%s_%s' % (key, setname)] = rs[setname][key] | ||
293 | + result['epoch'] = epoch | ||
294 | + | ||
295 | + writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) | ||
296 | + writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) | ||
297 | + | ||
298 | + reporter( | ||
299 | + loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], | ||
300 | + loss_test=rs['test']['loss'], top1_test=rs['test']['top1'] | ||
301 | + ) | ||
302 | + | ||
303 | + # save checkpoint | ||
304 | + if is_master and save_path: | ||
305 | + logger.info('save model@%d to %s, err=%.4f' % (epoch, save_path, 1 - best_top1)) | ||
306 | + torch.save({ | ||
307 | + 'epoch': epoch, | ||
308 | + 'log': { | ||
309 | + 'train': rs['train'].get_dict(), | ||
310 | + 'valid': rs['valid'].get_dict(), | ||
311 | + 'test': rs['test'].get_dict(), | ||
312 | + }, | ||
313 | + 'optimizer': optimizer.state_dict(), | ||
314 | + 'model': model.state_dict(), | ||
315 | + 'ema': ema.state_dict() if ema is not None else None, | ||
316 | + }, save_path) | ||
317 | + | ||
318 | + del model | ||
319 | + | ||
320 | + result['top1_test'] = best_top1 | ||
321 | + return result | ||
322 | + | ||
323 | + | ||
324 | +if __name__ == '__main__': | ||
325 | + parser = ConfigArgumentParser(conflict_handler='resolve') | ||
326 | + parser.add_argument('--tag', type=str, default='') | ||
327 | + parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder') | ||
328 | + parser.add_argument('--save', type=str, default='test.pth') | ||
329 | + parser.add_argument('--cv-ratio', type=float, default=0.0) | ||
330 | + parser.add_argument('--cv', type=int, default=0) | ||
331 | + parser.add_argument('--local_rank', type=int, default=-1) | ||
332 | + parser.add_argument('--evaluation-interval', type=int, default=5) | ||
333 | + parser.add_argument('--only-eval', action='store_true') | ||
334 | + args = parser.parse_args() | ||
335 | + | ||
336 | + assert (args.only_eval and args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.' | ||
337 | + | ||
338 | + if not args.only_eval: | ||
339 | + if args.save: | ||
340 | + logger.info('checkpoint will be saved at %s' % args.save) | ||
341 | + else: | ||
342 | + logger.warning('Provide --save argument to save the checkpoint. Without it, training result will not be saved!') | ||
343 | + | ||
344 | + import time | ||
345 | + t = time.time() | ||
346 | + result = train_and_eval(args.tag, args.dataroot, test_ratio=args.cv_ratio, cv_fold=args.cv, save_path=args.save, only_eval=args.only_eval, local_rank=args.local_rank, metric='test', evaluation_interval=args.evaluation_interval) | ||
347 | + elapsed = time.time() - t | ||
348 | + | ||
349 | + logger.info('done.') | ||
350 | + logger.info('model: %s' % C.get()['model']) | ||
351 | + logger.info('augmentation: %s' % C.get()['aug']) | ||
352 | + logger.info('\n' + json.dumps(result, indent=4)) | ||
353 | + logger.info('elapsed time: %.3f Hours' % (elapsed / 3600.)) | ||
354 | + logger.info('top1 error in testset: %.4f' % (1. - result['top1_test'])) | ||
355 | + logger.info(args.save) |
1 | +import sys | ||
2 | +sys.path.append('/data/private/fast-autoaugment-public') # TODO | ||
3 | + | ||
4 | +import time | ||
5 | +import os | ||
6 | +import threading | ||
7 | +import six | ||
8 | +from six.moves import queue | ||
9 | + | ||
10 | +from FastAutoAugment import safe_shell_exec | ||
11 | + | ||
12 | + | ||
13 | +def _exec_command(command): | ||
14 | + host_output = six.StringIO() | ||
15 | + try: | ||
16 | + exit_code = safe_shell_exec.execute(command, | ||
17 | + stdout=host_output, | ||
18 | + stderr=host_output) | ||
19 | + if exit_code != 0: | ||
20 | + print('Launching task function was not successful:\n{host_output}'.format(host_output=host_output.getvalue())) | ||
21 | + os._exit(exit_code) | ||
22 | + finally: | ||
23 | + host_output.close() | ||
24 | + return exit_code | ||
25 | + | ||
26 | + | ||
27 | +def execute_function_multithreaded(fn, | ||
28 | + args_list, | ||
29 | + block_until_all_done=True, | ||
30 | + max_concurrent_executions=1000): | ||
31 | + """ | ||
32 | + Executes fn in multiple threads each with one set of the args in the | ||
33 | + args_list. | ||
34 | + :param fn: function to be executed | ||
35 | + :type fn: | ||
36 | + :param args_list: | ||
37 | + :type args_list: list(list) | ||
38 | + :param block_until_all_done: if is True, function will block until all the | ||
39 | + threads are done and will return the results of each thread's execution. | ||
40 | + :type block_until_all_done: bool | ||
41 | + :param max_concurrent_executions: | ||
42 | + :type max_concurrent_executions: int | ||
43 | + :return: | ||
44 | + If block_until_all_done is False, returns None. If block_until_all_done is | ||
45 | + True, function returns the dict of results. | ||
46 | + { | ||
47 | + index: execution result of fn with args_list[index] | ||
48 | + } | ||
49 | + :rtype: dict | ||
50 | + """ | ||
51 | + result_queue = queue.Queue() | ||
52 | + worker_queue = queue.Queue() | ||
53 | + | ||
54 | + for i, arg in enumerate(args_list): | ||
55 | + arg.append(i) | ||
56 | + worker_queue.put(arg) | ||
57 | + | ||
58 | + def fn_execute(): | ||
59 | + while True: | ||
60 | + try: | ||
61 | + arg = worker_queue.get(block=False) | ||
62 | + except queue.Empty: | ||
63 | + return | ||
64 | + exec_index = arg[-1] | ||
65 | + res = fn(*arg[:-1]) | ||
66 | + result_queue.put((exec_index, res)) | ||
67 | + | ||
68 | + threads = [] | ||
69 | + number_of_threads = min(max_concurrent_executions, len(args_list)) | ||
70 | + | ||
71 | + for _ in range(number_of_threads): | ||
72 | + thread = threading.Thread(target=fn_execute) | ||
73 | + if not block_until_all_done: | ||
74 | + thread.daemon = True | ||
75 | + thread.start() | ||
76 | + threads.append(thread) | ||
77 | + | ||
78 | + # Returns the results only if block_until_all_done is set. | ||
79 | + results = None | ||
80 | + if block_until_all_done: | ||
81 | + # Because join() cannot be interrupted by signal, a single join() | ||
82 | + # needs to be separated into join()s with timeout in a while loop. | ||
83 | + have_alive_child = True | ||
84 | + while have_alive_child: | ||
85 | + have_alive_child = False | ||
86 | + for t in threads: | ||
87 | + t.join(0.1) | ||
88 | + if t.is_alive(): | ||
89 | + have_alive_child = True | ||
90 | + | ||
91 | + results = {} | ||
92 | + while not result_queue.empty(): | ||
93 | + item = result_queue.get() | ||
94 | + results[item[0]] = item[1] | ||
95 | + | ||
96 | + if len(results) != len(args_list): | ||
97 | + raise RuntimeError( | ||
98 | + 'Some threads for func {func} did not complete ' | ||
99 | + 'successfully.'.format(func=fn.__name__)) | ||
100 | + return results | ||
101 | + | ||
102 | + | ||
103 | +if __name__ == '__main__': | ||
104 | + import argparse | ||
105 | + | ||
106 | + parser = argparse.ArgumentParser() | ||
107 | + parser.add_argument('--host', type=str) | ||
108 | + parser.add_argument('--num-gpus', type=int, default=4) | ||
109 | + parser.add_argument('--master', type=str, default='task1') | ||
110 | + parser.add_argument('--port', type=int, default=1958) | ||
111 | + parser.add_argument('-c', '--conf', type=str) | ||
112 | + parser.add_argument('--args', type=str, default='') | ||
113 | + | ||
114 | + args = parser.parse_args() | ||
115 | + | ||
116 | + try: | ||
117 | + hosts = ['task%d' % (x + 1) for x in range(int(args.host))] | ||
118 | + except: | ||
119 | + hosts = args.host.split(',') | ||
120 | + | ||
121 | + cwd = os.getcwd() | ||
122 | + command_list = [] | ||
123 | + for node_rank, host in enumerate(hosts): | ||
124 | + ssh_cmd = f'ssh -t -t -o StrictHostKeyChecking=no {host} -p 22 ' \ | ||
125 | + f'\'bash -O huponexit -c "cd {cwd} && ' \ | ||
126 | + f'python -m torch.distributed.launch --nproc_per_node={args.num_gpus} --nnodes={len(hosts)} ' \ | ||
127 | + f'--master_addr={args.master} --master_port={args.port} --node_rank={node_rank} ' \ | ||
128 | + f'FastAutoAugment/train.py -c {args.conf} {args.args}"' \ | ||
129 | + '\'' | ||
130 | + print(ssh_cmd) | ||
131 | + | ||
132 | + command_list.append([ssh_cmd]) | ||
133 | + | ||
134 | + execute_function_multithreaded(_exec_command, | ||
135 | + command_list[1:], | ||
136 | + block_until_all_done=False) | ||
137 | + | ||
138 | + print(command_list[0]) | ||
139 | + | ||
140 | + while True: | ||
141 | + time.sleep(1) | ||
142 | + | ||
143 | + # thread = threading.Thread(target=safe_shell_exec.execute, args=(command_list[0][0],)) | ||
144 | + # thread.start() | ||
145 | + # thread.join() | ||
146 | + | ||
147 | + # while True: | ||
148 | + # time.sleep(1) |
code/fast-autoaugment-master/LICENSE
0 → 100644
1 | +MIT License | ||
2 | + | ||
3 | +Copyright (c) 2019 Ildoo Kim | ||
4 | + | ||
5 | +Permission is hereby granted, free of charge, to any person obtaining a copy | ||
6 | +of this software and associated documentation files (the "Software"), to deal | ||
7 | +in the Software without restriction, including without limitation the rights | ||
8 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
9 | +copies of the Software, and to permit persons to whom the Software is | ||
10 | +furnished to do so, subject to the following conditions: | ||
11 | + | ||
12 | +The above copyright notice and this permission notice shall be included in all | ||
13 | +copies or substantial portions of the Software. | ||
14 | + | ||
15 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
16 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
17 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
18 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
19 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
20 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
21 | +SOFTWARE. |
code/fast-autoaugment-master/README.md
0 → 100644
1 | +# Fast AutoAugment **(Accepted at NeurIPS 2019)** | ||
2 | + | ||
3 | +Official [Fast AutoAugment](https://arxiv.org/abs/1905.00397) implementation in PyTorch. | ||
4 | + | ||
5 | +- Fast AutoAugment learns augmentation policies using a more efficient search strategy based on density matching. | ||
6 | +- Fast AutoAugment speeds up the search time by orders of magnitude while maintaining the comparable performances. | ||
7 | + | ||
8 | +<p align="center"> | ||
9 | +<img src="etc/search.jpg" height=350> | ||
10 | +</p> | ||
11 | + | ||
12 | +## Results | ||
13 | + | ||
14 | +### CIFAR-10 / 100 | ||
15 | + | ||
16 | +Search : **3.5 GPU Hours (1428x faster than AutoAugment)**, WResNet-40x2 on Reduced CIFAR-10 | ||
17 | + | ||
18 | +| Model(CIFAR-10) | Baseline | Cutout | AutoAugment | Fast AutoAugment<br/>(transfer/direct) | | | ||
19 | +|-------------------------|------------|------------|-------------|------------------|----| | ||
20 | +| Wide-ResNet-40-2 | 5.3 | 4.1 | 3.7 | 3.6 / 3.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_wresnet40x2_top1_3.52.pth) | | ||
21 | +| Wide-ResNet-28-10 | 3.9 | 3.1 | 2.6 | 2.7 / 2.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_wresnet28x10_top1.pth) | | ||
22 | +| Shake-Shake(26 2x32d) | 3.6 | 3.0 | 2.5 | 2.7 / 2.5 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_shake26_2x32d_top1_2.68.pth) | | ||
23 | +| Shake-Shake(26 2x96d) | 2.9 | 2.6 | 2.0 | 2.0 / 2.0 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_shake26_2x96d_top1_1.97.pth) | | ||
24 | +| Shake-Shake(26 2x112d) | 2.8 | 2.6 | 1.9 | 2.0 / 1.9 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_shake26_2x112d_top1_2.04.pth) | | ||
25 | +| PyramidNet+ShakeDrop | 2.7 | 2.3 | 1.5 | 1.8 / 1.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar10_pyramid272_top1_1.44.pth) | | ||
26 | + | ||
27 | +| Model(CIFAR-100) | Baseline | Cutout | AutoAugment | Fast AutoAugment<br/>(transfer/direct) | | | ||
28 | +|-----------------------|------------|------------|-------------|------------------|----| | ||
29 | +| Wide-ResNet-40-2 | 26.0 | 25.2 | 20.7 | 20.7 / 20.6 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_wresnet40x2_top1_20.43.pth) | | ||
30 | +| Wide-ResNet-28-10 | 18.8 | 18.4 | 17.1 | 17.3 / 17.3 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_wresnet28x10_top1_17.17.pth) | | ||
31 | +| Shake-Shake(26 2x96d) | 17.1 | 16.0 | 14.3 | 14.9 / 14.6 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_shake26_2x96d_top1_15.15.pth) | | ||
32 | +| PyramidNet+ShakeDrop | 14.0 | 12.2 | 10.7 | 11.9 / 11.7 | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/cifar100_pyramid272_top1_11.74.pth) | | ||
33 | + | ||
34 | +### ImageNet | ||
35 | + | ||
36 | +Search : **450 GPU Hours (33x faster than AutoAugment)**, ResNet-50 on Reduced ImageNet | ||
37 | + | ||
38 | +| Model | Baseline | AutoAugment | Fast AutoAugment<br/>(Top1/Top5) | | | ||
39 | +|------------|------------|-------------|------------------|----| | ||
40 | +| ResNet-50 | 23.7 / 6.9 | 22.4 / 6.2 | **22.4 / 6.3** | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/imagenet_resnet50_top1_22.2.pth) | | ||
41 | +| ResNet-200 | 21.5 / 5.8 | 20.0 / 5.0 | **19.4 / 4.7** | [Download](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/imagenet_resnet200_top1_19.4.pth) | | ||
42 | + | ||
43 | +Notes | ||
44 | +* We evaluated resnet-50 and resnet-200 with resolution of 224 and 320, respectively. According to the original resnet paper, resnet 200 was tested with the resolution of 320. Also our resnet-200 baseline's performance was similar when we use the resolution. | ||
45 | +* But with recent our code clean-up and bugfixes, we've found that the baseline performs similar to the baseline even using 224x224. | ||
46 | +* When we use 224x224, resnet-200 performs **20.0 / 5.2**. Download link for the trained model is [here](https://arena.kakaocdn.net/brainrepo/fast-autoaugment/imagenet_resnet200_res224.pth). | ||
47 | + | ||
48 | +We have conducted additional experiments with EfficientNet. | ||
49 | + | ||
50 | +| Model | Baseline | AutoAugment | | Our Baseline(Batch) | +Fast AA | | ||
51 | +|-------|------------|-------------|---|---------------------|----------| | ||
52 | +| B0 | 23.2 | 22.7 | | 22.96 | 22.68 | | ||
53 | + | ||
54 | +### SVHN Test | ||
55 | + | ||
56 | +Search : **1.5 GPU Hours** | ||
57 | + | ||
58 | +| | Baseline | AutoAug / Our | Fast AutoAugment | | ||
59 | +|----------------------------------|---------:|--------------:|--------:| | ||
60 | +| Wide-Resnet28x10 | 1.5 | 1.1 | 1.1 | | ||
61 | + | ||
62 | +## Run | ||
63 | + | ||
64 | +We conducted experiments under | ||
65 | + | ||
66 | +- python 3.6.9 | ||
67 | +- pytorch 1.2.0, torchvision 0.4.0, cuda10 | ||
68 | + | ||
69 | +### Search a augmentation policy | ||
70 | + | ||
71 | +Please read ray's document to construct a proper ray cluster : https://github.com/ray-project/ray, and run search.py with the master's redis address. | ||
72 | + | ||
73 | +``` | ||
74 | +$ python search.py -c confs/wresnet40x2_cifar10_b512.yaml --dataroot ... --redis ... | ||
75 | +``` | ||
76 | + | ||
77 | +### Train a model with found policies | ||
78 | + | ||
79 | +You can train network architectures on CIFAR-10 / 100 and ImageNet with our searched policies. | ||
80 | + | ||
81 | +- fa_reduced_cifar10 : reduced CIFAR-10(4k images), WResNet-40x2 | ||
82 | +- fa_reduced_imagenet : reduced ImageNet(50k images, 120 classes), ResNet-50 | ||
83 | + | ||
84 | +``` | ||
85 | +$ export PYTHONPATH=$PYTHONPATH:$PWD | ||
86 | +$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10 | ||
87 | +$ python FastAutoAugment/train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100 | ||
88 | +$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10 | ||
89 | +$ python FastAutoAugment/train.py -c confs/wresnet28x10_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar100 | ||
90 | +... | ||
91 | +$ python FastAutoAugment/train.py -c confs/resnet50_b512.yaml --aug fa_reduced_imagenet | ||
92 | +$ python FastAutoAugment/train.py -c confs/resnet200_b512.yaml --aug fa_reduced_imagenet | ||
93 | +``` | ||
94 | + | ||
95 | +By adding --only-eval and --save arguments, you can test trained models without training. | ||
96 | + | ||
97 | +If you want to train with multi-gpu/node, use `torch.distributed.launch` such as | ||
98 | + | ||
99 | +```bash | ||
100 | +$ python -m torch.distributed.launch --nproc_per_node={num_gpu_per_node} --nnodes={num_node} --master_addr={master} --master_port={master_port} --node_rank={0,1,2,...,num_node} FastAutoAugment/train.py -c confs/efficientnet_b4.yaml --aug fa_reduced_imagenet | ||
101 | +``` | ||
102 | + | ||
103 | +## Citation | ||
104 | + | ||
105 | +If you use this code in your research, please cite our [paper](https://arxiv.org/abs/1905.00397). | ||
106 | + | ||
107 | +``` | ||
108 | +@inproceedings{lim2019fast, | ||
109 | + title={Fast AutoAugment}, | ||
110 | + author={Lim, Sungbin and Kim, Ildoo and Kim, Taesup and Kim, Chiheon and Kim, Sungwoong}, | ||
111 | + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, | ||
112 | + year={2019} | ||
113 | +} | ||
114 | +``` | ||
115 | + | ||
116 | +## Contact for Issues | ||
117 | +- Ildoo Kim, ildoo.kim@kakaobrain.com | ||
118 | + | ||
119 | +## References & Opensources | ||
120 | + | ||
121 | +We increase the batch size and adapt the learning rate accordingly to boost the training. Otherwise, we set other hyperparameters equal to AutoAugment if possible. For the unknown hyperparameters, we follow values from the original references or we tune them to match baseline performances. | ||
122 | + | ||
123 | +- **ResNet** : [paper1](https://arxiv.org/abs/1512.03385), [paper2](https://arxiv.org/abs/1603.05027), [code](https://github.com/osmr/imgclsmob/tree/master/pytorch/pytorchcv/models) | ||
124 | +- **PyramidNet** : [paper](https://arxiv.org/abs/1610.02915), [code](https://github.com/dyhan0920/PyramidNet-PyTorch) | ||
125 | +- **Wide-ResNet** : [code](https://github.com/meliketoy/wide-resnet.pytorch) | ||
126 | +- **Shake-Shake** : [code](https://github.com/owruby/shake-shake_pytorch) | ||
127 | +- **ShakeDrop Regularization** : [paper](https://arxiv.org/abs/1802.02375), [code](https://github.com/owruby/shake-drop_pytorch) | ||
128 | +- **AutoAugment** : [code](https://github.com/tensorflow/models/tree/master/research/autoaugment) | ||
129 | +- **Ray** : [code](https://github.com/ray-project/ray) | ||
130 | +- **HyperOpt** : [code](https://github.com/hyperopt/hyperopt) |
code/fast-autoaugment-master/__init__.py
0 → 100644
File mode changed
code/fast-autoaugment-master/archive.py
0 → 100644
This diff could not be displayed because it is too large.
1 | +model: | ||
2 | + type: efficientnet-b0 | ||
3 | + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. | ||
4 | +dataset: imagenet | ||
5 | +aug: fa_reduced_imagenet | ||
6 | +cutout: 0 | ||
7 | +batch: 128 # per gpu | ||
8 | +epoch: 350 | ||
9 | +lr: 0.008 # 0.256 for 4096 batch | ||
10 | +lr_schedule: | ||
11 | + type: 'efficientnet' | ||
12 | + warmup: | ||
13 | + multiplier: 1 | ||
14 | + epoch: 5 | ||
15 | +optimizer: | ||
16 | + type: rmsprop | ||
17 | + decay: 0.00001 | ||
18 | + clip: 0 | ||
19 | + ema: 0.9999 | ||
20 | + ema_interval: -1 | ||
21 | +lb_smooth: 0.1 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +model: | ||
2 | + type: efficientnet-b0 | ||
3 | + condconv_num_expert: 8 # if this is greater than 1(eg. 4), it activates condconv. | ||
4 | +dataset: imagenet | ||
5 | +aug: fa_reduced_imagenet | ||
6 | +cutout: 0 | ||
7 | +batch: 128 # per gpu | ||
8 | +epoch: 350 | ||
9 | +lr: 0.008 # 0.256 for 4096 batch | ||
10 | +lr_schedule: | ||
11 | + type: 'efficientnet' | ||
12 | + warmup: | ||
13 | + multiplier: 1 | ||
14 | + epoch: 5 | ||
15 | +optimizer: | ||
16 | + type: rmsprop | ||
17 | + decay: 0.00001 | ||
18 | + clip: 0 | ||
19 | + ema: 0.9999 | ||
20 | + ema_interval: -1 | ||
21 | +lb_smooth: 0.1 | ||
22 | +mixup: 0.2 |
1 | +model: | ||
2 | + type: efficientnet-b1 | ||
3 | + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. | ||
4 | +dataset: imagenet | ||
5 | +aug: fa_reduced_imagenet | ||
6 | +cutout: 0 | ||
7 | +batch: 128 # per gpu | ||
8 | +epoch: 350 | ||
9 | +lr: 0.008 # 0.256 for 4096 batch | ||
10 | +lr_schedule: | ||
11 | + type: 'efficientnet' | ||
12 | + warmup: | ||
13 | + multiplier: 1 | ||
14 | + epoch: 5 | ||
15 | +optimizer: | ||
16 | + type: rmsprop | ||
17 | + decay: 0.00001 | ||
18 | + clip: 0 | ||
19 | + ema: 0.9999 | ||
20 | + ema_interval: -1 | ||
21 | +lb_smooth: 0.1 |
1 | +model: | ||
2 | + type: efficientnet-b2 | ||
3 | + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. | ||
4 | +dataset: imagenet | ||
5 | +aug: fa_reduced_imagenet | ||
6 | +cutout: 0 | ||
7 | +batch: 128 # per gpu | ||
8 | +epoch: 350 | ||
9 | +lr: 0.008 # 0.256 for 4096 batch | ||
10 | +lr_schedule: | ||
11 | + type: 'efficientnet' | ||
12 | + warmup: | ||
13 | + multiplier: 1 | ||
14 | + epoch: 5 | ||
15 | +optimizer: | ||
16 | + type: rmsprop | ||
17 | + decay: 0.00001 | ||
18 | + clip: 0 | ||
19 | + ema: 0.9999 | ||
20 | + ema_interval: -1 | ||
21 | +lb_smooth: 0.1 |
1 | +model: | ||
2 | + type: efficientnet-b3 | ||
3 | + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. | ||
4 | +dataset: imagenet | ||
5 | +aug: fa_reduced_imagenet | ||
6 | +cutout: 0 | ||
7 | +batch: 64 # per gpu | ||
8 | +epoch: 350 | ||
9 | +lr: 0.004 # 0.256 for 4096 batch | ||
10 | +lr_schedule: | ||
11 | + type: 'efficientnet' | ||
12 | + warmup: | ||
13 | + multiplier: 1 | ||
14 | + epoch: 5 | ||
15 | +optimizer: | ||
16 | + type: rmsprop | ||
17 | + decay: 0.00001 | ||
18 | + clip: 0 | ||
19 | + ema: 0.9999 | ||
20 | + ema_interval: -1 | ||
21 | +lb_smooth: 0.1 |
1 | +model: | ||
2 | + type: efficientnet-b4 | ||
3 | + condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. | ||
4 | +dataset: imagenet | ||
5 | +aug: fa_reduced_imagenet | ||
6 | +cutout: 0 | ||
7 | +batch: 32 # per gpu | ||
8 | +epoch: 350 | ||
9 | +lr: 0.002 # 0.256 for 4096 batch | ||
10 | +lr_schedule: | ||
11 | + type: 'efficientnet' | ||
12 | + warmup: | ||
13 | + multiplier: 1 | ||
14 | + epoch: 5 | ||
15 | +optimizer: | ||
16 | + type: rmsprop | ||
17 | + decay: 0.00001 | ||
18 | + clip: 0 | ||
19 | + ema: 0.9999 | ||
20 | + ema_interval: -1 | ||
21 | +lb_smooth: 0.1 |
1 | +model: | ||
2 | + type: pyramid | ||
3 | + depth: 272 | ||
4 | + alpha: 200 | ||
5 | + bottleneck: True | ||
6 | +dataset: cifar10 | ||
7 | +aug: fa_reduced_cifar10 | ||
8 | +cutout: 16 | ||
9 | +batch: 64 | ||
10 | +epoch: 1800 | ||
11 | +lr: 0.05 | ||
12 | +lr_schedule: | ||
13 | + type: 'cosine' | ||
14 | + warmup: | ||
15 | + multiplier: 1 | ||
16 | + epoch: 5 | ||
17 | +optimizer: | ||
18 | + type: sgd | ||
19 | + nesterov: True | ||
20 | + decay: 0.00005 |
1 | +model: | ||
2 | + type: resnet50 | ||
3 | +dataset: imagenet | ||
4 | +aug: fa_reduced_imagenet | ||
5 | +cutout: 0 | ||
6 | +batch: 128 | ||
7 | +epoch: 270 | ||
8 | +lr: 0.05 | ||
9 | +lr_schedule: | ||
10 | + type: 'resnet' | ||
11 | + warmup: | ||
12 | + multiplier: 1 | ||
13 | + epoch: 5 | ||
14 | +optimizer: | ||
15 | + type: sgd | ||
16 | + nesterov: True | ||
17 | + decay: 0.0001 | ||
18 | + clip: 0 | ||
19 | + ema: 0 |
1 | +model: | ||
2 | + type: resnet50 | ||
3 | +dataset: imagenet | ||
4 | +aug: fa_reduced_imagenet | ||
5 | +cutout: 0 | ||
6 | +batch: 128 | ||
7 | +epoch: 270 | ||
8 | +lr: 0.05 | ||
9 | +lr_schedule: | ||
10 | + type: 'resnet' | ||
11 | + warmup: | ||
12 | + multiplier: 1 | ||
13 | + epoch: 5 | ||
14 | +optimizer: | ||
15 | + type: sgd | ||
16 | + nesterov: True | ||
17 | + decay: 0.0001 | ||
18 | + clip: 0 | ||
19 | + ema: 0 | ||
20 | +#lb_smooth: 0.1 | ||
21 | +mixup: 0.2 |
1 | +model: | ||
2 | + type: wresnet28_10 | ||
3 | +dataset: cifar10 | ||
4 | +aug: fa_reduced_cifar10 | ||
5 | +cutout: 16 | ||
6 | +batch: 128 | ||
7 | +epoch: 200 | ||
8 | +lr: 0.1 | ||
9 | +lr_schedule: | ||
10 | + type: 'cosine' | ||
11 | + warmup: | ||
12 | + multiplier: 1 | ||
13 | + epoch: 5 | ||
14 | +optimizer: | ||
15 | + type: sgd | ||
16 | + nesterov: True | ||
17 | + decay: 0.0005 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +model: | ||
2 | + type: wresnet28_10 | ||
3 | +dataset: svhn | ||
4 | +aug: fa_reduced_svhn | ||
5 | +cutout: 20 | ||
6 | +batch: 128 | ||
7 | +epoch: 200 | ||
8 | +lr: 0.01 | ||
9 | +lr_schedule: | ||
10 | + type: 'cosine' | ||
11 | + warmup: | ||
12 | + multiplier: 1 | ||
13 | + epoch: 5 | ||
14 | +optimizer: | ||
15 | + type: sgd | ||
16 | + nesterov: True | ||
17 | + decay: 0.0005 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/fast-autoaugment-master/etc/search.jpg
0 → 100644
967 KB
1 | +git+https://github.com/wbaek/theconf | ||
2 | +git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@08f7d5e | ||
3 | +git+https://github.com/ildoonet/pystopwatch2.git | ||
4 | +git+https://github.com/hyperopt/hyperopt.git | ||
5 | +git+https://github.com/kakaobrain/torchlars | ||
6 | + | ||
7 | +pretrainedmodels | ||
8 | +tqdm | ||
9 | +tensorboardx | ||
10 | +sklearn | ||
11 | +ray | ||
12 | +matplotlib | ||
13 | +psutil | ||
14 | +requests | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/filter_normal_csv.ipynb
0 → 100644
1 | +{ | ||
2 | + "cells": [ | ||
3 | + { | ||
4 | + "cell_type": "code", | ||
5 | + "execution_count": 1, | ||
6 | + "metadata": {}, | ||
7 | + "outputs": [], | ||
8 | + "source": [ | ||
9 | + "import pandas as pd" | ||
10 | + ] | ||
11 | + }, | ||
12 | + { | ||
13 | + "cell_type": "code", | ||
14 | + "execution_count": 3, | ||
15 | + "metadata": {}, | ||
16 | + "outputs": [ | ||
17 | + { | ||
18 | + "data": { | ||
19 | + "text/html": [ | ||
20 | + "<div>\n", | ||
21 | + "<style scoped>\n", | ||
22 | + " .dataframe tbody tr th:only-of-type {\n", | ||
23 | + " vertical-align: middle;\n", | ||
24 | + " }\n", | ||
25 | + "\n", | ||
26 | + " .dataframe tbody tr th {\n", | ||
27 | + " vertical-align: top;\n", | ||
28 | + " }\n", | ||
29 | + "\n", | ||
30 | + " .dataframe thead th {\n", | ||
31 | + " text-align: right;\n", | ||
32 | + " }\n", | ||
33 | + "</style>\n", | ||
34 | + "<table border=\"1\" class=\"dataframe\">\n", | ||
35 | + " <thead>\n", | ||
36 | + " <tr style=\"text-align: right;\">\n", | ||
37 | + " <th></th>\n", | ||
38 | + " <th>id</th>\n", | ||
39 | + " <th>Label</th>\n", | ||
40 | + " <th>Subject</th>\n", | ||
41 | + " <th>Date</th>\n", | ||
42 | + " <th>Gender</th>\n", | ||
43 | + " <th>Age</th>\n", | ||
44 | + " <th>mmse</th>\n", | ||
45 | + " <th>ageAtEntry</th>\n", | ||
46 | + " <th>cdr</th>\n", | ||
47 | + " <th>commun</th>\n", | ||
48 | + " <th>...</th>\n", | ||
49 | + " <th>memory</th>\n", | ||
50 | + " <th>orient</th>\n", | ||
51 | + " <th>perscare</th>\n", | ||
52 | + " <th>apoe</th>\n", | ||
53 | + " <th>sumbox</th>\n", | ||
54 | + " <th>acsparnt</th>\n", | ||
55 | + " <th>height</th>\n", | ||
56 | + " <th>weight</th>\n", | ||
57 | + " <th>primStudy</th>\n", | ||
58 | + " <th>acsStudy</th>\n", | ||
59 | + " </tr>\n", | ||
60 | + " </thead>\n", | ||
61 | + " <tbody>\n", | ||
62 | + " <tr>\n", | ||
63 | + " <th>0</th>\n", | ||
64 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
65 | + " <td>OAS30001_ClinicalData_d3025</td>\n", | ||
66 | + " <td>OAS30001</td>\n", | ||
67 | + " <td>NaN</td>\n", | ||
68 | + " <td>female</td>\n", | ||
69 | + " <td>NaN</td>\n", | ||
70 | + " <td>30.0</td>\n", | ||
71 | + " <td>65.149895</td>\n", | ||
72 | + " <td>0.0</td>\n", | ||
73 | + " <td>0.0</td>\n", | ||
74 | + " <td>...</td>\n", | ||
75 | + " <td>0.0</td>\n", | ||
76 | + " <td>0.0</td>\n", | ||
77 | + " <td>0.0</td>\n", | ||
78 | + " <td>23.0</td>\n", | ||
79 | + " <td>0.0</td>\n", | ||
80 | + " <td>NaN</td>\n", | ||
81 | + " <td>64.0</td>\n", | ||
82 | + " <td>180.0</td>\n", | ||
83 | + " <td>NaN</td>\n", | ||
84 | + " <td>NaN</td>\n", | ||
85 | + " </tr>\n", | ||
86 | + " <tr>\n", | ||
87 | + " <th>1</th>\n", | ||
88 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
89 | + " <td>OAS30001_ClinicalData_d3977</td>\n", | ||
90 | + " <td>OAS30001</td>\n", | ||
91 | + " <td>NaN</td>\n", | ||
92 | + " <td>female</td>\n", | ||
93 | + " <td>NaN</td>\n", | ||
94 | + " <td>29.0</td>\n", | ||
95 | + " <td>65.149895</td>\n", | ||
96 | + " <td>0.0</td>\n", | ||
97 | + " <td>0.0</td>\n", | ||
98 | + " <td>...</td>\n", | ||
99 | + " <td>0.0</td>\n", | ||
100 | + " <td>0.0</td>\n", | ||
101 | + " <td>0.0</td>\n", | ||
102 | + " <td>23.0</td>\n", | ||
103 | + " <td>0.0</td>\n", | ||
104 | + " <td>NaN</td>\n", | ||
105 | + " <td>NaN</td>\n", | ||
106 | + " <td>NaN</td>\n", | ||
107 | + " <td>NaN</td>\n", | ||
108 | + " <td>NaN</td>\n", | ||
109 | + " </tr>\n", | ||
110 | + " <tr>\n", | ||
111 | + " <th>2</th>\n", | ||
112 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
113 | + " <td>OAS30001_ClinicalData_d3332</td>\n", | ||
114 | + " <td>OAS30001</td>\n", | ||
115 | + " <td>NaN</td>\n", | ||
116 | + " <td>female</td>\n", | ||
117 | + " <td>NaN</td>\n", | ||
118 | + " <td>30.0</td>\n", | ||
119 | + " <td>65.149895</td>\n", | ||
120 | + " <td>0.0</td>\n", | ||
121 | + " <td>0.0</td>\n", | ||
122 | + " <td>...</td>\n", | ||
123 | + " <td>0.0</td>\n", | ||
124 | + " <td>0.0</td>\n", | ||
125 | + " <td>0.0</td>\n", | ||
126 | + " <td>23.0</td>\n", | ||
127 | + " <td>0.0</td>\n", | ||
128 | + " <td>NaN</td>\n", | ||
129 | + " <td>63.0</td>\n", | ||
130 | + " <td>185.0</td>\n", | ||
131 | + " <td>NaN</td>\n", | ||
132 | + " <td>NaN</td>\n", | ||
133 | + " </tr>\n", | ||
134 | + " <tr>\n", | ||
135 | + " <th>3</th>\n", | ||
136 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
137 | + " <td>OAS30001_ClinicalData_d0000</td>\n", | ||
138 | + " <td>OAS30001</td>\n", | ||
139 | + " <td>NaN</td>\n", | ||
140 | + " <td>female</td>\n", | ||
141 | + " <td>NaN</td>\n", | ||
142 | + " <td>28.0</td>\n", | ||
143 | + " <td>65.149895</td>\n", | ||
144 | + " <td>0.0</td>\n", | ||
145 | + " <td>0.0</td>\n", | ||
146 | + " <td>...</td>\n", | ||
147 | + " <td>0.0</td>\n", | ||
148 | + " <td>0.0</td>\n", | ||
149 | + " <td>0.0</td>\n", | ||
150 | + " <td>23.0</td>\n", | ||
151 | + " <td>0.0</td>\n", | ||
152 | + " <td>NaN</td>\n", | ||
153 | + " <td>NaN</td>\n", | ||
154 | + " <td>NaN</td>\n", | ||
155 | + " <td>NaN</td>\n", | ||
156 | + " <td>NaN</td>\n", | ||
157 | + " </tr>\n", | ||
158 | + " <tr>\n", | ||
159 | + " <th>4</th>\n", | ||
160 | + " <td>/@WEBAPP/images/r.gif</td>\n", | ||
161 | + " <td>OAS30001_ClinicalData_d1456</td>\n", | ||
162 | + " <td>OAS30001</td>\n", | ||
163 | + " <td>NaN</td>\n", | ||
164 | + " <td>female</td>\n", | ||
165 | + " <td>NaN</td>\n", | ||
166 | + " <td>30.0</td>\n", | ||
167 | + " <td>65.149895</td>\n", | ||
168 | + " <td>0.0</td>\n", | ||
169 | + " <td>0.0</td>\n", | ||
170 | + " <td>...</td>\n", | ||
171 | + " <td>0.0</td>\n", | ||
172 | + " <td>0.0</td>\n", | ||
173 | + " <td>0.0</td>\n", | ||
174 | + " <td>23.0</td>\n", | ||
175 | + " <td>0.0</td>\n", | ||
176 | + " <td>NaN</td>\n", | ||
177 | + " <td>63.0</td>\n", | ||
178 | + " <td>173.0</td>\n", | ||
179 | + " <td>NaN</td>\n", | ||
180 | + " <td>NaN</td>\n", | ||
181 | + " </tr>\n", | ||
182 | + " </tbody>\n", | ||
183 | + "</table>\n", | ||
184 | + "<p>5 rows × 27 columns</p>\n", | ||
185 | + "</div>" | ||
186 | + ], | ||
187 | + "text/plain": [ | ||
188 | + " id Label Subject Date Gender \\\n", | ||
189 | + "0 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3025 OAS30001 NaN female \n", | ||
190 | + "1 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3977 OAS30001 NaN female \n", | ||
191 | + "2 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d3332 OAS30001 NaN female \n", | ||
192 | + "3 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d0000 OAS30001 NaN female \n", | ||
193 | + "4 /@WEBAPP/images/r.gif OAS30001_ClinicalData_d1456 OAS30001 NaN female \n", | ||
194 | + "\n", | ||
195 | + " Age mmse ageAtEntry cdr commun ... memory orient perscare apoe \\\n", | ||
196 | + "0 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
197 | + "1 NaN 29.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
198 | + "2 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
199 | + "3 NaN 28.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
200 | + "4 NaN 30.0 65.149895 0.0 0.0 ... 0.0 0.0 0.0 23.0 \n", | ||
201 | + "\n", | ||
202 | + " sumbox acsparnt height weight primStudy acsStudy \n", | ||
203 | + "0 0.0 NaN 64.0 180.0 NaN NaN \n", | ||
204 | + "1 0.0 NaN NaN NaN NaN NaN \n", | ||
205 | + "2 0.0 NaN 63.0 185.0 NaN NaN \n", | ||
206 | + "3 0.0 NaN NaN NaN NaN NaN \n", | ||
207 | + "4 0.0 NaN 63.0 173.0 NaN NaN \n", | ||
208 | + "\n", | ||
209 | + "[5 rows x 27 columns]" | ||
210 | + ] | ||
211 | + }, | ||
212 | + "execution_count": 3, | ||
213 | + "metadata": {}, | ||
214 | + "output_type": "execute_result" | ||
215 | + } | ||
216 | + ], | ||
217 | + "source": [ | ||
218 | + "all_data = pd.read_csv(\"..\\data\\ADRC clinical data_all.csv\")\n", | ||
219 | + "\n", | ||
220 | + "all_data.head()" | ||
221 | + ] | ||
222 | + }, | ||
223 | + { | ||
224 | + "cell_type": "code", | ||
225 | + "execution_count": 18, | ||
226 | + "metadata": {}, | ||
227 | + "outputs": [ | ||
228 | + { | ||
229 | + "name": "stdout", | ||
230 | + "output_type": "stream", | ||
231 | + "text": [ | ||
232 | + "Subject\n", | ||
233 | + "OAS30001 0.0\n", | ||
234 | + "OAS30002 0.0\n", | ||
235 | + "OAS30003 0.0\n", | ||
236 | + "OAS30004 0.0\n", | ||
237 | + "OAS30005 0.0\n", | ||
238 | + " ... \n", | ||
239 | + "OAS31168 0.0\n", | ||
240 | + "OAS31169 3.0\n", | ||
241 | + "OAS31170 2.0\n", | ||
242 | + "OAS31171 2.0\n", | ||
243 | + "OAS31172 0.0\n", | ||
244 | + "Name: cdr, Length: 1098, dtype: float64\n" | ||
245 | + ] | ||
246 | + } | ||
247 | + ], | ||
248 | + "source": [ | ||
249 | + "ad = all_data.groupby(['Subject'])['cdr'].max()\n", | ||
250 | + "print(ad)" | ||
251 | + ] | ||
252 | + }, | ||
253 | + { | ||
254 | + "cell_type": "code", | ||
255 | + "execution_count": 21, | ||
256 | + "metadata": {}, | ||
257 | + "outputs": [ | ||
258 | + { | ||
259 | + "data": { | ||
260 | + "text/plain": [ | ||
261 | + "'OAS30001'" | ||
262 | + ] | ||
263 | + }, | ||
264 | + "execution_count": 21, | ||
265 | + "metadata": {}, | ||
266 | + "output_type": "execute_result" | ||
267 | + } | ||
268 | + ], | ||
269 | + "source": [ | ||
270 | + "ad.index[0]" | ||
271 | + ] | ||
272 | + }, | ||
273 | + { | ||
274 | + "cell_type": "code", | ||
275 | + "execution_count": 22, | ||
276 | + "metadata": {}, | ||
277 | + "outputs": [ | ||
278 | + { | ||
279 | + "ename": "SyntaxError", | ||
280 | + "evalue": "unexpected EOF while parsing (<ipython-input-22-b8a078b72aca>, line 5)", | ||
281 | + "output_type": "error", | ||
282 | + "traceback": [ | ||
283 | + "\u001b[1;36m File \u001b[1;32m\"<ipython-input-22-b8a078b72aca>\"\u001b[1;36m, line \u001b[1;32m5\u001b[0m\n\u001b[1;33m #print(filtered)\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m unexpected EOF while parsing\n" | ||
284 | + ] | ||
285 | + } | ||
286 | + ], | ||
287 | + "source": [ | ||
288 | + "filtered = []\n", | ||
289 | + "for i, val in enumerate(ad):\n", | ||
290 | + " if ad[i] == 0:\n", | ||
291 | + " filtered.append(ad.index[i])\n", | ||
292 | + "#print(filtered)" | ||
293 | + ] | ||
294 | + }, | ||
295 | + { | ||
296 | + "cell_type": "code", | ||
297 | + "execution_count": 23, | ||
298 | + "metadata": {}, | ||
299 | + "outputs": [], | ||
300 | + "source": [ | ||
301 | + "df_filtered = pd.DataFrame(filtered)\n", | ||
302 | + "df_filtered.to_csv('..\\data\\ADRC clinical data_normal.csv')" | ||
303 | + ] | ||
304 | + }, | ||
305 | + { | ||
306 | + "cell_type": "code", | ||
307 | + "execution_count": null, | ||
308 | + "metadata": {}, | ||
309 | + "outputs": [], | ||
310 | + "source": [] | ||
311 | + } | ||
312 | + ], | ||
313 | + "metadata": { | ||
314 | + "kernelspec": { | ||
315 | + "display_name": "ML", | ||
316 | + "language": "python", | ||
317 | + "name": "ml" | ||
318 | + }, | ||
319 | + "language_info": { | ||
320 | + "codemirror_mode": { | ||
321 | + "name": "ipython", | ||
322 | + "version": 3 | ||
323 | + }, | ||
324 | + "file_extension": ".py", | ||
325 | + "mimetype": "text/x-python", | ||
326 | + "name": "python", | ||
327 | + "nbconvert_exporter": "python", | ||
328 | + "pygments_lexer": "ipython3", | ||
329 | + "version": "3.7.4" | ||
330 | + } | ||
331 | + }, | ||
332 | + "nbformat": 4, | ||
333 | + "nbformat_minor": 2 | ||
334 | +} |
code/filter_normal_nii.ipynb
0 → 100644
This diff could not be displayed because it is too large.
code/flair2seg_1.m
0 → 100644
1 | +%load('..\data\MICCAI_BraTS_2019_Data_Training\name_mapping.csv') | ||
2 | + | ||
3 | +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\'; | ||
4 | +outfolder = '..\testfolder\'; | ||
5 | +id = 'BraTS19_2013_2_1'; | ||
6 | + | ||
7 | +type = 'flair.nii'; | ||
8 | +filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii | ||
9 | +flair_path = strcat(inputheader, id, '\', filename,'\', filename); | ||
10 | +%disp(path); | ||
11 | +flair = niftiread(flair_path); %size 240x240x155 | ||
12 | +cp_flair = flair; | ||
13 | + | ||
14 | +type = 'seg.nii'; | ||
15 | +filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii | ||
16 | +seg_path = strcat(inputheader, id, '\', filename, '\', filename); | ||
17 | +seg = niftiread(seg_path); | ||
18 | + | ||
19 | +[x,y,z] = size(seg); | ||
20 | + | ||
21 | +% copy flair, segment flair data | ||
22 | + | ||
23 | +cp_flair(seg == 0) = 0; | ||
24 | + | ||
25 | +% save a segmented data | ||
26 | +type = 'seg_flair.nii'; | ||
27 | +filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg_flair.nii | ||
28 | +outpath = strcat(outfolder, filename); | ||
29 | +%niftiwrite(cp_flair, outpath); | ||
30 | + | ||
31 | + | ||
32 | +%whos seg | ||
33 | + | ||
34 | + | ||
35 | +%extract = seg(84, :, 86); | ||
36 | + | ||
37 | + | ||
38 | + | ||
39 | + | ||
40 | +% cp84 = cp_flair(84,:,86); | ||
41 | +% flair84 = flair(84,:, 86); | ||
42 | + | ||
43 | + | ||
44 | + | ||
45 | +%[flair,info] = ReadData3D(filename) | ||
46 | +% whos flair | ||
47 | + | ||
48 | +%volumeViewer(flair); | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/flair2seg_all.m
0 → 100644
1 | +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\'; | ||
2 | +outfolder = strcat('..\data\MICCAI_BraTS_2019_Data_Training\HGG_seg_flair\'); | ||
3 | + | ||
4 | +files = dir(inputheader); | ||
5 | +id = {files.name}; | ||
6 | +% files + dir dir | ||
7 | +dirFlag = [files.isdir] & ~strcmp(id, '.') & ~strcmp(id, '..'); | ||
8 | +subFolders = files(dirFlag); | ||
9 | +disp(length(subFolders)); | ||
10 | + | ||
11 | +% for k = 1 : length(subFolders) | ||
12 | +% fprintf('Sub folder #%d = %s\n', k, subFolders(k).name); | ||
13 | +% end | ||
14 | + | ||
15 | +for i = 1 : length(subFolders) | ||
16 | + | ||
17 | + id = subFolders(i).name; | ||
18 | + fprintf('Sub folder #%d = %s\n', i, id); | ||
19 | + | ||
20 | + type = 'flair.nii'; | ||
21 | + filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii | ||
22 | + flair_path = strcat(inputheader, id, '\', filename,'\', filename); | ||
23 | + flair = niftiread(flair_path); %size 240x240x155 | ||
24 | + cp_flair = flair; | ||
25 | + | ||
26 | + type = 'seg.nii'; | ||
27 | + filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg.nii | ||
28 | + seg_path = strcat(inputheader, id, '\', filename, '\', filename); | ||
29 | + seg = niftiread(seg_path); | ||
30 | + | ||
31 | + [x,y,z] = size(seg); | ||
32 | + | ||
33 | + % copy flair, segment flair data | ||
34 | + | ||
35 | + cp_flair(seg == 0) = 0; | ||
36 | + | ||
37 | + % save a segmented data | ||
38 | + type = 'seg_flair.nii'; | ||
39 | + filename = strcat(id,'_', type); % BraTS19_2013_2_1_seg_flair.nii | ||
40 | + outpath = strcat(outfolder, filename); | ||
41 | + niftiwrite(cp_flair, outpath); | ||
42 | + | ||
43 | +end |
code/loadfile.m
0 → 100644
1 | +inputheader = '..\data\MICCAI_BraTS_2019_Data_Training\HGG\'; | ||
2 | +outfolder = '..\testfolder\'; | ||
3 | +id = 'BraTS19_2013_2_1'; | ||
4 | + | ||
5 | +type = 'seg_flair.nii'; | ||
6 | +filename = strcat(id,'_', type); % BraTS19_2013_2_1_flair.nii | ||
7 | +flair_path = strcat(outfolder, filename); | ||
8 | +%disp(path); | ||
9 | +ffffff = niftiread(flair_path); %size 240x240x155 | ||
10 | + | ||
11 | +fff84 = ffffff(84, :, 86); | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment