배경호

add source code

1 +# # Created by .ignore support plugin (hsz.mobi)
2 +# ### Example user template template
3 +# ### Example user template
4 +
5 +# # IntelliJ project files
6 +# .idea
7 +# *.iml
8 +# out
9 +# gen
10 +# ### Python template
11 +# # Byte-compiled / optimized / DLL files
12 +__pycache__/
13 +# *.py[cod]
14 +# *$py.class
15 +
16 +# # C extensions
17 +# *.so
18 +
19 +# # Distribution / packaging
20 +# .Python
21 +# build/
22 +# develop-eggs/
23 +# dist/
24 +# downloads/
25 +# eggs/
26 +# .eggs/
27 +# lib/
28 +# lib64/
29 +# parts/
30 +# sdist/
31 +# var/
32 +# wheels/
33 +# pip-wheel-metadata/
34 +# share/python-wheels/
35 +# *.egg-info/
36 +# .installed.cfg
37 +# *.egg
38 +# MANIFEST
39 +
40 +# # PyInstaller
41 +# # Usually these files are written by a python script from a template
42 +# # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 +# *.manifest
44 +# *.spec
45 +
46 +# # Installer logs
47 +# pip-log.txt
48 +# pip-delete-this-directory.txt
49 +
50 +# # Unit test / coverage reports
51 +# htmlcov/
52 +# .tox/
53 +# .nox/
54 +# .coverage
55 +# .coverage.*
56 +# .cache
57 +# nosetests.xml
58 +# coverage.xml
59 +# *.cover
60 +# .hypothesis/
61 +# .pytest_cache/
62 +
63 +# # # Translations
64 +# *.mo
65 +# *.pot
66 +
67 +# # Django stuff:
68 +# *.log
69 +# local_settings.py
70 +# db.sqlite3
71 +
72 +# # Flask stuff:
73 +# instance/
74 +# .webassets-cache
75 +
76 +# # Scrapy stuff:
77 +# .scrapy
78 +
79 +# # Sphinx documentation
80 +# docs/_build/
81 +
82 +# # PyBuilder
83 +# target/
84 +
85 +# # Jupyter Notebook
86 +# .ipynb_checkpoints
87 +
88 +# # IPython
89 +# profile_default/
90 +# ipython_config.py
91 +
92 +# # pyenv
93 +# .python-version
94 +
95 +# # pipenv
96 +# # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 +# # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 +# # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 +# # install all needed dependencies.
100 +# #Pipfile.lock
101 +
102 +# # celery beat schedule file
103 +# celerybeat-schedule
104 +
105 +# # SageMath parsed files
106 +# *.sage.py
107 +
108 +# # Environments
109 +# .env
110 +# .venv
111 +# env/
112 +# venv/
113 +# ENV/
114 +# env.bak/
115 +# venv.bak/
116 +
117 +# # Spyder project settings
118 +# .spyderproject
119 +# .spyproject
120 +
121 +# # Rope project settings
122 +# .ropeproject
123 +
124 +# # mkdocs documentation
125 +# /site
126 +
127 +# # mypy
128 +# .mypy_cache/
129 +# .dmypy.json
130 +# dmypy.json
131 +
132 +# # Pyre type checker
133 +# .pyre/
134 +
135 +# ### JetBrains template
136 +# # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
137 +# # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
138 +
139 +# # User-specific stuff
140 +# .idea/**/workspace.xml
141 +# .idea/**/tasks.xml
142 +# .idea/**/usage.statistics.xml
143 +# .idea/**/dictionaries
144 +# .idea/**/shelf
145 +
146 +# # Generated files
147 +# .idea/**/contentModel.xml
148 +
149 +# # Sensitive or high-churn files
150 +# .idea/**/dataSources/
151 +# .idea/**/dataSources.ids
152 +# .idea/**/dataSources.local.xml
153 +# .idea/**/sqlDataSources.xml
154 +# .idea/**/dynamic.xml
155 +# .idea/**/uiDesigner.xml
156 +# .idea/**/dbnavigator.xml
157 +
158 +# # Gradle
159 +# .idea/**/gradle.xml
160 +# .idea/**/libraries
161 +
162 +# # Gradle and Maven with auto-import
163 +# # When using Gradle or Maven with auto-import, you should exclude module files,
164 +# # since they will be recreated, and may cause churn. Uncomment if using
165 +# # auto-import.
166 +# # .idea/modules.xml
167 +# # .idea/*.iml
168 +# # .idea/modules
169 +# # *.iml
170 +# # *.ipr
171 +
172 +# # CMake
173 +# cmake-build-*/
174 +
175 +# # Mongo Explorer plugin
176 +# .idea/**/mongoSettings.xml
177 +
178 +# # File-based project format
179 +# *.iws
180 +
181 +# # IntelliJ
182 +# out/
183 +
184 +# # mpeltonen/sbt-idea plugin
185 +# .idea_modules/
186 +
187 +# # JIRA plugin
188 +# atlassian-ide-plugin.xml
189 +
190 +# # Cursive Clojure plugin
191 +# .idea/replstate.xml
192 +
193 +# # Crashlytics plugin (for Android Studio and IntelliJ)
194 +# com_crashlytics_export_strings.xml
195 +# crashlytics.properties
196 +# crashlytics-build.properties
197 +# fabric.properties
198 +
199 +# # Editor-based Rest Client
200 +# .idea/httpRequests
201 +
202 +# # Android studio 3.1+ serialized cache file
203 +# .idea/caches/build_file_checksums.ser
204 +
205 +./data/
206 +data/
207 +# .idea/
208 +generated/
1 +# PyTorch-ADDA-mixup
2 +A PyTorch implementation added MIXUO for Adversarial Discriminative Domain Adaptation.
3 +
4 +Confirmed improved performance by mixing up target domian and source domain
5 +
6 +# Usage
7 +It works on MNIST -> USPS , SVHN -> MNIST , USPS -> MNIST, MNIST -> MNIST-M
8 +Only 10,000 of the total data were used.(usps excluded)
9 +
10 +<pre>
11 +<code>
12 +python main.py
13 +</code>
14 +</pre>
15 +
16 +## adda
17 +This repo is based on https://github.com/corenel/pytorch-adda , https://github.com/Fujiki-Nakamura/ADDA.PyTorch
18 +
19 +
20 +
21 +Reference
22 +https://arxiv.org/abs/1702.05464
23 +
24 +
25 +#
1 +import os
2 +
3 +import torch
4 +import torch.optim as optim
5 +from torch import nn
6 +from core import test
7 +import params
8 +from utils import make_cuda, mixup_data
9 +
10 +
11 +
12 +def train_tgt(source_cnn, target_cnn, critic,
13 + src_data_loader, tgt_data_loader):
14 + """Train encoder for target domain."""
15 + ####################
16 + # 1. setup network #
17 + ####################
18 +
19 + source_cnn.eval()
20 + target_cnn.encoder.train()
21 + critic.train()
22 + isbest = 0
23 + # setup criterion and optimizer
24 + criterion = nn.CrossEntropyLoss()
25 + #target encoder
26 + optimizer_tgt = optim.Adam(target_cnn.parameters(),
27 + lr=params.adp_c_learning_rate,
28 + betas=(params.beta1, params.beta2),
29 + weight_decay=params.weight_decay
30 + )
31 + #Discriminator
32 + optimizer_critic = optim.Adam(critic.parameters(),
33 + lr=params.d_learning_rate,
34 + betas=(params.beta1, params.beta2),
35 + weight_decay=params.weight_decay
36 +
37 +
38 + )
39 +
40 + ####################
41 + # 2. train network #
42 + ####################
43 + len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
44 +
45 + for epoch in range(params.num_epochs):
46 + # zip source and target data pair
47 + data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
48 + for step, ((images_src, _), (images_tgt, _)) in data_zip:
49 +
50 + # make images variable
51 + images_src = make_cuda(images_src)
52 + images_tgt = make_cuda(images_tgt)
53 +
54 +
55 +
56 +
57 + ###########################
58 + # 2.1 train discriminator #
59 + ###########################
60 +
61 + # zero gradients for optimizer
62 + optimizer_critic.zero_grad()
63 +
64 + # extract and concat features
65 + feat_src = source_cnn.encoder(images_src)
66 + feat_tgt = target_cnn.encoder(images_tgt)
67 + feat_concat = torch.cat((feat_src, feat_tgt), 0)
68 +
69 + # predict on discriminator
70 + pred_concat = critic(feat_concat.detach())
71 +
72 + # prepare real and fake label
73 + label_src = make_cuda(torch.zeros(feat_src.size(0)).long())
74 + label_tgt = make_cuda(torch.ones(feat_tgt.size(0)).long())
75 + label_concat = torch.cat((label_src, label_tgt), 0)
76 +
77 + # compute loss for critic
78 + loss_critic = criterion(pred_concat, label_concat)
79 + loss_critic.backward()
80 +
81 + # optimize critic
82 + optimizer_critic.step()
83 +
84 + pred_cls = torch.squeeze(pred_concat.max(1)[1])
85 + acc = (pred_cls == label_concat).float().mean()
86 +
87 +
88 + ############################
89 + # 2.2 train target encoder #
90 + ############################
91 +
92 + # zero gradients for optimizer
93 + optimizer_critic.zero_grad()
94 + optimizer_tgt.zero_grad()
95 +
96 + # extract and target features
97 + feat_tgt = target_cnn.encoder(images_tgt)
98 +
99 + # predict on discriminator
100 + pred_tgt = critic(feat_tgt)
101 +
102 + # prepare fake labels
103 + label_tgt = make_cuda(torch.zeros(feat_tgt.size(0)).long())
104 +
105 + # compute loss for target encoder
106 + loss_tgt = criterion(pred_tgt, label_tgt)
107 + loss_tgt.backward()
108 +
109 + # optimize target encoder
110 + optimizer_tgt.step()
111 + #######################
112 + # 2.3 print step info #
113 + #######################
114 + if ((epoch % 10 ==0 )&((step + 1) % len_data_loader== 0)):
115 + print("Epoch [{}/{}] Step [{}/{}]:"
116 + "d_loss={:.5f} g_loss={:.5f} acc={:.5f}"
117 + .format(epoch,
118 + params.num_epochs,
119 + step + 1,
120 + len_data_loader,
121 + loss_critic.item(),
122 + loss_tgt.item(),
123 + acc.item()))
124 +
125 +
126 + torch.save(critic.state_dict(), os.path.join(
127 + params.model_root,
128 + "ADDA-critic-final.pt"))
129 + torch.save(target_cnn.state_dict(), os.path.join(
130 + params.model_root,
131 + "ADDA-target_cnn-final.pt"))
132 + return target_cnn
...\ No newline at end of file ...\ No newline at end of file
1 +import torch.nn as nn
2 +import torch
3 +import torch.optim as optim
4 +
5 +import params
6 +from utils import make_cuda, save_model, LabelSmoothingCrossEntropy,mixup_data,mixup_criterion
7 +from random import *
8 +import sys
9 +
10 +from torch.utils.data import Dataset,DataLoader
11 +import os
12 +from core import test
13 +from utils import make_cuda, mixup_data
14 +
15 +
16 +
17 +class CustomDataset(Dataset):
18 + def __init__(self,img,label):
19 + self.x_data = img
20 + self.y_data = label
21 + def __len__(self):
22 + return len(self.x_data)
23 +
24 + def __getitem__(self, idx):
25 + x = self.x_data[idx]
26 + y = self.y_data[idx]
27 + return x, y
28 +
29 +
30 +def train_src(model, source_data_loader,target_data_loader,valid_loader):
31 + """Train classifier for source domain."""
32 + ####################
33 + # 1. setup network #
34 + ####################
35 +
36 + model.train()
37 +
38 +
39 + target_data_loader = list(target_data_loader)
40 +
41 + # setup criterion and optimizer
42 + optimizer = optim.Adam(
43 + model.parameters(),
44 + lr=params.pre_c_learning_rate,
45 + betas=(params.beta1, params.beta2),
46 + weight_decay=params.weight_decay
47 + )
48 +
49 +
50 +
51 + if params.labelsmoothing:
52 + criterion = LabelSmoothingCrossEntropy(smoothing= params.smoothing)
53 + else:
54 + criterion = nn.CrossEntropyLoss()
55 +
56 +
57 + ####################
58 + # 2. train network #
59 + ####################
60 + len_data_loader = min(len(source_data_loader), len(target_data_loader))
61 +
62 + for epoch in range(params.num_epochs_pre+1):
63 + data_zip = enumerate(zip(source_data_loader, target_data_loader))
64 + for step, ((images, labels), (images_tgt, _)) in data_zip:
65 + # make images and labels variable
66 + images = make_cuda(images)
67 + labels = make_cuda(labels.squeeze_())
68 + # zero gradients for optimizer
69 + optimizer.zero_grad()
70 + target=target_data_loader[randint(0, len(target_data_loader)-1)]
71 + images, lam = mixup_data(images,target[0])
72 +
73 + # compute loss for critic
74 + preds = model(images)
75 + # loss = mixup_criterion(criterion, preds, labels, labels_tgt, lam)
76 + loss = criterion(preds, labels)
77 +
78 + # optimize source classifier
79 + loss.backward()
80 + optimizer.step()
81 +
82 +
83 +
84 + # # eval model on test set
85 + if ((epoch) % params.eval_step_pre == 0):
86 + print(f"Epoch [{epoch}/{params.num_epochs_pre}]",end='')
87 + if valid_loader is not None:
88 + test.eval_tgt(model, valid_loader)
89 + else:
90 + test.eval_tgt(model, source_data_loader)
91 +
92 + # save model parameters
93 + if ((epoch + 1) % params.save_step_pre == 0):
94 + save_model(model, "our-source_cnn-{}.pt".format(epoch + 1))
95 +
96 + # # save final model
97 + save_model(model, "our-source_cnn-final.pt")
98 +
99 + return model
100 +
101 +
102 +
103 +
104 +def train_tgt(source_cnn, target_cnn, critic,
105 + src_data_loader, tgt_data_loader,valid_loader):
106 + """Train encoder for target domain."""
107 + ####################
108 + # 1. setup network #
109 + ####################
110 +
111 + source_cnn.eval()
112 + target_cnn.encoder.train()
113 + critic.train()
114 + isbest = 0
115 + # setup criterion and optimizer
116 + criterion = nn.CrossEntropyLoss()
117 + #target encoder
118 + optimizer_tgt = optim.Adam(target_cnn.parameters(),
119 + lr=params.adp_c_learning_rate,
120 + betas=(params.beta1, params.beta2),
121 + weight_decay=params.weight_decay
122 + )
123 + #Discriminator
124 + optimizer_critic = optim.Adam(critic.parameters(),
125 + lr=params.d_learning_rate,
126 + betas=(params.beta1, params.beta2),
127 + weight_decay=params.weight_decay
128 +
129 +
130 + )
131 +
132 + ####################
133 + # 2. train network #
134 + ####################
135 + data_len = min(len(src_data_loader), len(tgt_data_loader))
136 +
137 + for epoch in range(params.num_epochs+1):
138 + # zip source and target data pair
139 + data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
140 + for step, ((images_src, _), (images_tgt, _)) in data_zip:
141 +
142 + # make images variable
143 + images_src = make_cuda(images_src)
144 + images_tgt = make_cuda(images_tgt)
145 +
146 + ###########################
147 + # 2.1 train discriminator #
148 + ###########################
149 +
150 + # mixup data
151 + images_src, _ = mixup_data(images_src,images_tgt)
152 +
153 + # zero gradients for optimizer
154 + optimizer_critic.zero_grad()
155 +
156 + # extract and concat features
157 + feat_src = source_cnn.encoder(images_src)
158 + feat_tgt = target_cnn.encoder(images_tgt)
159 + feat_concat = torch.cat((feat_src, feat_tgt), 0)
160 +
161 + # predict on discriminator
162 + pred_concat = critic(feat_concat.detach())
163 +
164 + # prepare real and fake label
165 + label_src = make_cuda(torch.zeros(feat_src.size(0)).long())
166 + label_tgt = make_cuda(torch.ones(feat_tgt.size(0)).long())
167 + label_concat = torch.cat((label_src, label_tgt), 0)
168 +
169 + # compute loss for critic
170 + loss_critic = criterion(pred_concat, label_concat)
171 + loss_critic.backward()
172 +
173 + # optimize critic
174 + optimizer_critic.step()
175 +
176 + pred_cls = torch.squeeze(pred_concat.max(1)[1])
177 + acc = (pred_cls == label_concat).float().mean()
178 +
179 +
180 + ############################
181 + # 2.2 train target encoder #
182 + ############################
183 +
184 + # zero gradients for optimizer
185 + optimizer_critic.zero_grad()
186 + optimizer_tgt.zero_grad()
187 +
188 + # extract and target features
189 + feat_tgt = target_cnn.encoder(images_tgt)
190 +
191 + # predict on discriminator
192 + pred_tgt = critic(feat_tgt)
193 +
194 + # prepare fake labels
195 + label_tgt = make_cuda(torch.zeros(feat_tgt.size(0)).long())
196 +
197 + # compute loss for target encoder
198 + loss_tgt = criterion(pred_tgt, label_tgt)
199 + loss_tgt.backward()
200 +
201 + # optimize target encoder
202 + optimizer_tgt.step()
203 + #######################
204 + # 2.3 print step info #
205 + #######################
206 + if ((epoch % 10 ==0 )& ((step + 1) % data_len == 0)):
207 + print("Epoch [{}/{}] Step [{}/{}]:"
208 + "d_loss={:.5f} g_loss={:.5f} acc={:.5f}"
209 + .format(epoch ,
210 + params.num_epochs,
211 + step + 1,
212 + data_len,
213 + loss_critic.item(),
214 + loss_tgt.item(),
215 + acc.item()))
216 + if valid_loader is not None:
217 + test.eval_tgt(target_cnn,valid_loader)
218 +
219 + torch.save(critic.state_dict(), os.path.join(
220 + params.model_root,
221 + "our-critic-final.pt"))
222 + torch.save(target_cnn.state_dict(), os.path.join(
223 + params.model_root,
224 + "our-target_cnn-final.pt"))
225 + return target_cnn
...\ No newline at end of file ...\ No newline at end of file
1 +import torch.nn as nn
2 +import torch.optim as optim
3 +import torch
4 +import params
5 +from utils import make_cuda, save_model, LabelSmoothingCrossEntropy,mixup_data
6 +from random import *
7 +import sys
8 +
9 +def train_src(model, source_data_loader):
10 + """Train classifier for source domain."""
11 + ####################
12 + # 1. setup network #
13 + ####################
14 +
15 + model.train()
16 +
17 +
18 +
19 + # setup criterion and optimizer
20 + optimizer = optim.Adam(
21 + model.parameters(),
22 + lr=params.pre_c_learning_rate,
23 + betas=(params.beta1, params.beta2),
24 + weight_decay=params.weight_decay
25 + )
26 +
27 +
28 + if params.labelsmoothing:
29 + criterion = LabelSmoothingCrossEntropy(smoothing= params.smoothing)
30 + else:
31 + criterion = nn.CrossEntropyLoss()
32 +
33 +
34 + ####################
35 + # 2. train network #
36 + ####################
37 +
38 + for epoch in range(params.num_epochs_pre):
39 + for step, (images, labels) in enumerate(source_data_loader):
40 + # make images and labels variable
41 + images = make_cuda(images)
42 + labels = make_cuda(labels.squeeze_())
43 + # zero gradients for optimizer
44 + optimizer.zero_grad()
45 +
46 + # compute loss for critic
47 + preds = model(images)
48 + loss = criterion(preds, labels)
49 +
50 + # optimize source classifier
51 + loss.backward()
52 + optimizer.step()
53 +
54 +
55 +
56 + # # eval model on test set
57 + if ((epoch ) % params.eval_step_pre == 0):
58 + print(f"Epoch [{epoch}/{params.num_epochs_pre}]",end='')
59 + eval_src(model, source_data_loader)
60 +
61 + # save model parameters
62 + if ((epoch + 1) % params.save_step_pre == 0):
63 + save_model(model, "ADDA-source_cnn-{}.pt".format(epoch + 1))
64 +
65 + # # save final model
66 + save_model(model, "ADDA-source_cnn-final.pt")
67 +
68 + return model
69 +
70 +def eval_src(model, data_loader):
71 + """Evaluate classifier for source domain."""
72 + # set eval state for Dropout and BN layers
73 + model.eval()
74 + with torch.no_grad():
75 + # init loss and accuracy
76 + loss = 0
77 + acc = 0
78 +
79 + # evaluate network
80 + for (images, labels) in data_loader:
81 +
82 + images = make_cuda(images)
83 + labels = make_cuda(labels).squeeze_()
84 +
85 + preds = model(images)
86 +
87 + pred_cls = preds.data.max(1)[1]
88 + acc += pred_cls.eq(labels.data).cpu().sum().item()
89 +
90 + acc /= len(data_loader.dataset)
91 +
92 + print("Avg Accuracy = {:2%}".format( acc))
93 +
94 +
1 +#!/usr/bin/env python3
2 +# -*- coding: utf-8 -*-
3 +"""
4 +Created on Wed Dec 5 15:03:50 2018
5 +
6 +@author: gaoyi
7 +"""
8 +
9 +import torch
10 +import torch.nn as nn
11 +
12 +from utils import make_cuda
13 +
14 +
15 +def eval_tgt(model, data_loader):
16 + """Evaluation for target encoder by source classifier on target dataset."""
17 + # set eval state for Dropout and BN layers
18 + model.eval()
19 + # init loss and accuracy
20 + loss = 0
21 + acc = 0
22 + with torch.no_grad():
23 + # evaluate network
24 + for (images, labels) in data_loader:
25 + images = make_cuda(images)
26 + labels = make_cuda(labels).squeeze_()
27 +
28 + preds = model(images)
29 + _, preds = torch.max(preds.data, 1)
30 + acc += (preds == labels).float().sum()/images.shape[0]
31 +
32 + acc /= len(data_loader)
33 +
34 + print("Avg Accuracy = {:2%}".format(acc))
...\ No newline at end of file ...\ No newline at end of file
1 +from .mnist import get_mnist
2 +from .mnist_m import get_mnist_m
3 +from .usps import get_usps
4 +from .svhn import get_svhn
5 +from .customdata import get_custom
6 +__all__ = (get_mnist_m, get_mnist,get_usps,get_svhn,get_custom)
1 +from torchvision import transforms, datasets
2 +import torch
3 +import params
4 +
5 +def get_custom(train,adp=False,size = 0):
6 +
7 + pre_process = transforms.Compose([transforms.Resize(params.image_size),
8 + transforms.ToTensor(),
9 + # transforms.Normalize((0.5),(0.5)),
10 + ])
11 + custom_dataset = datasets.ImageFolder(
12 + root = params.custom_dataset_root ,
13 + transform = pre_process,
14 + )
15 + length = len(custom_dataset)
16 + train_set, val_set = torch.utils.data.random_split(custom_dataset, [int(length*0.9), length-int(length*0.9)])
17 +
18 + if train:
19 + train_set,_ = torch.utils.data.random_split(train_set, [size,len(train_set)-size])
20 +
21 +
22 +
23 + custom_data_loader = torch.utils.data.DataLoader(
24 + train_set if train else val_set,
25 + batch_size= params.adp_batch_size if adp else params.batch_size,
26 + shuffle=True,
27 + drop_last=True
28 +
29 + )
30 +
31 + return custom_data_loader
...\ No newline at end of file ...\ No newline at end of file
1 +
2 +
3 +import torch
4 +from torchvision import datasets, transforms
5 +import torch.utils.data as data_utils
6 +
7 +import params
8 +
9 +
10 +def get_mnist(train,adp = False,size = 0):
11 + """Get MNIST dataset loader."""
12 + # image pre-processing
13 + pre_process = transforms.Compose([transforms.Resize(params.image_size),
14 + transforms.ToTensor(),
15 +# transforms.Normalize((0.5),(0.5)),
16 + transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
17 +
18 +
19 + ])
20 +
21 +
22 +
23 +
24 + # dataset and data loader
25 + mnist_dataset = datasets.MNIST(root=params.mnist_dataset_root,
26 + train=train,
27 + transform=pre_process,
28 +
29 + download=True)
30 + if train:
31 + # perm = torch.randperm(len(mnist_dataset))
32 + # indices = perm[:10000]
33 + mnist_dataset,_ = data_utils.random_split(mnist_dataset, [size,len(mnist_dataset)-size])
34 + # size = len(mnist_dataset)
35 + # train, valid = data_utils.random_split(mnist_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)])
36 + # train_loader = torch.utils.data.DataLoader(
37 + # dataset=train,
38 + # batch_size= params.adp_batch_size if adp else params.batch_size,
39 + # shuffle=True,
40 + # drop_last=True)
41 + # valid_loader = torch.utils.data.DataLoader(
42 + # dataset=valid,
43 + # batch_size= params.adp_batch_size if adp else params.batch_size,
44 + # shuffle=True,
45 + # drop_last=True)
46 +
47 + # return train_loader,valid_loader
48 +
49 + mnist_data_loader = torch.utils.data.DataLoader(
50 + dataset=mnist_dataset,
51 + batch_size= params.adp_batch_size if adp else params.batch_size,
52 + shuffle=True,
53 + drop_last=True)
54 + return mnist_data_loader
...\ No newline at end of file ...\ No newline at end of file
1 +import torch.utils.data as data
2 +from PIL import Image
3 +import os
4 +import params
5 +from torchvision import transforms
6 +import torch
7 +
8 +import torch.utils.data as data_utils
9 +
10 +
11 +class GetLoader(data.Dataset):
12 + def __init__(self, data_root, data_list, transform=None):
13 + self.root = data_root
14 + self.transform = transform
15 +
16 + f = open(data_list, 'r')
17 + data_list = f.readlines()
18 + f.close()
19 +
20 + self.n_data = len(data_list)
21 +
22 + self.img_paths = []
23 + self.img_labels = []
24 +
25 + for data_ in data_list:
26 + self.img_paths.append(data_[:-3])
27 + self.img_labels.append(data_[-2])
28 +
29 + def __getitem__(self, item):
30 + img_paths, labels = self.img_paths[item], self.img_labels[item]
31 + imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')
32 +
33 + if self.transform is not None:
34 + imgs = self.transform(imgs)
35 + labels = int(labels)
36 +
37 + return imgs, labels
38 +
39 + def __len__(self):
40 + return self.n_data
41 +
42 +
43 +def get_mnist_m(train,adp=False,size= 0 ):
44 +
45 + if train == True:
46 + mode = 'train'
47 + else:
48 + mode = 'test'
49 +
50 + train_list = os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}_labels.txt'.format(mode))
51 + # image pre-processing
52 + pre_process = transforms.Compose([
53 + transforms.Resize(params.image_size),
54 + # transforms.Grayscale(3),
55 +
56 + transforms.ToTensor(),
57 +
58 +# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
59 +# transforms.Grayscale(1),
60 + ]
61 + )
62 +
63 + dataset_target = GetLoader(
64 + data_root=os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}'.format(mode)),
65 + data_list=train_list,
66 + transform=pre_process)
67 +
68 + if train:
69 + # perm = torch.randperm(len(dataset_target))
70 + # indices = perm[:10000]
71 + dataset_target,_ = data_utils.random_split(dataset_target, [size,len(dataset_target)-size])
72 + # size = len(dataset_target)
73 + # train, valid = data_utils.random_split(dataset_target,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)])
74 + # train_loader = torch.utils.data.DataLoader(
75 + # dataset=train,
76 + # batch_size= params.adp_batch_size if adp else params.batch_size,
77 + # shuffle=True,
78 + # drop_last=True)
79 + # valid_loader = torch.utils.data.DataLoader(
80 + # dataset=valid,
81 + # batch_size= params.adp_batch_size if adp else params.batch_size,
82 + # shuffle=True,
83 + # drop_last=True)
84 +
85 + # return train_loader,valid_loader
86 +
87 +
88 + dataloader = torch.utils.data.DataLoader(
89 + dataset=dataset_target,
90 + batch_size= params.adp_batch_size if adp else params.batch_size,
91 +
92 + shuffle=True,
93 + drop_last=True)
94 +
95 + return dataloader
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from torchvision import datasets, transforms
3 +
4 +import params
5 +
6 +import torch.utils.data as data_utils
7 +
8 +def get_svhn(train,adp=False,size=0):
9 + """Get SVHN dataset loader."""
10 + # image pre-processing
11 + pre_process = transforms.Compose([
12 + transforms.Resize(params.image_size),
13 + transforms.Grayscale(3),
14 + transforms.ToTensor(),
15 + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
16 +
17 + ])
18 +
19 +
20 + # dataset and data loader
21 + svhn_dataset = datasets.SVHN(root=params.svhn_dataset_root,
22 + split='train' if train else 'test',
23 + transform=pre_process,
24 + download=True)
25 + if train:
26 + # perm = torch.randperm(len(svhn_dataset))
27 + # indices = perm[:10000]
28 + svhn_dataset,_ = data_utils.random_split(svhn_dataset, [size,len(svhn_dataset)-size])
29 + # size = len(svhn_dataset)
30 + # train, valid = data_utils.random_split(svhn_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)])
31 +
32 + # train_loader = torch.utils.data.DataLoader(
33 + # dataset=train,
34 + # batch_size= params.adp_batch_size if adp else params.batch_size,
35 +
36 + # shuffle=True,
37 + # drop_last=True)
38 +
39 + # valid_loader = torch.utils.data.DataLoader(
40 + # dataset=valid,
41 + # batch_size= params.adp_batch_size if adp else params.batch_size,
42 +
43 + # shuffle=True,
44 + # drop_last=True)
45 + # return train_loader,valid_loader
46 +
47 + svhn_data_loader = torch.utils.data.DataLoader(
48 + dataset=svhn_dataset,
49 + batch_size= params.adp_batch_size if adp else params.batch_size,
50 +
51 + shuffle=True,
52 + drop_last=True)
53 +
54 + return svhn_data_loader
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from torchvision import datasets, transforms
3 +from torch.utils.data.dataset import random_split
4 +
5 +import params
6 +import torch.utils.data as data_utils
7 +
8 +
9 +def get_usps(train,adp=False,size=0):
10 + """Get usps dataset loader."""
11 + # image pre-processing
12 + pre_process = transforms.Compose([transforms.Resize(params.image_size),
13 + transforms.ToTensor(),
14 + # transforms.Normalize((0.5),(0.5)),
15 + transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
16 + # transforms.Grayscale(1),
17 +
18 +
19 +
20 + ])
21 +
22 +
23 + # dataset and data loader
24 + usps_dataset = datasets.USPS(root=params.usps_dataset_root,
25 + train=train,
26 + transform=pre_process,
27 + download=True)
28 +
29 +
30 + if train:
31 + usps_dataset, _ = data_utils.random_split(usps_dataset, [size,len(usps_dataset)-size])
32 + # size = len(usps_dataset)
33 + # train, valid = data_utils.random_split(usps_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)])
34 + # train_loader = torch.utils.data.DataLoader(
35 + # dataset=train,
36 + # batch_size= params.adp_batch_size if adp else params.batch_size,
37 + # shuffle=True,
38 + # drop_last=True)
39 + # valid_loader = torch.utils.data.DataLoader(
40 + # dataset=valid,
41 + # batch_size= params.adp_batch_size if adp else params.batch_size,
42 + # shuffle=True,
43 + # drop_last=True)
44 + # return train_loader,valid_loader
45 +
46 + usps_data_loader = torch.utils.data.DataLoader(
47 + dataset=usps_dataset,
48 + batch_size= params.adp_batch_size if adp else params.batch_size,
49 +
50 + shuffle=True,
51 + drop_last=True)
52 + return usps_data_loader
...\ No newline at end of file ...\ No newline at end of file
1 +
2 +import params
3 +from utils import get_data_loader, init_model, init_random_seed,mixup_data
4 +from core import pretrain , adapt , test,mixup
5 +import torch
6 +from models.models import *
7 +import numpy as np
8 +import sys
9 +
10 +
11 +
12 +if __name__ == '__main__':
13 + # init random seed
14 + init_random_seed(params.manual_seed)
15 + print(f"Is cuda availabel? {torch.cuda.is_available()}")
16 +
17 +
18 +
19 + #set loader
20 + print("src data loader....")
21 + src_data_loader = get_data_loader(params.src_dataset,adp=False,size = 10000)
22 + src_data_loader_eval = get_data_loader(params.src_dataset,train=False)
23 + print("tgt data loader....")
24 + tgt_data_loader = get_data_loader(params.tgt_dataset,adp=False,size = 50000)
25 + tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)
26 + print(f"scr data size : {len(src_data_loader.dataset)}")
27 + print(f"tgt data size : {len(tgt_data_loader.dataset)}")
28 +
29 +
30 + print("start training")
31 + source_cnn = CNN(in_channels=3).to("cuda")
32 + target_cnn = CNN(in_channels=3, target=True).to("cuda")
33 + discriminator = Discriminator().to("cuda")
34 +
35 + source_cnn = mixup.train_src(source_cnn, src_data_loader,tgt_data_loader,None)
36 + # source_cnn.load_state_dict(torch.load("./generated/models/our-source_cnn-final.pt"))
37 +
38 + test.eval_tgt(source_cnn, tgt_data_loader_eval)
39 +
40 + target_cnn.load_state_dict(source_cnn.state_dict())
41 +
42 + tgt_encoder = mixup.train_tgt(source_cnn, target_cnn, discriminator,
43 + src_data_loader,tgt_data_loader,None)
44 +
45 +
46 +
47 + print("=== Evaluating classifier for encoded target domain ===")
48 + print(f"mixup : {params.lammax} {params.src_dataset} -> {params.tgt_dataset} ")
49 + print("Eval | source_cnn | src_data_loader_eval")
50 + test.eval_tgt(source_cnn, src_data_loader_eval)
51 + print(">>> Eval | source_cnn | tgt_data_loader_eval <<<")
52 + test.eval_tgt(source_cnn, tgt_data_loader_eval)
53 + print(">>> Eval | target_cnn | tgt_data_loader_eval <<<")
54 + test.eval_tgt(target_cnn, tgt_data_loader_eval)
55 +
56 +
57 +
58 +
1 +from torch import nn
2 +import torch.nn.functional as F
3 +import params
4 +
5 +class Encoder(nn.Module):
6 + def __init__(self, in_channels=1, h=256, dropout=0.5):
7 + super(Encoder, self).__init__()
8 + self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1)
9 + self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
10 + self.bn1 = nn.BatchNorm2d(20)
11 + self.bn2 = nn.BatchNorm2d(50)
12 + self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
13 + self.relu = nn.ReLU()
14 + self.dropout =nn.Dropout2d(p= dropout)
15 + # self.dropout = nn.Dropout(dropout)
16 + self.fc1 = nn.Linear(800, 500)
17 +
18 + # for m in self.modules():
19 + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
20 + # nn.init.kaiming_normal_(m.weight)
21 +
22 + def forward(self, x):
23 + bs = x.size(0)
24 + x = self.pool(self.relu(self.bn1(self.conv1(x))))
25 + x = self.pool(self.relu(self.bn2(self.dropout(self.conv2(x)))))
26 + x = x.view(bs, -1)
27 + # x = self.dropout(x)W
28 + x = self.fc1(x)
29 + return x
30 +
31 +
32 +class Classifier(nn.Module):
33 + def __init__(self, n_classes, dropout=0.5):
34 + super(Classifier, self).__init__()
35 + self.l1 = nn.Linear(500, n_classes)
36 +
37 + # for m in self.modules():
38 + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
39 + # nn.init.kaiming_normal_(m.weight)
40 +
41 + def forward(self, x):
42 + x = self.l1(x)
43 + return x
44 +
45 +
46 +class CNN(nn.Module):
47 + def __init__(self, in_channels=1, n_classes=10, target=False):
48 + super(CNN, self).__init__()
49 + self.encoder = Encoder(in_channels=in_channels)
50 + self.classifier = Classifier(n_classes)
51 + if target:
52 + for param in self.classifier.parameters():
53 + param.requires_grad = False
54 +
55 + def forward(self, x):
56 + x = self.encoder(x)
57 + x = self.classifier(x)
58 + return x
59 +
60 +
61 +class Discriminator(nn.Module):
62 + def __init__(self, h=500):
63 + super(Discriminator, self).__init__()
64 + self.l1 = nn.Linear(500, h)
65 + self.l2 = nn.Linear(h, h)
66 + self.l3 = nn.Linear(h, 2)
67 + # self.slope =params.slope
68 +
69 + self.relu = nn.ReLU()
70 +
71 + # for m in self.modules():
72 + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
73 + # nn.init.kaiming_normal_(m.weight)
74 +
75 + def forward(self, x):
76 + x = self.relu(self.l1(x))
77 + x = self.relu(self.l2(x))
78 + x = self.l3(x)
79 + return x
1 +import torch
2 +
3 +# params for dataset and data loader
4 +data_root = "data"
5 +image_size = 28
6 +
7 +#restore
8 +model_root = 'generated\\models'
9 +
10 +
11 +# params for target dataset
12 +# 'mnist_m', 'usps', 'svhn' "custom"
13 +
14 +#dataset root
15 +mnist_dataset_root = data_root
16 +mnist_m_dataset_root = data_root+'\\mnist_m'
17 +usps_dataset_root = data_root+'\\usps'
18 +svhn_dataset_root = data_root+'\\svhn'
19 +custom_dataset_root = data_root+'\\custom\\'
20 +
21 +# params for training network
22 +num_gpu = 1
23 +
24 +log_step_pre = 10
25 +log_step = 10
26 +eval_step_pre = 10
27 +
28 +##epoch
29 +save_step_pre = 100
30 +manual_seed = 1234
31 +
32 +d_input_dims = 500
33 +d_hidden_dims = 500
34 +d_output_dims = 2
35 +d_model_restore = 'generated\\models\\ADDA-critic-final.pt'
36 +
37 +## sorce target
38 +src_dataset = 'custom'
39 +tgt_dataset = 'custom'
40 +
41 +
42 +# params for optimizing models
43 +# # lam 0.3
44 +#mnist -> custom
45 +num_epochs_pre = 20
46 +num_epochs = 50
47 +batch_size = 128
48 +adp_batch_size = 128
49 +pre_c_learning_rate = 2e-4
50 +adp_c_learning_rate = 1e-4
51 +d_learning_rate = 1e-4
52 +beta1 = 0.5
53 +beta2 = 0.999
54 +weight_decay = 0
55 +
56 +
57 +# #usps -> custom
58 +# #lam 0.1
59 +# num_epochs_pre = 5
60 +# num_epochs = 20
61 +# batch_size = 256
62 +# pre_c_learning_rate = 1e-4
63 +# adp_c_learning_rate = 2e-5
64 +# d_learning_rate = 1e-5
65 +# beta1 = 0.5
66 +# beta2 = 0.999
67 +# weight_decay = 2e-4
68 +
69 +# #mnist_m -> custom
70 +# #lam 0.1
71 +# num_epochs_pre = 30
72 +# num_epochs = 50
73 +# batch_size = 256
74 +# adp_batch_size = 256
75 +# pre_c_learning_rate = 1e-3
76 +# adp_c_learning_rate = 1e-4
77 +# d_learning_rate = 1e-4
78 +# beta1 = 0.5
79 +# beta2 = 0.999
80 +# weight_decay = 2e-4
81 +
82 +# # params for optimizing models
83 +#lam 0.3
84 +# #mnist -> mnist_m
85 +# num_epochs_pre = 50
86 +# num_epochs = 100
87 +# batch_size = 256
88 +# adp_batch_size = 256
89 +# pre_c_learning_rate = 2e-4
90 +# adp_c_learning_rate = 2e-4
91 +# d_learning_rate = 2e-4
92 +# beta1 = 0.5
93 +# beta2 = 0.999
94 +# weight_decay = 0
95 +
96 +# # source 10000 target 50000
97 +# # params for optimizing models
98 +# #svhn -> mnist
99 +# num_epochs_pre = 20
100 +# num_epochs = 30
101 +# batch_size = 128
102 +# adp_batch_size = 128
103 +# pre_c_learning_rate = 2e-4
104 +# adp_c_learning_rate = 1e-4
105 +# d_learning_rate = 1e-4
106 +# beta1 = 0.5
107 +# beta2 = 0.999
108 +# weight_decay = 2.5e-4
109 +
110 +# # mnist->usps
111 +# num_epochs_pre = 50
112 +# num_epochs = 100
113 +# batch_size = 256
114 +# adp_batch_size = 256
115 +# pre_c_learning_rate = 2e-4
116 +# adp_c_learning_rate = 2e-4
117 +# d_learning_rate = 2e-4
118 +# beta1 = 0.5
119 +# beta2 = 0.999
120 +# weight_decay =0
121 +
122 +
123 +# # usps->mnist
124 +# num_epochs_pre = 50
125 +# num_epochs = 100
126 +# batch_size = 256
127 +# pre_c_learning_rate = 2e-4
128 +# adp_c_learning_rate = 2e-4
129 +# d_learning_rate =2e-4
130 +# beta1 = 0.5
131 +# beta2 = 0.999
132 +# weight_decay =0
133 +
134 +
135 +
136 +#
137 +use_load = False
138 +train =False
139 +
140 +#ratio mix target
141 +lammax = 0.0
142 +lammin = 0.0
143 +
144 +
145 +labelsmoothing = False
146 +smoothing = 0.3
...\ No newline at end of file ...\ No newline at end of file
1 +
2 +
3 +import os
4 +import random
5 +import torch
6 +import torch.backends.cudnn as cudnn
7 +from torch.autograd import Variable
8 +import params
9 +from dataset import get_mnist, get_mnist_m, get_usps,get_svhn,get_custom
10 +import numpy as np
11 +import itertools
12 +import torch.nn.functional as F
13 +import params
14 +
15 +def make_cuda(tensor):
16 + """Use CUDA if it's available."""
17 + if torch.cuda.is_available():
18 + tensor = tensor.cuda()
19 + return tensor
20 +
21 +
22 +def denormalize(x, std, mean):
23 + """Invert normalization, and then convert array into image."""
24 + out = x * std + mean
25 + return out.clamp(0, 1)
26 +
27 +
28 +def init_weights(layer):
29 + """Init weights for layers w.r.t. the original paper."""
30 + layer_name = layer.__class__.__name__
31 + if layer_name.find("Conv") != -1:
32 + layer.weight.data.normal_(0.0, 0.02)
33 + elif layer_name.find("BatchNorm") != -1:
34 + layer.weight.data.normal_(1.0, 0.02)
35 + layer.bias.data.fill_(0)
36 +
37 +
38 +def init_random_seed(manual_seed):
39 + """Init random seed."""
40 + seed = None
41 + if manual_seed is None:
42 + seed = random.randint(1, 10000)
43 + else:
44 + seed = manual_seed
45 + #for REPRODUCIBILITY
46 + torch.backends.cudnn.deterministic = True
47 + torch.backends.cudnn.benchmark = False
48 + print("use random seed: {}".format(seed))
49 + random.seed(seed)
50 + torch.manual_seed(seed)
51 + np.random.seed(seed)
52 +
53 + if torch.cuda.is_available():
54 + torch.cuda.manual_seed_all(seed)
55 +
56 +
57 +def get_data_loader(name,train=True,adp=False,size = 0):
58 + """Get data loader by name."""
59 + if name == "mnist":
60 + return get_mnist(train,adp,size)
61 + elif name == "mnist_m":
62 + return get_mnist_m(train,adp,size)
63 + elif name == "usps":
64 + return get_usps(train,adp,size)
65 + elif name == "svhn":
66 + return get_svhn(train,adp,size)
67 + elif name == "custom":
68 + return get_custom(train,adp,size)
69 +
70 +def init_model(net, restore=None):
71 + """Init models with cuda and weights."""
72 + # init weights of model
73 + # net.apply(init_weights)
74 +
75 + print(f'restore file : {restore}')
76 + # restore model weights
77 + if restore is not None and os.path.exists(restore):
78 + net.load_state_dict(torch.load(restore))
79 + net.restored = True
80 + print("Restore model from: {}".format(os.path.abspath(restore)))
81 +
82 + # check if cuda is available
83 + if torch.cuda.is_available():
84 + net.cuda()
85 +
86 + return net
87 +
88 +def save_model(net, filename):
89 + """Save trained model."""
90 + if not os.path.exists(params.model_root):
91 + os.makedirs(params.model_root)
92 + torch.save(net.state_dict(),
93 + os.path.join(params.model_root, filename))
94 + print("save pretrained model to: {}".format(os.path.join(params.model_root,
95 + filename)))
96 +
97 +class LabelSmoothingCrossEntropy(torch.nn.Module):
98 + def __init__(self,smoothing):
99 + super(LabelSmoothingCrossEntropy, self).__init__()
100 + self.smoothing = smoothing
101 + def forward(self, y, targets,smoothing=0.1):
102 + confidence = 1. - self.smoothing
103 + log_probs = F.log_softmax(y, dim=-1) # 예측 확률 계산
104 + true_probs = torch.zeros_like(log_probs)
105 + true_probs.fill_(self.smoothing / (y.shape[1] - 1))
106 + true_probs.scatter_(1, targets.data.unsqueeze(1), confidence) # 정답 인덱스의 정답 확률을 confidence로 변경
107 + return torch.mean(torch.sum(true_probs * -log_probs, dim=-1)) # negative log likelihood
108 +
109 +#mixup only data, not label
110 +def mixup_data(source,target):
111 + max = params.lammax
112 + min = params.lammin
113 + lam = (max-min)*torch.rand((1))+min
114 + lam=lam.cuda()
115 + target = target.cuda()
116 + mixed_source = (1 - lam) * source + lam* target
117 +
118 +
119 + return mixed_source, lam
120 +
121 +
122 +
123 +
124 +def mixup_criterion(criterion, pred, y_a, y_b, lam):
125 + return 0.9* criterion(pred, y_a) + 0.1 * criterion(pred, y_b)
...\ No newline at end of file ...\ No newline at end of file