Showing
18 changed files
with
1420 additions
and
0 deletions
source code/adda_mixup/.gitignore
0 → 100644
1 | +# # Created by .ignore support plugin (hsz.mobi) | ||
2 | +# ### Example user template template | ||
3 | +# ### Example user template | ||
4 | + | ||
5 | +# # IntelliJ project files | ||
6 | +# .idea | ||
7 | +# *.iml | ||
8 | +# out | ||
9 | +# gen | ||
10 | +# ### Python template | ||
11 | +# # Byte-compiled / optimized / DLL files | ||
12 | +__pycache__/ | ||
13 | +# *.py[cod] | ||
14 | +# *$py.class | ||
15 | + | ||
16 | +# # C extensions | ||
17 | +# *.so | ||
18 | + | ||
19 | +# # Distribution / packaging | ||
20 | +# .Python | ||
21 | +# build/ | ||
22 | +# develop-eggs/ | ||
23 | +# dist/ | ||
24 | +# downloads/ | ||
25 | +# eggs/ | ||
26 | +# .eggs/ | ||
27 | +# lib/ | ||
28 | +# lib64/ | ||
29 | +# parts/ | ||
30 | +# sdist/ | ||
31 | +# var/ | ||
32 | +# wheels/ | ||
33 | +# pip-wheel-metadata/ | ||
34 | +# share/python-wheels/ | ||
35 | +# *.egg-info/ | ||
36 | +# .installed.cfg | ||
37 | +# *.egg | ||
38 | +# MANIFEST | ||
39 | + | ||
40 | +# # PyInstaller | ||
41 | +# # Usually these files are written by a python script from a template | ||
42 | +# # before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
43 | +# *.manifest | ||
44 | +# *.spec | ||
45 | + | ||
46 | +# # Installer logs | ||
47 | +# pip-log.txt | ||
48 | +# pip-delete-this-directory.txt | ||
49 | + | ||
50 | +# # Unit test / coverage reports | ||
51 | +# htmlcov/ | ||
52 | +# .tox/ | ||
53 | +# .nox/ | ||
54 | +# .coverage | ||
55 | +# .coverage.* | ||
56 | +# .cache | ||
57 | +# nosetests.xml | ||
58 | +# coverage.xml | ||
59 | +# *.cover | ||
60 | +# .hypothesis/ | ||
61 | +# .pytest_cache/ | ||
62 | + | ||
63 | +# # # Translations | ||
64 | +# *.mo | ||
65 | +# *.pot | ||
66 | + | ||
67 | +# # Django stuff: | ||
68 | +# *.log | ||
69 | +# local_settings.py | ||
70 | +# db.sqlite3 | ||
71 | + | ||
72 | +# # Flask stuff: | ||
73 | +# instance/ | ||
74 | +# .webassets-cache | ||
75 | + | ||
76 | +# # Scrapy stuff: | ||
77 | +# .scrapy | ||
78 | + | ||
79 | +# # Sphinx documentation | ||
80 | +# docs/_build/ | ||
81 | + | ||
82 | +# # PyBuilder | ||
83 | +# target/ | ||
84 | + | ||
85 | +# # Jupyter Notebook | ||
86 | +# .ipynb_checkpoints | ||
87 | + | ||
88 | +# # IPython | ||
89 | +# profile_default/ | ||
90 | +# ipython_config.py | ||
91 | + | ||
92 | +# # pyenv | ||
93 | +# .python-version | ||
94 | + | ||
95 | +# # pipenv | ||
96 | +# # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
97 | +# # However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
98 | +# # having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
99 | +# # install all needed dependencies. | ||
100 | +# #Pipfile.lock | ||
101 | + | ||
102 | +# # celery beat schedule file | ||
103 | +# celerybeat-schedule | ||
104 | + | ||
105 | +# # SageMath parsed files | ||
106 | +# *.sage.py | ||
107 | + | ||
108 | +# # Environments | ||
109 | +# .env | ||
110 | +# .venv | ||
111 | +# env/ | ||
112 | +# venv/ | ||
113 | +# ENV/ | ||
114 | +# env.bak/ | ||
115 | +# venv.bak/ | ||
116 | + | ||
117 | +# # Spyder project settings | ||
118 | +# .spyderproject | ||
119 | +# .spyproject | ||
120 | + | ||
121 | +# # Rope project settings | ||
122 | +# .ropeproject | ||
123 | + | ||
124 | +# # mkdocs documentation | ||
125 | +# /site | ||
126 | + | ||
127 | +# # mypy | ||
128 | +# .mypy_cache/ | ||
129 | +# .dmypy.json | ||
130 | +# dmypy.json | ||
131 | + | ||
132 | +# # Pyre type checker | ||
133 | +# .pyre/ | ||
134 | + | ||
135 | +# ### JetBrains template | ||
136 | +# # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm | ||
137 | +# # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||
138 | + | ||
139 | +# # User-specific stuff | ||
140 | +# .idea/**/workspace.xml | ||
141 | +# .idea/**/tasks.xml | ||
142 | +# .idea/**/usage.statistics.xml | ||
143 | +# .idea/**/dictionaries | ||
144 | +# .idea/**/shelf | ||
145 | + | ||
146 | +# # Generated files | ||
147 | +# .idea/**/contentModel.xml | ||
148 | + | ||
149 | +# # Sensitive or high-churn files | ||
150 | +# .idea/**/dataSources/ | ||
151 | +# .idea/**/dataSources.ids | ||
152 | +# .idea/**/dataSources.local.xml | ||
153 | +# .idea/**/sqlDataSources.xml | ||
154 | +# .idea/**/dynamic.xml | ||
155 | +# .idea/**/uiDesigner.xml | ||
156 | +# .idea/**/dbnavigator.xml | ||
157 | + | ||
158 | +# # Gradle | ||
159 | +# .idea/**/gradle.xml | ||
160 | +# .idea/**/libraries | ||
161 | + | ||
162 | +# # Gradle and Maven with auto-import | ||
163 | +# # When using Gradle or Maven with auto-import, you should exclude module files, | ||
164 | +# # since they will be recreated, and may cause churn. Uncomment if using | ||
165 | +# # auto-import. | ||
166 | +# # .idea/modules.xml | ||
167 | +# # .idea/*.iml | ||
168 | +# # .idea/modules | ||
169 | +# # *.iml | ||
170 | +# # *.ipr | ||
171 | + | ||
172 | +# # CMake | ||
173 | +# cmake-build-*/ | ||
174 | + | ||
175 | +# # Mongo Explorer plugin | ||
176 | +# .idea/**/mongoSettings.xml | ||
177 | + | ||
178 | +# # File-based project format | ||
179 | +# *.iws | ||
180 | + | ||
181 | +# # IntelliJ | ||
182 | +# out/ | ||
183 | + | ||
184 | +# # mpeltonen/sbt-idea plugin | ||
185 | +# .idea_modules/ | ||
186 | + | ||
187 | +# # JIRA plugin | ||
188 | +# atlassian-ide-plugin.xml | ||
189 | + | ||
190 | +# # Cursive Clojure plugin | ||
191 | +# .idea/replstate.xml | ||
192 | + | ||
193 | +# # Crashlytics plugin (for Android Studio and IntelliJ) | ||
194 | +# com_crashlytics_export_strings.xml | ||
195 | +# crashlytics.properties | ||
196 | +# crashlytics-build.properties | ||
197 | +# fabric.properties | ||
198 | + | ||
199 | +# # Editor-based Rest Client | ||
200 | +# .idea/httpRequests | ||
201 | + | ||
202 | +# # Android studio 3.1+ serialized cache file | ||
203 | +# .idea/caches/build_file_checksums.ser | ||
204 | + | ||
205 | +./data/ | ||
206 | +data/ | ||
207 | +# .idea/ | ||
208 | +generated/ |
source code/adda_mixup/README.md
0 → 100644
1 | +# PyTorch-ADDA-mixup | ||
2 | +A PyTorch implementation added MIXUO for Adversarial Discriminative Domain Adaptation. | ||
3 | + | ||
4 | +Confirmed improved performance by mixing up target domian and source domain | ||
5 | + | ||
6 | +# Usage | ||
7 | +It works on MNIST -> USPS , SVHN -> MNIST , USPS -> MNIST, MNIST -> MNIST-M | ||
8 | +Only 10,000 of the total data were used.(usps excluded) | ||
9 | + | ||
10 | +<pre> | ||
11 | +<code> | ||
12 | +python main.py | ||
13 | +</code> | ||
14 | +</pre> | ||
15 | + | ||
16 | +## adda | ||
17 | +This repo is based on https://github.com/corenel/pytorch-adda , https://github.com/Fujiki-Nakamura/ADDA.PyTorch | ||
18 | + | ||
19 | + | ||
20 | + | ||
21 | +Reference | ||
22 | +https://arxiv.org/abs/1702.05464 | ||
23 | + | ||
24 | + | ||
25 | +# |
source code/adda_mixup/core/__init__.py
0 → 100644
source code/adda_mixup/core/adapt.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.optim as optim | ||
5 | +from torch import nn | ||
6 | +from core import test | ||
7 | +import params | ||
8 | +from utils import make_cuda, mixup_data | ||
9 | + | ||
10 | + | ||
11 | + | ||
12 | +def train_tgt(source_cnn, target_cnn, critic, | ||
13 | + src_data_loader, tgt_data_loader): | ||
14 | + """Train encoder for target domain.""" | ||
15 | + #################### | ||
16 | + # 1. setup network # | ||
17 | + #################### | ||
18 | + | ||
19 | + source_cnn.eval() | ||
20 | + target_cnn.encoder.train() | ||
21 | + critic.train() | ||
22 | + isbest = 0 | ||
23 | + # setup criterion and optimizer | ||
24 | + criterion = nn.CrossEntropyLoss() | ||
25 | + #target encoder | ||
26 | + optimizer_tgt = optim.Adam(target_cnn.parameters(), | ||
27 | + lr=params.adp_c_learning_rate, | ||
28 | + betas=(params.beta1, params.beta2), | ||
29 | + weight_decay=params.weight_decay | ||
30 | + ) | ||
31 | + #Discriminator | ||
32 | + optimizer_critic = optim.Adam(critic.parameters(), | ||
33 | + lr=params.d_learning_rate, | ||
34 | + betas=(params.beta1, params.beta2), | ||
35 | + weight_decay=params.weight_decay | ||
36 | + | ||
37 | + | ||
38 | + ) | ||
39 | + | ||
40 | + #################### | ||
41 | + # 2. train network # | ||
42 | + #################### | ||
43 | + len_data_loader = min(len(src_data_loader), len(tgt_data_loader)) | ||
44 | + | ||
45 | + for epoch in range(params.num_epochs): | ||
46 | + # zip source and target data pair | ||
47 | + data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) | ||
48 | + for step, ((images_src, _), (images_tgt, _)) in data_zip: | ||
49 | + | ||
50 | + # make images variable | ||
51 | + images_src = make_cuda(images_src) | ||
52 | + images_tgt = make_cuda(images_tgt) | ||
53 | + | ||
54 | + | ||
55 | + | ||
56 | + | ||
57 | + ########################### | ||
58 | + # 2.1 train discriminator # | ||
59 | + ########################### | ||
60 | + | ||
61 | + # zero gradients for optimizer | ||
62 | + optimizer_critic.zero_grad() | ||
63 | + | ||
64 | + # extract and concat features | ||
65 | + feat_src = source_cnn.encoder(images_src) | ||
66 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
67 | + feat_concat = torch.cat((feat_src, feat_tgt), 0) | ||
68 | + | ||
69 | + # predict on discriminator | ||
70 | + pred_concat = critic(feat_concat.detach()) | ||
71 | + | ||
72 | + # prepare real and fake label | ||
73 | + label_src = make_cuda(torch.zeros(feat_src.size(0)).long()) | ||
74 | + label_tgt = make_cuda(torch.ones(feat_tgt.size(0)).long()) | ||
75 | + label_concat = torch.cat((label_src, label_tgt), 0) | ||
76 | + | ||
77 | + # compute loss for critic | ||
78 | + loss_critic = criterion(pred_concat, label_concat) | ||
79 | + loss_critic.backward() | ||
80 | + | ||
81 | + # optimize critic | ||
82 | + optimizer_critic.step() | ||
83 | + | ||
84 | + pred_cls = torch.squeeze(pred_concat.max(1)[1]) | ||
85 | + acc = (pred_cls == label_concat).float().mean() | ||
86 | + | ||
87 | + | ||
88 | + ############################ | ||
89 | + # 2.2 train target encoder # | ||
90 | + ############################ | ||
91 | + | ||
92 | + # zero gradients for optimizer | ||
93 | + optimizer_critic.zero_grad() | ||
94 | + optimizer_tgt.zero_grad() | ||
95 | + | ||
96 | + # extract and target features | ||
97 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
98 | + | ||
99 | + # predict on discriminator | ||
100 | + pred_tgt = critic(feat_tgt) | ||
101 | + | ||
102 | + # prepare fake labels | ||
103 | + label_tgt = make_cuda(torch.zeros(feat_tgt.size(0)).long()) | ||
104 | + | ||
105 | + # compute loss for target encoder | ||
106 | + loss_tgt = criterion(pred_tgt, label_tgt) | ||
107 | + loss_tgt.backward() | ||
108 | + | ||
109 | + # optimize target encoder | ||
110 | + optimizer_tgt.step() | ||
111 | + ####################### | ||
112 | + # 2.3 print step info # | ||
113 | + ####################### | ||
114 | + if ((epoch % 10 ==0 )&((step + 1) % len_data_loader== 0)): | ||
115 | + print("Epoch [{}/{}] Step [{}/{}]:" | ||
116 | + "d_loss={:.5f} g_loss={:.5f} acc={:.5f}" | ||
117 | + .format(epoch, | ||
118 | + params.num_epochs, | ||
119 | + step + 1, | ||
120 | + len_data_loader, | ||
121 | + loss_critic.item(), | ||
122 | + loss_tgt.item(), | ||
123 | + acc.item())) | ||
124 | + | ||
125 | + | ||
126 | + torch.save(critic.state_dict(), os.path.join( | ||
127 | + params.model_root, | ||
128 | + "ADDA-critic-final.pt")) | ||
129 | + torch.save(target_cnn.state_dict(), os.path.join( | ||
130 | + params.model_root, | ||
131 | + "ADDA-target_cnn-final.pt")) | ||
132 | + return target_cnn | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/core/mixup.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | +import torch | ||
3 | +import torch.optim as optim | ||
4 | + | ||
5 | +import params | ||
6 | +from utils import make_cuda, save_model, LabelSmoothingCrossEntropy,mixup_data,mixup_criterion | ||
7 | +from random import * | ||
8 | +import sys | ||
9 | + | ||
10 | +from torch.utils.data import Dataset,DataLoader | ||
11 | +import os | ||
12 | +from core import test | ||
13 | +from utils import make_cuda, mixup_data | ||
14 | + | ||
15 | + | ||
16 | + | ||
17 | +class CustomDataset(Dataset): | ||
18 | + def __init__(self,img,label): | ||
19 | + self.x_data = img | ||
20 | + self.y_data = label | ||
21 | + def __len__(self): | ||
22 | + return len(self.x_data) | ||
23 | + | ||
24 | + def __getitem__(self, idx): | ||
25 | + x = self.x_data[idx] | ||
26 | + y = self.y_data[idx] | ||
27 | + return x, y | ||
28 | + | ||
29 | + | ||
30 | +def train_src(model, source_data_loader,target_data_loader,valid_loader): | ||
31 | + """Train classifier for source domain.""" | ||
32 | + #################### | ||
33 | + # 1. setup network # | ||
34 | + #################### | ||
35 | + | ||
36 | + model.train() | ||
37 | + | ||
38 | + | ||
39 | + target_data_loader = list(target_data_loader) | ||
40 | + | ||
41 | + # setup criterion and optimizer | ||
42 | + optimizer = optim.Adam( | ||
43 | + model.parameters(), | ||
44 | + lr=params.pre_c_learning_rate, | ||
45 | + betas=(params.beta1, params.beta2), | ||
46 | + weight_decay=params.weight_decay | ||
47 | + ) | ||
48 | + | ||
49 | + | ||
50 | + | ||
51 | + if params.labelsmoothing: | ||
52 | + criterion = LabelSmoothingCrossEntropy(smoothing= params.smoothing) | ||
53 | + else: | ||
54 | + criterion = nn.CrossEntropyLoss() | ||
55 | + | ||
56 | + | ||
57 | + #################### | ||
58 | + # 2. train network # | ||
59 | + #################### | ||
60 | + len_data_loader = min(len(source_data_loader), len(target_data_loader)) | ||
61 | + | ||
62 | + for epoch in range(params.num_epochs_pre+1): | ||
63 | + data_zip = enumerate(zip(source_data_loader, target_data_loader)) | ||
64 | + for step, ((images, labels), (images_tgt, _)) in data_zip: | ||
65 | + # make images and labels variable | ||
66 | + images = make_cuda(images) | ||
67 | + labels = make_cuda(labels.squeeze_()) | ||
68 | + # zero gradients for optimizer | ||
69 | + optimizer.zero_grad() | ||
70 | + target=target_data_loader[randint(0, len(target_data_loader)-1)] | ||
71 | + images, lam = mixup_data(images,target[0]) | ||
72 | + | ||
73 | + # compute loss for critic | ||
74 | + preds = model(images) | ||
75 | + # loss = mixup_criterion(criterion, preds, labels, labels_tgt, lam) | ||
76 | + loss = criterion(preds, labels) | ||
77 | + | ||
78 | + # optimize source classifier | ||
79 | + loss.backward() | ||
80 | + optimizer.step() | ||
81 | + | ||
82 | + | ||
83 | + | ||
84 | + # # eval model on test set | ||
85 | + if ((epoch) % params.eval_step_pre == 0): | ||
86 | + print(f"Epoch [{epoch}/{params.num_epochs_pre}]",end='') | ||
87 | + if valid_loader is not None: | ||
88 | + test.eval_tgt(model, valid_loader) | ||
89 | + else: | ||
90 | + test.eval_tgt(model, source_data_loader) | ||
91 | + | ||
92 | + # save model parameters | ||
93 | + if ((epoch + 1) % params.save_step_pre == 0): | ||
94 | + save_model(model, "our-source_cnn-{}.pt".format(epoch + 1)) | ||
95 | + | ||
96 | + # # save final model | ||
97 | + save_model(model, "our-source_cnn-final.pt") | ||
98 | + | ||
99 | + return model | ||
100 | + | ||
101 | + | ||
102 | + | ||
103 | + | ||
104 | +def train_tgt(source_cnn, target_cnn, critic, | ||
105 | + src_data_loader, tgt_data_loader,valid_loader): | ||
106 | + """Train encoder for target domain.""" | ||
107 | + #################### | ||
108 | + # 1. setup network # | ||
109 | + #################### | ||
110 | + | ||
111 | + source_cnn.eval() | ||
112 | + target_cnn.encoder.train() | ||
113 | + critic.train() | ||
114 | + isbest = 0 | ||
115 | + # setup criterion and optimizer | ||
116 | + criterion = nn.CrossEntropyLoss() | ||
117 | + #target encoder | ||
118 | + optimizer_tgt = optim.Adam(target_cnn.parameters(), | ||
119 | + lr=params.adp_c_learning_rate, | ||
120 | + betas=(params.beta1, params.beta2), | ||
121 | + weight_decay=params.weight_decay | ||
122 | + ) | ||
123 | + #Discriminator | ||
124 | + optimizer_critic = optim.Adam(critic.parameters(), | ||
125 | + lr=params.d_learning_rate, | ||
126 | + betas=(params.beta1, params.beta2), | ||
127 | + weight_decay=params.weight_decay | ||
128 | + | ||
129 | + | ||
130 | + ) | ||
131 | + | ||
132 | + #################### | ||
133 | + # 2. train network # | ||
134 | + #################### | ||
135 | + data_len = min(len(src_data_loader), len(tgt_data_loader)) | ||
136 | + | ||
137 | + for epoch in range(params.num_epochs+1): | ||
138 | + # zip source and target data pair | ||
139 | + data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) | ||
140 | + for step, ((images_src, _), (images_tgt, _)) in data_zip: | ||
141 | + | ||
142 | + # make images variable | ||
143 | + images_src = make_cuda(images_src) | ||
144 | + images_tgt = make_cuda(images_tgt) | ||
145 | + | ||
146 | + ########################### | ||
147 | + # 2.1 train discriminator # | ||
148 | + ########################### | ||
149 | + | ||
150 | + # mixup data | ||
151 | + images_src, _ = mixup_data(images_src,images_tgt) | ||
152 | + | ||
153 | + # zero gradients for optimizer | ||
154 | + optimizer_critic.zero_grad() | ||
155 | + | ||
156 | + # extract and concat features | ||
157 | + feat_src = source_cnn.encoder(images_src) | ||
158 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
159 | + feat_concat = torch.cat((feat_src, feat_tgt), 0) | ||
160 | + | ||
161 | + # predict on discriminator | ||
162 | + pred_concat = critic(feat_concat.detach()) | ||
163 | + | ||
164 | + # prepare real and fake label | ||
165 | + label_src = make_cuda(torch.zeros(feat_src.size(0)).long()) | ||
166 | + label_tgt = make_cuda(torch.ones(feat_tgt.size(0)).long()) | ||
167 | + label_concat = torch.cat((label_src, label_tgt), 0) | ||
168 | + | ||
169 | + # compute loss for critic | ||
170 | + loss_critic = criterion(pred_concat, label_concat) | ||
171 | + loss_critic.backward() | ||
172 | + | ||
173 | + # optimize critic | ||
174 | + optimizer_critic.step() | ||
175 | + | ||
176 | + pred_cls = torch.squeeze(pred_concat.max(1)[1]) | ||
177 | + acc = (pred_cls == label_concat).float().mean() | ||
178 | + | ||
179 | + | ||
180 | + ############################ | ||
181 | + # 2.2 train target encoder # | ||
182 | + ############################ | ||
183 | + | ||
184 | + # zero gradients for optimizer | ||
185 | + optimizer_critic.zero_grad() | ||
186 | + optimizer_tgt.zero_grad() | ||
187 | + | ||
188 | + # extract and target features | ||
189 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
190 | + | ||
191 | + # predict on discriminator | ||
192 | + pred_tgt = critic(feat_tgt) | ||
193 | + | ||
194 | + # prepare fake labels | ||
195 | + label_tgt = make_cuda(torch.zeros(feat_tgt.size(0)).long()) | ||
196 | + | ||
197 | + # compute loss for target encoder | ||
198 | + loss_tgt = criterion(pred_tgt, label_tgt) | ||
199 | + loss_tgt.backward() | ||
200 | + | ||
201 | + # optimize target encoder | ||
202 | + optimizer_tgt.step() | ||
203 | + ####################### | ||
204 | + # 2.3 print step info # | ||
205 | + ####################### | ||
206 | + if ((epoch % 10 ==0 )& ((step + 1) % data_len == 0)): | ||
207 | + print("Epoch [{}/{}] Step [{}/{}]:" | ||
208 | + "d_loss={:.5f} g_loss={:.5f} acc={:.5f}" | ||
209 | + .format(epoch , | ||
210 | + params.num_epochs, | ||
211 | + step + 1, | ||
212 | + data_len, | ||
213 | + loss_critic.item(), | ||
214 | + loss_tgt.item(), | ||
215 | + acc.item())) | ||
216 | + if valid_loader is not None: | ||
217 | + test.eval_tgt(target_cnn,valid_loader) | ||
218 | + | ||
219 | + torch.save(critic.state_dict(), os.path.join( | ||
220 | + params.model_root, | ||
221 | + "our-critic-final.pt")) | ||
222 | + torch.save(target_cnn.state_dict(), os.path.join( | ||
223 | + params.model_root, | ||
224 | + "our-target_cnn-final.pt")) | ||
225 | + return target_cnn | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/core/pretrain.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | +import torch.optim as optim | ||
3 | +import torch | ||
4 | +import params | ||
5 | +from utils import make_cuda, save_model, LabelSmoothingCrossEntropy,mixup_data | ||
6 | +from random import * | ||
7 | +import sys | ||
8 | + | ||
9 | +def train_src(model, source_data_loader): | ||
10 | + """Train classifier for source domain.""" | ||
11 | + #################### | ||
12 | + # 1. setup network # | ||
13 | + #################### | ||
14 | + | ||
15 | + model.train() | ||
16 | + | ||
17 | + | ||
18 | + | ||
19 | + # setup criterion and optimizer | ||
20 | + optimizer = optim.Adam( | ||
21 | + model.parameters(), | ||
22 | + lr=params.pre_c_learning_rate, | ||
23 | + betas=(params.beta1, params.beta2), | ||
24 | + weight_decay=params.weight_decay | ||
25 | + ) | ||
26 | + | ||
27 | + | ||
28 | + if params.labelsmoothing: | ||
29 | + criterion = LabelSmoothingCrossEntropy(smoothing= params.smoothing) | ||
30 | + else: | ||
31 | + criterion = nn.CrossEntropyLoss() | ||
32 | + | ||
33 | + | ||
34 | + #################### | ||
35 | + # 2. train network # | ||
36 | + #################### | ||
37 | + | ||
38 | + for epoch in range(params.num_epochs_pre): | ||
39 | + for step, (images, labels) in enumerate(source_data_loader): | ||
40 | + # make images and labels variable | ||
41 | + images = make_cuda(images) | ||
42 | + labels = make_cuda(labels.squeeze_()) | ||
43 | + # zero gradients for optimizer | ||
44 | + optimizer.zero_grad() | ||
45 | + | ||
46 | + # compute loss for critic | ||
47 | + preds = model(images) | ||
48 | + loss = criterion(preds, labels) | ||
49 | + | ||
50 | + # optimize source classifier | ||
51 | + loss.backward() | ||
52 | + optimizer.step() | ||
53 | + | ||
54 | + | ||
55 | + | ||
56 | + # # eval model on test set | ||
57 | + if ((epoch ) % params.eval_step_pre == 0): | ||
58 | + print(f"Epoch [{epoch}/{params.num_epochs_pre}]",end='') | ||
59 | + eval_src(model, source_data_loader) | ||
60 | + | ||
61 | + # save model parameters | ||
62 | + if ((epoch + 1) % params.save_step_pre == 0): | ||
63 | + save_model(model, "ADDA-source_cnn-{}.pt".format(epoch + 1)) | ||
64 | + | ||
65 | + # # save final model | ||
66 | + save_model(model, "ADDA-source_cnn-final.pt") | ||
67 | + | ||
68 | + return model | ||
69 | + | ||
70 | +def eval_src(model, data_loader): | ||
71 | + """Evaluate classifier for source domain.""" | ||
72 | + # set eval state for Dropout and BN layers | ||
73 | + model.eval() | ||
74 | + with torch.no_grad(): | ||
75 | + # init loss and accuracy | ||
76 | + loss = 0 | ||
77 | + acc = 0 | ||
78 | + | ||
79 | + # evaluate network | ||
80 | + for (images, labels) in data_loader: | ||
81 | + | ||
82 | + images = make_cuda(images) | ||
83 | + labels = make_cuda(labels).squeeze_() | ||
84 | + | ||
85 | + preds = model(images) | ||
86 | + | ||
87 | + pred_cls = preds.data.max(1)[1] | ||
88 | + acc += pred_cls.eq(labels.data).cpu().sum().item() | ||
89 | + | ||
90 | + acc /= len(data_loader.dataset) | ||
91 | + | ||
92 | + print("Avg Accuracy = {:2%}".format( acc)) | ||
93 | + | ||
94 | + |
source code/adda_mixup/core/test.py
0 → 100644
1 | +#!/usr/bin/env python3 | ||
2 | +# -*- coding: utf-8 -*- | ||
3 | +""" | ||
4 | +Created on Wed Dec 5 15:03:50 2018 | ||
5 | + | ||
6 | +@author: gaoyi | ||
7 | +""" | ||
8 | + | ||
9 | +import torch | ||
10 | +import torch.nn as nn | ||
11 | + | ||
12 | +from utils import make_cuda | ||
13 | + | ||
14 | + | ||
15 | +def eval_tgt(model, data_loader): | ||
16 | + """Evaluation for target encoder by source classifier on target dataset.""" | ||
17 | + # set eval state for Dropout and BN layers | ||
18 | + model.eval() | ||
19 | + # init loss and accuracy | ||
20 | + loss = 0 | ||
21 | + acc = 0 | ||
22 | + with torch.no_grad(): | ||
23 | + # evaluate network | ||
24 | + for (images, labels) in data_loader: | ||
25 | + images = make_cuda(images) | ||
26 | + labels = make_cuda(labels).squeeze_() | ||
27 | + | ||
28 | + preds = model(images) | ||
29 | + _, preds = torch.max(preds.data, 1) | ||
30 | + acc += (preds == labels).float().sum()/images.shape[0] | ||
31 | + | ||
32 | + acc /= len(data_loader) | ||
33 | + | ||
34 | + print("Avg Accuracy = {:2%}".format(acc)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/__init__.py
0 → 100644
source code/adda_mixup/dataset/customdata.py
0 → 100644
1 | +from torchvision import transforms, datasets | ||
2 | +import torch | ||
3 | +import params | ||
4 | + | ||
5 | +def get_custom(train,adp=False,size = 0): | ||
6 | + | ||
7 | + pre_process = transforms.Compose([transforms.Resize(params.image_size), | ||
8 | + transforms.ToTensor(), | ||
9 | + # transforms.Normalize((0.5),(0.5)), | ||
10 | + ]) | ||
11 | + custom_dataset = datasets.ImageFolder( | ||
12 | + root = params.custom_dataset_root , | ||
13 | + transform = pre_process, | ||
14 | + ) | ||
15 | + length = len(custom_dataset) | ||
16 | + train_set, val_set = torch.utils.data.random_split(custom_dataset, [int(length*0.9), length-int(length*0.9)]) | ||
17 | + | ||
18 | + if train: | ||
19 | + train_set,_ = torch.utils.data.random_split(train_set, [size,len(train_set)-size]) | ||
20 | + | ||
21 | + | ||
22 | + | ||
23 | + custom_data_loader = torch.utils.data.DataLoader( | ||
24 | + train_set if train else val_set, | ||
25 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
26 | + shuffle=True, | ||
27 | + drop_last=True | ||
28 | + | ||
29 | + ) | ||
30 | + | ||
31 | + return custom_data_loader | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/mnist.py
0 → 100644
1 | + | ||
2 | + | ||
3 | +import torch | ||
4 | +from torchvision import datasets, transforms | ||
5 | +import torch.utils.data as data_utils | ||
6 | + | ||
7 | +import params | ||
8 | + | ||
9 | + | ||
10 | +def get_mnist(train,adp = False,size = 0): | ||
11 | + """Get MNIST dataset loader.""" | ||
12 | + # image pre-processing | ||
13 | + pre_process = transforms.Compose([transforms.Resize(params.image_size), | ||
14 | + transforms.ToTensor(), | ||
15 | +# transforms.Normalize((0.5),(0.5)), | ||
16 | + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||
17 | + | ||
18 | + | ||
19 | + ]) | ||
20 | + | ||
21 | + | ||
22 | + | ||
23 | + | ||
24 | + # dataset and data loader | ||
25 | + mnist_dataset = datasets.MNIST(root=params.mnist_dataset_root, | ||
26 | + train=train, | ||
27 | + transform=pre_process, | ||
28 | + | ||
29 | + download=True) | ||
30 | + if train: | ||
31 | + # perm = torch.randperm(len(mnist_dataset)) | ||
32 | + # indices = perm[:10000] | ||
33 | + mnist_dataset,_ = data_utils.random_split(mnist_dataset, [size,len(mnist_dataset)-size]) | ||
34 | + # size = len(mnist_dataset) | ||
35 | + # train, valid = data_utils.random_split(mnist_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
36 | + # train_loader = torch.utils.data.DataLoader( | ||
37 | + # dataset=train, | ||
38 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
39 | + # shuffle=True, | ||
40 | + # drop_last=True) | ||
41 | + # valid_loader = torch.utils.data.DataLoader( | ||
42 | + # dataset=valid, | ||
43 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
44 | + # shuffle=True, | ||
45 | + # drop_last=True) | ||
46 | + | ||
47 | + # return train_loader,valid_loader | ||
48 | + | ||
49 | + mnist_data_loader = torch.utils.data.DataLoader( | ||
50 | + dataset=mnist_dataset, | ||
51 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
52 | + shuffle=True, | ||
53 | + drop_last=True) | ||
54 | + return mnist_data_loader | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/mnist_m.py
0 → 100644
1 | +import torch.utils.data as data | ||
2 | +from PIL import Image | ||
3 | +import os | ||
4 | +import params | ||
5 | +from torchvision import transforms | ||
6 | +import torch | ||
7 | + | ||
8 | +import torch.utils.data as data_utils | ||
9 | + | ||
10 | + | ||
11 | +class GetLoader(data.Dataset): | ||
12 | + def __init__(self, data_root, data_list, transform=None): | ||
13 | + self.root = data_root | ||
14 | + self.transform = transform | ||
15 | + | ||
16 | + f = open(data_list, 'r') | ||
17 | + data_list = f.readlines() | ||
18 | + f.close() | ||
19 | + | ||
20 | + self.n_data = len(data_list) | ||
21 | + | ||
22 | + self.img_paths = [] | ||
23 | + self.img_labels = [] | ||
24 | + | ||
25 | + for data_ in data_list: | ||
26 | + self.img_paths.append(data_[:-3]) | ||
27 | + self.img_labels.append(data_[-2]) | ||
28 | + | ||
29 | + def __getitem__(self, item): | ||
30 | + img_paths, labels = self.img_paths[item], self.img_labels[item] | ||
31 | + imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') | ||
32 | + | ||
33 | + if self.transform is not None: | ||
34 | + imgs = self.transform(imgs) | ||
35 | + labels = int(labels) | ||
36 | + | ||
37 | + return imgs, labels | ||
38 | + | ||
39 | + def __len__(self): | ||
40 | + return self.n_data | ||
41 | + | ||
42 | + | ||
43 | +def get_mnist_m(train,adp=False,size= 0 ): | ||
44 | + | ||
45 | + if train == True: | ||
46 | + mode = 'train' | ||
47 | + else: | ||
48 | + mode = 'test' | ||
49 | + | ||
50 | + train_list = os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}_labels.txt'.format(mode)) | ||
51 | + # image pre-processing | ||
52 | + pre_process = transforms.Compose([ | ||
53 | + transforms.Resize(params.image_size), | ||
54 | + # transforms.Grayscale(3), | ||
55 | + | ||
56 | + transforms.ToTensor(), | ||
57 | + | ||
58 | +# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
59 | +# transforms.Grayscale(1), | ||
60 | + ] | ||
61 | + ) | ||
62 | + | ||
63 | + dataset_target = GetLoader( | ||
64 | + data_root=os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}'.format(mode)), | ||
65 | + data_list=train_list, | ||
66 | + transform=pre_process) | ||
67 | + | ||
68 | + if train: | ||
69 | + # perm = torch.randperm(len(dataset_target)) | ||
70 | + # indices = perm[:10000] | ||
71 | + dataset_target,_ = data_utils.random_split(dataset_target, [size,len(dataset_target)-size]) | ||
72 | + # size = len(dataset_target) | ||
73 | + # train, valid = data_utils.random_split(dataset_target,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
74 | + # train_loader = torch.utils.data.DataLoader( | ||
75 | + # dataset=train, | ||
76 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
77 | + # shuffle=True, | ||
78 | + # drop_last=True) | ||
79 | + # valid_loader = torch.utils.data.DataLoader( | ||
80 | + # dataset=valid, | ||
81 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
82 | + # shuffle=True, | ||
83 | + # drop_last=True) | ||
84 | + | ||
85 | + # return train_loader,valid_loader | ||
86 | + | ||
87 | + | ||
88 | + dataloader = torch.utils.data.DataLoader( | ||
89 | + dataset=dataset_target, | ||
90 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
91 | + | ||
92 | + shuffle=True, | ||
93 | + drop_last=True) | ||
94 | + | ||
95 | + return dataloader | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/svhn.py
0 → 100644
1 | +import torch | ||
2 | +from torchvision import datasets, transforms | ||
3 | + | ||
4 | +import params | ||
5 | + | ||
6 | +import torch.utils.data as data_utils | ||
7 | + | ||
8 | +def get_svhn(train,adp=False,size=0): | ||
9 | + """Get SVHN dataset loader.""" | ||
10 | + # image pre-processing | ||
11 | + pre_process = transforms.Compose([ | ||
12 | + transforms.Resize(params.image_size), | ||
13 | + transforms.Grayscale(3), | ||
14 | + transforms.ToTensor(), | ||
15 | + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
16 | + | ||
17 | + ]) | ||
18 | + | ||
19 | + | ||
20 | + # dataset and data loader | ||
21 | + svhn_dataset = datasets.SVHN(root=params.svhn_dataset_root, | ||
22 | + split='train' if train else 'test', | ||
23 | + transform=pre_process, | ||
24 | + download=True) | ||
25 | + if train: | ||
26 | + # perm = torch.randperm(len(svhn_dataset)) | ||
27 | + # indices = perm[:10000] | ||
28 | + svhn_dataset,_ = data_utils.random_split(svhn_dataset, [size,len(svhn_dataset)-size]) | ||
29 | + # size = len(svhn_dataset) | ||
30 | + # train, valid = data_utils.random_split(svhn_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
31 | + | ||
32 | + # train_loader = torch.utils.data.DataLoader( | ||
33 | + # dataset=train, | ||
34 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
35 | + | ||
36 | + # shuffle=True, | ||
37 | + # drop_last=True) | ||
38 | + | ||
39 | + # valid_loader = torch.utils.data.DataLoader( | ||
40 | + # dataset=valid, | ||
41 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
42 | + | ||
43 | + # shuffle=True, | ||
44 | + # drop_last=True) | ||
45 | + # return train_loader,valid_loader | ||
46 | + | ||
47 | + svhn_data_loader = torch.utils.data.DataLoader( | ||
48 | + dataset=svhn_dataset, | ||
49 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
50 | + | ||
51 | + shuffle=True, | ||
52 | + drop_last=True) | ||
53 | + | ||
54 | + return svhn_data_loader | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/usps.py
0 → 100644
1 | +import torch | ||
2 | +from torchvision import datasets, transforms | ||
3 | +from torch.utils.data.dataset import random_split | ||
4 | + | ||
5 | +import params | ||
6 | +import torch.utils.data as data_utils | ||
7 | + | ||
8 | + | ||
9 | +def get_usps(train,adp=False,size=0): | ||
10 | + """Get usps dataset loader.""" | ||
11 | + # image pre-processing | ||
12 | + pre_process = transforms.Compose([transforms.Resize(params.image_size), | ||
13 | + transforms.ToTensor(), | ||
14 | + # transforms.Normalize((0.5),(0.5)), | ||
15 | + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||
16 | + # transforms.Grayscale(1), | ||
17 | + | ||
18 | + | ||
19 | + | ||
20 | + ]) | ||
21 | + | ||
22 | + | ||
23 | + # dataset and data loader | ||
24 | + usps_dataset = datasets.USPS(root=params.usps_dataset_root, | ||
25 | + train=train, | ||
26 | + transform=pre_process, | ||
27 | + download=True) | ||
28 | + | ||
29 | + | ||
30 | + if train: | ||
31 | + usps_dataset, _ = data_utils.random_split(usps_dataset, [size,len(usps_dataset)-size]) | ||
32 | + # size = len(usps_dataset) | ||
33 | + # train, valid = data_utils.random_split(usps_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
34 | + # train_loader = torch.utils.data.DataLoader( | ||
35 | + # dataset=train, | ||
36 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
37 | + # shuffle=True, | ||
38 | + # drop_last=True) | ||
39 | + # valid_loader = torch.utils.data.DataLoader( | ||
40 | + # dataset=valid, | ||
41 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
42 | + # shuffle=True, | ||
43 | + # drop_last=True) | ||
44 | + # return train_loader,valid_loader | ||
45 | + | ||
46 | + usps_data_loader = torch.utils.data.DataLoader( | ||
47 | + dataset=usps_dataset, | ||
48 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
49 | + | ||
50 | + shuffle=True, | ||
51 | + drop_last=True) | ||
52 | + return usps_data_loader | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/main.py
0 → 100644
1 | + | ||
2 | +import params | ||
3 | +from utils import get_data_loader, init_model, init_random_seed,mixup_data | ||
4 | +from core import pretrain , adapt , test,mixup | ||
5 | +import torch | ||
6 | +from models.models import * | ||
7 | +import numpy as np | ||
8 | +import sys | ||
9 | + | ||
10 | + | ||
11 | + | ||
12 | +if __name__ == '__main__': | ||
13 | + # init random seed | ||
14 | + init_random_seed(params.manual_seed) | ||
15 | + print(f"Is cuda availabel? {torch.cuda.is_available()}") | ||
16 | + | ||
17 | + | ||
18 | + | ||
19 | + #set loader | ||
20 | + print("src data loader....") | ||
21 | + src_data_loader = get_data_loader(params.src_dataset,adp=False,size = 10000) | ||
22 | + src_data_loader_eval = get_data_loader(params.src_dataset,train=False) | ||
23 | + print("tgt data loader....") | ||
24 | + tgt_data_loader = get_data_loader(params.tgt_dataset,adp=False,size = 50000) | ||
25 | + tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) | ||
26 | + print(f"scr data size : {len(src_data_loader.dataset)}") | ||
27 | + print(f"tgt data size : {len(tgt_data_loader.dataset)}") | ||
28 | + | ||
29 | + | ||
30 | + print("start training") | ||
31 | + source_cnn = CNN(in_channels=3).to("cuda") | ||
32 | + target_cnn = CNN(in_channels=3, target=True).to("cuda") | ||
33 | + discriminator = Discriminator().to("cuda") | ||
34 | + | ||
35 | + source_cnn = mixup.train_src(source_cnn, src_data_loader,tgt_data_loader,None) | ||
36 | + # source_cnn.load_state_dict(torch.load("./generated/models/our-source_cnn-final.pt")) | ||
37 | + | ||
38 | + test.eval_tgt(source_cnn, tgt_data_loader_eval) | ||
39 | + | ||
40 | + target_cnn.load_state_dict(source_cnn.state_dict()) | ||
41 | + | ||
42 | + tgt_encoder = mixup.train_tgt(source_cnn, target_cnn, discriminator, | ||
43 | + src_data_loader,tgt_data_loader,None) | ||
44 | + | ||
45 | + | ||
46 | + | ||
47 | + print("=== Evaluating classifier for encoded target domain ===") | ||
48 | + print(f"mixup : {params.lammax} {params.src_dataset} -> {params.tgt_dataset} ") | ||
49 | + print("Eval | source_cnn | src_data_loader_eval") | ||
50 | + test.eval_tgt(source_cnn, src_data_loader_eval) | ||
51 | + print(">>> Eval | source_cnn | tgt_data_loader_eval <<<") | ||
52 | + test.eval_tgt(source_cnn, tgt_data_loader_eval) | ||
53 | + print(">>> Eval | target_cnn | tgt_data_loader_eval <<<") | ||
54 | + test.eval_tgt(target_cnn, tgt_data_loader_eval) | ||
55 | + | ||
56 | + | ||
57 | + | ||
58 | + |
source code/adda_mixup/models/__init__.py
0 → 100644
File mode changed
source code/adda_mixup/models/models.py
0 → 100644
1 | +from torch import nn | ||
2 | +import torch.nn.functional as F | ||
3 | +import params | ||
4 | + | ||
5 | +class Encoder(nn.Module): | ||
6 | + def __init__(self, in_channels=1, h=256, dropout=0.5): | ||
7 | + super(Encoder, self).__init__() | ||
8 | + self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1) | ||
9 | + self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1) | ||
10 | + self.bn1 = nn.BatchNorm2d(20) | ||
11 | + self.bn2 = nn.BatchNorm2d(50) | ||
12 | + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | ||
13 | + self.relu = nn.ReLU() | ||
14 | + self.dropout =nn.Dropout2d(p= dropout) | ||
15 | + # self.dropout = nn.Dropout(dropout) | ||
16 | + self.fc1 = nn.Linear(800, 500) | ||
17 | + | ||
18 | + # for m in self.modules(): | ||
19 | + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
20 | + # nn.init.kaiming_normal_(m.weight) | ||
21 | + | ||
22 | + def forward(self, x): | ||
23 | + bs = x.size(0) | ||
24 | + x = self.pool(self.relu(self.bn1(self.conv1(x)))) | ||
25 | + x = self.pool(self.relu(self.bn2(self.dropout(self.conv2(x))))) | ||
26 | + x = x.view(bs, -1) | ||
27 | + # x = self.dropout(x)W | ||
28 | + x = self.fc1(x) | ||
29 | + return x | ||
30 | + | ||
31 | + | ||
32 | +class Classifier(nn.Module): | ||
33 | + def __init__(self, n_classes, dropout=0.5): | ||
34 | + super(Classifier, self).__init__() | ||
35 | + self.l1 = nn.Linear(500, n_classes) | ||
36 | + | ||
37 | + # for m in self.modules(): | ||
38 | + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
39 | + # nn.init.kaiming_normal_(m.weight) | ||
40 | + | ||
41 | + def forward(self, x): | ||
42 | + x = self.l1(x) | ||
43 | + return x | ||
44 | + | ||
45 | + | ||
46 | +class CNN(nn.Module): | ||
47 | + def __init__(self, in_channels=1, n_classes=10, target=False): | ||
48 | + super(CNN, self).__init__() | ||
49 | + self.encoder = Encoder(in_channels=in_channels) | ||
50 | + self.classifier = Classifier(n_classes) | ||
51 | + if target: | ||
52 | + for param in self.classifier.parameters(): | ||
53 | + param.requires_grad = False | ||
54 | + | ||
55 | + def forward(self, x): | ||
56 | + x = self.encoder(x) | ||
57 | + x = self.classifier(x) | ||
58 | + return x | ||
59 | + | ||
60 | + | ||
61 | +class Discriminator(nn.Module): | ||
62 | + def __init__(self, h=500): | ||
63 | + super(Discriminator, self).__init__() | ||
64 | + self.l1 = nn.Linear(500, h) | ||
65 | + self.l2 = nn.Linear(h, h) | ||
66 | + self.l3 = nn.Linear(h, 2) | ||
67 | + # self.slope =params.slope | ||
68 | + | ||
69 | + self.relu = nn.ReLU() | ||
70 | + | ||
71 | + # for m in self.modules(): | ||
72 | + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
73 | + # nn.init.kaiming_normal_(m.weight) | ||
74 | + | ||
75 | + def forward(self, x): | ||
76 | + x = self.relu(self.l1(x)) | ||
77 | + x = self.relu(self.l2(x)) | ||
78 | + x = self.l3(x) | ||
79 | + return x |
source code/adda_mixup/params.py
0 → 100644
1 | +import torch | ||
2 | + | ||
3 | +# params for dataset and data loader | ||
4 | +data_root = "data" | ||
5 | +image_size = 28 | ||
6 | + | ||
7 | +#restore | ||
8 | +model_root = 'generated\\models' | ||
9 | + | ||
10 | + | ||
11 | +# params for target dataset | ||
12 | +# 'mnist_m', 'usps', 'svhn' "custom" | ||
13 | + | ||
14 | +#dataset root | ||
15 | +mnist_dataset_root = data_root | ||
16 | +mnist_m_dataset_root = data_root+'\\mnist_m' | ||
17 | +usps_dataset_root = data_root+'\\usps' | ||
18 | +svhn_dataset_root = data_root+'\\svhn' | ||
19 | +custom_dataset_root = data_root+'\\custom\\' | ||
20 | + | ||
21 | +# params for training network | ||
22 | +num_gpu = 1 | ||
23 | + | ||
24 | +log_step_pre = 10 | ||
25 | +log_step = 10 | ||
26 | +eval_step_pre = 10 | ||
27 | + | ||
28 | +##epoch | ||
29 | +save_step_pre = 100 | ||
30 | +manual_seed = 1234 | ||
31 | + | ||
32 | +d_input_dims = 500 | ||
33 | +d_hidden_dims = 500 | ||
34 | +d_output_dims = 2 | ||
35 | +d_model_restore = 'generated\\models\\ADDA-critic-final.pt' | ||
36 | + | ||
37 | +## sorce target | ||
38 | +src_dataset = 'custom' | ||
39 | +tgt_dataset = 'custom' | ||
40 | + | ||
41 | + | ||
42 | +# params for optimizing models | ||
43 | +# # lam 0.3 | ||
44 | +#mnist -> custom | ||
45 | +num_epochs_pre = 20 | ||
46 | +num_epochs = 50 | ||
47 | +batch_size = 128 | ||
48 | +adp_batch_size = 128 | ||
49 | +pre_c_learning_rate = 2e-4 | ||
50 | +adp_c_learning_rate = 1e-4 | ||
51 | +d_learning_rate = 1e-4 | ||
52 | +beta1 = 0.5 | ||
53 | +beta2 = 0.999 | ||
54 | +weight_decay = 0 | ||
55 | + | ||
56 | + | ||
57 | +# #usps -> custom | ||
58 | +# #lam 0.1 | ||
59 | +# num_epochs_pre = 5 | ||
60 | +# num_epochs = 20 | ||
61 | +# batch_size = 256 | ||
62 | +# pre_c_learning_rate = 1e-4 | ||
63 | +# adp_c_learning_rate = 2e-5 | ||
64 | +# d_learning_rate = 1e-5 | ||
65 | +# beta1 = 0.5 | ||
66 | +# beta2 = 0.999 | ||
67 | +# weight_decay = 2e-4 | ||
68 | + | ||
69 | +# #mnist_m -> custom | ||
70 | +# #lam 0.1 | ||
71 | +# num_epochs_pre = 30 | ||
72 | +# num_epochs = 50 | ||
73 | +# batch_size = 256 | ||
74 | +# adp_batch_size = 256 | ||
75 | +# pre_c_learning_rate = 1e-3 | ||
76 | +# adp_c_learning_rate = 1e-4 | ||
77 | +# d_learning_rate = 1e-4 | ||
78 | +# beta1 = 0.5 | ||
79 | +# beta2 = 0.999 | ||
80 | +# weight_decay = 2e-4 | ||
81 | + | ||
82 | +# # params for optimizing models | ||
83 | +#lam 0.3 | ||
84 | +# #mnist -> mnist_m | ||
85 | +# num_epochs_pre = 50 | ||
86 | +# num_epochs = 100 | ||
87 | +# batch_size = 256 | ||
88 | +# adp_batch_size = 256 | ||
89 | +# pre_c_learning_rate = 2e-4 | ||
90 | +# adp_c_learning_rate = 2e-4 | ||
91 | +# d_learning_rate = 2e-4 | ||
92 | +# beta1 = 0.5 | ||
93 | +# beta2 = 0.999 | ||
94 | +# weight_decay = 0 | ||
95 | + | ||
96 | +# # source 10000 target 50000 | ||
97 | +# # params for optimizing models | ||
98 | +# #svhn -> mnist | ||
99 | +# num_epochs_pre = 20 | ||
100 | +# num_epochs = 30 | ||
101 | +# batch_size = 128 | ||
102 | +# adp_batch_size = 128 | ||
103 | +# pre_c_learning_rate = 2e-4 | ||
104 | +# adp_c_learning_rate = 1e-4 | ||
105 | +# d_learning_rate = 1e-4 | ||
106 | +# beta1 = 0.5 | ||
107 | +# beta2 = 0.999 | ||
108 | +# weight_decay = 2.5e-4 | ||
109 | + | ||
110 | +# # mnist->usps | ||
111 | +# num_epochs_pre = 50 | ||
112 | +# num_epochs = 100 | ||
113 | +# batch_size = 256 | ||
114 | +# adp_batch_size = 256 | ||
115 | +# pre_c_learning_rate = 2e-4 | ||
116 | +# adp_c_learning_rate = 2e-4 | ||
117 | +# d_learning_rate = 2e-4 | ||
118 | +# beta1 = 0.5 | ||
119 | +# beta2 = 0.999 | ||
120 | +# weight_decay =0 | ||
121 | + | ||
122 | + | ||
123 | +# # usps->mnist | ||
124 | +# num_epochs_pre = 50 | ||
125 | +# num_epochs = 100 | ||
126 | +# batch_size = 256 | ||
127 | +# pre_c_learning_rate = 2e-4 | ||
128 | +# adp_c_learning_rate = 2e-4 | ||
129 | +# d_learning_rate =2e-4 | ||
130 | +# beta1 = 0.5 | ||
131 | +# beta2 = 0.999 | ||
132 | +# weight_decay =0 | ||
133 | + | ||
134 | + | ||
135 | + | ||
136 | +# | ||
137 | +use_load = False | ||
138 | +train =False | ||
139 | + | ||
140 | +#ratio mix target | ||
141 | +lammax = 0.0 | ||
142 | +lammin = 0.0 | ||
143 | + | ||
144 | + | ||
145 | +labelsmoothing = False | ||
146 | +smoothing = 0.3 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/utils.py
0 → 100644
1 | + | ||
2 | + | ||
3 | +import os | ||
4 | +import random | ||
5 | +import torch | ||
6 | +import torch.backends.cudnn as cudnn | ||
7 | +from torch.autograd import Variable | ||
8 | +import params | ||
9 | +from dataset import get_mnist, get_mnist_m, get_usps,get_svhn,get_custom | ||
10 | +import numpy as np | ||
11 | +import itertools | ||
12 | +import torch.nn.functional as F | ||
13 | +import params | ||
14 | + | ||
15 | +def make_cuda(tensor): | ||
16 | + """Use CUDA if it's available.""" | ||
17 | + if torch.cuda.is_available(): | ||
18 | + tensor = tensor.cuda() | ||
19 | + return tensor | ||
20 | + | ||
21 | + | ||
22 | +def denormalize(x, std, mean): | ||
23 | + """Invert normalization, and then convert array into image.""" | ||
24 | + out = x * std + mean | ||
25 | + return out.clamp(0, 1) | ||
26 | + | ||
27 | + | ||
28 | +def init_weights(layer): | ||
29 | + """Init weights for layers w.r.t. the original paper.""" | ||
30 | + layer_name = layer.__class__.__name__ | ||
31 | + if layer_name.find("Conv") != -1: | ||
32 | + layer.weight.data.normal_(0.0, 0.02) | ||
33 | + elif layer_name.find("BatchNorm") != -1: | ||
34 | + layer.weight.data.normal_(1.0, 0.02) | ||
35 | + layer.bias.data.fill_(0) | ||
36 | + | ||
37 | + | ||
38 | +def init_random_seed(manual_seed): | ||
39 | + """Init random seed.""" | ||
40 | + seed = None | ||
41 | + if manual_seed is None: | ||
42 | + seed = random.randint(1, 10000) | ||
43 | + else: | ||
44 | + seed = manual_seed | ||
45 | + #for REPRODUCIBILITY | ||
46 | + torch.backends.cudnn.deterministic = True | ||
47 | + torch.backends.cudnn.benchmark = False | ||
48 | + print("use random seed: {}".format(seed)) | ||
49 | + random.seed(seed) | ||
50 | + torch.manual_seed(seed) | ||
51 | + np.random.seed(seed) | ||
52 | + | ||
53 | + if torch.cuda.is_available(): | ||
54 | + torch.cuda.manual_seed_all(seed) | ||
55 | + | ||
56 | + | ||
57 | +def get_data_loader(name,train=True,adp=False,size = 0): | ||
58 | + """Get data loader by name.""" | ||
59 | + if name == "mnist": | ||
60 | + return get_mnist(train,adp,size) | ||
61 | + elif name == "mnist_m": | ||
62 | + return get_mnist_m(train,adp,size) | ||
63 | + elif name == "usps": | ||
64 | + return get_usps(train,adp,size) | ||
65 | + elif name == "svhn": | ||
66 | + return get_svhn(train,adp,size) | ||
67 | + elif name == "custom": | ||
68 | + return get_custom(train,adp,size) | ||
69 | + | ||
70 | +def init_model(net, restore=None): | ||
71 | + """Init models with cuda and weights.""" | ||
72 | + # init weights of model | ||
73 | + # net.apply(init_weights) | ||
74 | + | ||
75 | + print(f'restore file : {restore}') | ||
76 | + # restore model weights | ||
77 | + if restore is not None and os.path.exists(restore): | ||
78 | + net.load_state_dict(torch.load(restore)) | ||
79 | + net.restored = True | ||
80 | + print("Restore model from: {}".format(os.path.abspath(restore))) | ||
81 | + | ||
82 | + # check if cuda is available | ||
83 | + if torch.cuda.is_available(): | ||
84 | + net.cuda() | ||
85 | + | ||
86 | + return net | ||
87 | + | ||
88 | +def save_model(net, filename): | ||
89 | + """Save trained model.""" | ||
90 | + if not os.path.exists(params.model_root): | ||
91 | + os.makedirs(params.model_root) | ||
92 | + torch.save(net.state_dict(), | ||
93 | + os.path.join(params.model_root, filename)) | ||
94 | + print("save pretrained model to: {}".format(os.path.join(params.model_root, | ||
95 | + filename))) | ||
96 | + | ||
97 | +class LabelSmoothingCrossEntropy(torch.nn.Module): | ||
98 | + def __init__(self,smoothing): | ||
99 | + super(LabelSmoothingCrossEntropy, self).__init__() | ||
100 | + self.smoothing = smoothing | ||
101 | + def forward(self, y, targets,smoothing=0.1): | ||
102 | + confidence = 1. - self.smoothing | ||
103 | + log_probs = F.log_softmax(y, dim=-1) # 예측 확률 계산 | ||
104 | + true_probs = torch.zeros_like(log_probs) | ||
105 | + true_probs.fill_(self.smoothing / (y.shape[1] - 1)) | ||
106 | + true_probs.scatter_(1, targets.data.unsqueeze(1), confidence) # 정답 인덱스의 정답 확률을 confidence로 변경 | ||
107 | + return torch.mean(torch.sum(true_probs * -log_probs, dim=-1)) # negative log likelihood | ||
108 | + | ||
109 | +#mixup only data, not label | ||
110 | +def mixup_data(source,target): | ||
111 | + max = params.lammax | ||
112 | + min = params.lammin | ||
113 | + lam = (max-min)*torch.rand((1))+min | ||
114 | + lam=lam.cuda() | ||
115 | + target = target.cuda() | ||
116 | + mixed_source = (1 - lam) * source + lam* target | ||
117 | + | ||
118 | + | ||
119 | + return mixed_source, lam | ||
120 | + | ||
121 | + | ||
122 | + | ||
123 | + | ||
124 | +def mixup_criterion(criterion, pred, y_a, y_b, lam): | ||
125 | + return 0.9* criterion(pred, y_a) + 0.1 * criterion(pred, y_b) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment