yunjey

.

mkdir -p mnist
mkdir -p svhn
wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat
wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat
wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat
......
......@@ -3,15 +3,12 @@ from model import DTN
from solver import Solver
flags = tf.app.flags
flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'")
FLAGS = flags.FLAGS
def main(_):
model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model')
......@@ -25,6 +22,3 @@ def main(_):
if __name__ == '__main__':
tf.app.run()
\ No newline at end of file
\ No newline at end of file
......
......@@ -33,7 +33,6 @@ class DTN(object):
if self.mode == 'pretrain':
net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
net = slim.flatten(net)
return net
def generator(self, inputs, reuse=False):
......@@ -106,7 +105,6 @@ class DTN(object):
# source domain (svhn to mnist)
with tf.name_scope('model_for_source_domain'):
self.fx = self.content_extractor(self.src_images)
self.fake_images = self.generator(self.fx)
self.logits = self.discriminator(self.fake_images)
......@@ -128,7 +126,6 @@ class DTN(object):
f_vars = [var for var in t_vars if 'content_extractor' in var.name]
# train op
with tf.name_scope('source_train_op'):
self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars)
self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars)
self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
......@@ -144,7 +141,6 @@ class DTN(object):
sampled_images_summary])
# target domain (mnist)
with tf.name_scope('model_for_target_domain'):
self.fx = self.content_extractor(self.trg_images, reuse=True)
self.reconst_images = self.generator(self.fx, reuse=True)
self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
......@@ -162,13 +158,7 @@ class DTN(object):
self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
f_vars = [var for var in t_vars if 'content_extractor' in var.name]
# train op
with tf.name_scope('target_train_op'):
self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars)
self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
......
......@@ -9,7 +9,7 @@ import scipy.misc
class Solver(object):
def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100,
svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'):
self.model = model
......@@ -29,7 +29,12 @@ class Solver(object):
def load_svhn(self, image_dir, split='train'):
print ('loading svhn image dataset..')
if self.model.mode == 'pretrain':
image_file = 'extra_32x32.mat' if split=='train' else 'test_32x32.mat'
else:
image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat'
image_dir = os.path.join(image_dir, image_file)
svhn = scipy.io.loadmat(image_dir)
images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1
......@@ -136,10 +141,10 @@ class Solver(object):
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
if i % 15 == 0:
sess.run(model.f_train_op_src, feed_dict)
if (step+1) % 10 == 0:
summary, dl, gl, fl = sess.run([model.summary_op_src, \
model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict)
......@@ -169,7 +174,6 @@ class Solver(object):
saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1)
print ('model/dtn-%d saved' %(step+1))
def eval(self):
# build model
model = self.model
......