Toggle navigation
Toggle navigation
This project
Loading...
Sign in
최강혁
/
dddd
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Graphs
Network
Create a new issue
Commits
Issue Boards
Authored by
yunjey
2017-01-22 05:43:59 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
7457f7c3e56d6cc63f727f7849d5c7910b0d5d90
7457f7c3
1 parent
ae88cf03
.
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
19 deletions
download.sh
main.py
model.py
solver.py
download.sh
View file @
7457f7c
mkdir -p mnist
mkdir -p svhn
wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat
wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat
wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat
...
...
main.py
View file @
7457f7c
...
...
@@ -3,15 +3,12 @@ from model import DTN
from
solver
import
Solver
flags
=
tf
.
app
.
flags
flags
.
DEFINE_string
(
'mode'
,
'train'
,
"'pretrain', 'train' or 'eval'"
)
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
model
=
DTN
(
mode
=
FLAGS
.
mode
,
learning_rate
=
0.0003
)
solver
=
Solver
(
model
,
batch_size
=
100
,
pretrain_iter
=
5000
,
train_iter
=
2000
,
sample_iter
=
100
,
svhn_dir
=
'svhn'
,
mnist_dir
=
'mnist'
,
log_dir
=
'logs'
,
model_save_path
=
'model'
)
...
...
@@ -25,6 +22,3 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
\ No newline at end of file
\ No newline at end of file
...
...
model.py
View file @
7457f7c
...
...
@@ -33,7 +33,6 @@ class DTN(object):
if
self
.
mode
==
'pretrain'
:
net
=
slim
.
conv2d
(
net
,
10
,
[
1
,
1
],
padding
=
'VALID'
,
scope
=
'out'
)
net
=
slim
.
flatten
(
net
)
return
net
def
generator
(
self
,
inputs
,
reuse
=
False
):
...
...
@@ -106,7 +105,6 @@ class DTN(object):
# source domain (svhn to mnist)
with
tf
.
name_scope
(
'model_for_source_domain'
):
self
.
fx
=
self
.
content_extractor
(
self
.
src_images
)
self
.
fake_images
=
self
.
generator
(
self
.
fx
)
self
.
logits
=
self
.
discriminator
(
self
.
fake_images
)
...
...
@@ -128,7 +126,6 @@ class DTN(object):
f_vars
=
[
var
for
var
in
t_vars
if
'content_extractor'
in
var
.
name
]
# train op
with
tf
.
name_scope
(
'source_train_op'
):
self
.
d_train_op_src
=
slim
.
learning
.
create_train_op
(
self
.
d_loss_src
,
self
.
d_optimizer_src
,
variables_to_train
=
d_vars
)
self
.
g_train_op_src
=
slim
.
learning
.
create_train_op
(
self
.
g_loss_src
,
self
.
g_optimizer_src
,
variables_to_train
=
g_vars
)
self
.
f_train_op_src
=
slim
.
learning
.
create_train_op
(
self
.
f_loss_src
,
self
.
f_optimizer_src
,
variables_to_train
=
f_vars
)
...
...
@@ -144,7 +141,6 @@ class DTN(object):
sampled_images_summary
])
# target domain (mnist)
with
tf
.
name_scope
(
'model_for_target_domain'
):
self
.
fx
=
self
.
content_extractor
(
self
.
trg_images
,
reuse
=
True
)
self
.
reconst_images
=
self
.
generator
(
self
.
fx
,
reuse
=
True
)
self
.
logits_fake
=
self
.
discriminator
(
self
.
reconst_images
,
reuse
=
True
)
...
...
@@ -162,13 +158,7 @@ class DTN(object):
self
.
d_optimizer_trg
=
tf
.
train
.
AdamOptimizer
(
self
.
learning_rate
)
self
.
g_optimizer_trg
=
tf
.
train
.
AdamOptimizer
(
self
.
learning_rate
)
t_vars
=
tf
.
trainable_variables
()
d_vars
=
[
var
for
var
in
t_vars
if
'discriminator'
in
var
.
name
]
g_vars
=
[
var
for
var
in
t_vars
if
'generator'
in
var
.
name
]
f_vars
=
[
var
for
var
in
t_vars
if
'content_extractor'
in
var
.
name
]
# train op
with
tf
.
name_scope
(
'target_train_op'
):
self
.
d_train_op_trg
=
slim
.
learning
.
create_train_op
(
self
.
d_loss_trg
,
self
.
d_optimizer_trg
,
variables_to_train
=
d_vars
)
self
.
g_train_op_trg
=
slim
.
learning
.
create_train_op
(
self
.
g_loss_trg
,
self
.
g_optimizer_trg
,
variables_to_train
=
g_vars
)
...
...
solver.py
View file @
7457f7c
...
...
@@ -9,7 +9,7 @@ import scipy.misc
class
Solver
(
object
):
def
__init__
(
self
,
model
,
batch_size
=
100
,
pretrain_iter
=
5
000
,
train_iter
=
2000
,
sample_iter
=
100
,
def
__init__
(
self
,
model
,
batch_size
=
100
,
pretrain_iter
=
10
000
,
train_iter
=
2000
,
sample_iter
=
100
,
svhn_dir
=
'svhn'
,
mnist_dir
=
'mnist'
,
log_dir
=
'logs'
,
sample_save_path
=
'sample'
,
model_save_path
=
'model'
,
pretrained_model
=
'model/svhn_model-10000'
,
test_model
=
'model/dtn-2000'
):
self
.
model
=
model
...
...
@@ -29,7 +29,12 @@ class Solver(object):
def
load_svhn
(
self
,
image_dir
,
split
=
'train'
):
print
(
'loading svhn image dataset..'
)
if
self
.
model
.
mode
==
'pretrain'
:
image_file
=
'extra_32x32.mat'
if
split
==
'train'
else
'test_32x32.mat'
else
:
image_file
=
'train_32x32.mat'
if
split
==
'train'
else
'test_32x32.mat'
image_dir
=
os
.
path
.
join
(
image_dir
,
image_file
)
svhn
=
scipy
.
io
.
loadmat
(
image_dir
)
images
=
np
.
transpose
(
svhn
[
'X'
],
[
3
,
0
,
1
,
2
])
/
127.5
-
1
...
...
@@ -136,10 +141,10 @@ class Solver(object):
sess
.
run
([
model
.
g_train_op_src
],
feed_dict
)
sess
.
run
([
model
.
g_train_op_src
],
feed_dict
)
sess
.
run
([
model
.
g_train_op_src
],
feed_dict
)
if
i
%
15
==
0
:
sess
.
run
(
model
.
f_train_op_src
,
feed_dict
)
if
(
step
+
1
)
%
10
==
0
:
summary
,
dl
,
gl
,
fl
=
sess
.
run
([
model
.
summary_op_src
,
\
model
.
d_loss_src
,
model
.
g_loss_src
,
model
.
f_loss_src
],
feed_dict
)
...
...
@@ -169,7 +174,6 @@ class Solver(object):
saver
.
save
(
sess
,
os
.
path
.
join
(
self
.
model_save_path
,
'dtn'
),
global_step
=
step
+
1
)
print
(
'model/dtn-
%
d saved'
%
(
step
+
1
))
def
eval
(
self
):
# build model
model
=
self
.
model
...
...
Please
register
or
login
to post a comment