yunjey

.

1 +mkdir -p mnist
1 mkdir -p svhn 2 mkdir -p svhn
2 wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat 3 wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat
3 wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat 4 wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat
5 +wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat
4 6
5 7
......
...@@ -3,15 +3,12 @@ from model import DTN ...@@ -3,15 +3,12 @@ from model import DTN
3 from solver import Solver 3 from solver import Solver
4 4
5 5
6 -
7 flags = tf.app.flags 6 flags = tf.app.flags
8 flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") 7 flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'")
9 -
10 FLAGS = flags.FLAGS 8 FLAGS = flags.FLAGS
11 9
12 def main(_): 10 def main(_):
13 11
14 -
15 model = DTN(mode=FLAGS.mode, learning_rate=0.0003) 12 model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
16 solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, 13 solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100,
17 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') 14 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model')
...@@ -22,9 +19,6 @@ def main(_): ...@@ -22,9 +19,6 @@ def main(_):
22 solver.train() 19 solver.train()
23 else: 20 else:
24 solver.eval() 21 solver.eval()
25 - 22 +
26 if __name__ == '__main__': 23 if __name__ == '__main__':
27 - tf.app.run()
28 -
29 -
30 -
...\ No newline at end of file ...\ No newline at end of file
24 + tf.app.run()
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -33,7 +33,6 @@ class DTN(object): ...@@ -33,7 +33,6 @@ class DTN(object):
33 if self.mode == 'pretrain': 33 if self.mode == 'pretrain':
34 net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') 34 net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out')
35 net = slim.flatten(net) 35 net = slim.flatten(net)
36 -
37 return net 36 return net
38 37
39 def generator(self, inputs, reuse=False): 38 def generator(self, inputs, reuse=False):
...@@ -106,12 +105,11 @@ class DTN(object): ...@@ -106,12 +105,11 @@ class DTN(object):
106 105
107 106
108 # source domain (svhn to mnist) 107 # source domain (svhn to mnist)
109 - with tf.name_scope('model_for_source_domain'): 108 + self.fx = self.content_extractor(self.src_images)
110 - self.fx = self.content_extractor(self.src_images) 109 + self.fake_images = self.generator(self.fx)
111 - self.fake_images = self.generator(self.fx) 110 + self.logits = self.discriminator(self.fake_images)
112 - self.logits = self.discriminator(self.fake_images) 111 + self.fgfx = self.content_extractor(self.fake_images, reuse=True)
113 - self.fgfx = self.content_extractor(self.fake_images, reuse=True) 112 +
114 -
115 # loss 113 # loss
116 self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits)) 114 self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits))
117 self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits)) 115 self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits))
...@@ -128,10 +126,9 @@ class DTN(object): ...@@ -128,10 +126,9 @@ class DTN(object):
128 f_vars = [var for var in t_vars if 'content_extractor' in var.name] 126 f_vars = [var for var in t_vars if 'content_extractor' in var.name]
129 127
130 # train op 128 # train op
131 - with tf.name_scope('source_train_op'): 129 + self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars)
132 - self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) 130 + self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars)
133 - self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) 131 + self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
134 - self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars)
135 132
136 # summary op 133 # summary op
137 d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src) 134 d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src)
...@@ -144,11 +141,10 @@ class DTN(object): ...@@ -144,11 +141,10 @@ class DTN(object):
144 sampled_images_summary]) 141 sampled_images_summary])
145 142
146 # target domain (mnist) 143 # target domain (mnist)
147 - with tf.name_scope('model_for_target_domain'): 144 + self.fx = self.content_extractor(self.trg_images, reuse=True)
148 - self.fx = self.content_extractor(self.trg_images, reuse=True) 145 + self.reconst_images = self.generator(self.fx, reuse=True)
149 - self.reconst_images = self.generator(self.fx, reuse=True) 146 + self.logits_fake = self.discriminator(self.reconst_images, reuse=True)
150 - self.logits_fake = self.discriminator(self.reconst_images, reuse=True) 147 + self.logits_real = self.discriminator(self.trg_images, reuse=True)
151 - self.logits_real = self.discriminator(self.trg_images, reuse=True)
152 148
153 # loss 149 # loss
154 self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake)) 150 self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake))
...@@ -161,16 +157,10 @@ class DTN(object): ...@@ -161,16 +157,10 @@ class DTN(object):
161 # optimizer 157 # optimizer
162 self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) 158 self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
163 self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) 159 self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate)
164 -
165 - t_vars = tf.trainable_variables()
166 - d_vars = [var for var in t_vars if 'discriminator' in var.name]
167 - g_vars = [var for var in t_vars if 'generator' in var.name]
168 - f_vars = [var for var in t_vars if 'content_extractor' in var.name]
169 160
170 # train op 161 # train op
171 - with tf.name_scope('target_train_op'): 162 + self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars)
172 - self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) 163 + self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
173 - self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars)
174 164
175 # summary op 165 # summary op
176 d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg) 166 d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg)
......
...@@ -9,7 +9,7 @@ import scipy.misc ...@@ -9,7 +9,7 @@ import scipy.misc
9 9
10 class Solver(object): 10 class Solver(object):
11 11
12 - def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, 12 + def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100,
13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', 13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
14 model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): 14 model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'):
15 self.model = model 15 self.model = model
...@@ -29,7 +29,12 @@ class Solver(object): ...@@ -29,7 +29,12 @@ class Solver(object):
29 29
30 def load_svhn(self, image_dir, split='train'): 30 def load_svhn(self, image_dir, split='train'):
31 print ('loading svhn image dataset..') 31 print ('loading svhn image dataset..')
32 - image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' 32 +
33 + if self.model.mode == 'pretrain':
34 + image_file = 'extra_32x32.mat' if split=='train' else 'test_32x32.mat'
35 + else:
36 + image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat'
37 +
33 image_dir = os.path.join(image_dir, image_file) 38 image_dir = os.path.join(image_dir, image_file)
34 svhn = scipy.io.loadmat(image_dir) 39 svhn = scipy.io.loadmat(image_dir)
35 images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 40 images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1
...@@ -136,10 +141,10 @@ class Solver(object): ...@@ -136,10 +141,10 @@ class Solver(object):
136 sess.run([model.g_train_op_src], feed_dict) 141 sess.run([model.g_train_op_src], feed_dict)
137 sess.run([model.g_train_op_src], feed_dict) 142 sess.run([model.g_train_op_src], feed_dict)
138 sess.run([model.g_train_op_src], feed_dict) 143 sess.run([model.g_train_op_src], feed_dict)
144 +
139 if i % 15 == 0: 145 if i % 15 == 0:
140 sess.run(model.f_train_op_src, feed_dict) 146 sess.run(model.f_train_op_src, feed_dict)
141 147
142 -
143 if (step+1) % 10 == 0: 148 if (step+1) % 10 == 0:
144 summary, dl, gl, fl = sess.run([model.summary_op_src, \ 149 summary, dl, gl, fl = sess.run([model.summary_op_src, \
145 model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) 150 model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict)
...@@ -169,7 +174,6 @@ class Solver(object): ...@@ -169,7 +174,6 @@ class Solver(object):
169 saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) 174 saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1)
170 print ('model/dtn-%d saved' %(step+1)) 175 print ('model/dtn-%d saved' %(step+1))
171 176
172 -
173 def eval(self): 177 def eval(self):
174 # build model 178 # build model
175 model = self.model 179 model = self.model
...@@ -195,4 +199,4 @@ class Solver(object): ...@@ -195,4 +199,4 @@ class Solver(object):
195 merged = self.merge_images(batch_images, sampled_batch_images) 199 merged = self.merge_images(batch_images, sampled_batch_images)
196 path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size)) 200 path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size))
197 scipy.misc.imsave(path, merged) 201 scipy.misc.imsave(path, merged)
198 - print ('saved %s' %path) 202 + print ('saved %s' %path)
...\ No newline at end of file ...\ No newline at end of file
......