Showing
4 changed files
with
27 additions
and
37 deletions
1 | +mkdir -p mnist | ||
1 | mkdir -p svhn | 2 | mkdir -p svhn |
2 | wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat | 3 | wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat |
3 | wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat | 4 | wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat |
5 | +wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat | ||
4 | 6 | ||
5 | 7 | ... | ... |
... | @@ -3,15 +3,12 @@ from model import DTN | ... | @@ -3,15 +3,12 @@ from model import DTN |
3 | from solver import Solver | 3 | from solver import Solver |
4 | 4 | ||
5 | 5 | ||
6 | - | ||
7 | flags = tf.app.flags | 6 | flags = tf.app.flags |
8 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") | 7 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") |
9 | - | ||
10 | FLAGS = flags.FLAGS | 8 | FLAGS = flags.FLAGS |
11 | 9 | ||
12 | def main(_): | 10 | def main(_): |
13 | 11 | ||
14 | - | ||
15 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) | 12 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) |
16 | solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, | 13 | solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, |
17 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') | 14 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') |
... | @@ -22,9 +19,6 @@ def main(_): | ... | @@ -22,9 +19,6 @@ def main(_): |
22 | solver.train() | 19 | solver.train() |
23 | else: | 20 | else: |
24 | solver.eval() | 21 | solver.eval() |
25 | - | 22 | + |
26 | if __name__ == '__main__': | 23 | if __name__ == '__main__': |
27 | - tf.app.run() | ||
28 | - | ||
29 | - | ||
30 | - | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
24 | + tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -33,7 +33,6 @@ class DTN(object): | ... | @@ -33,7 +33,6 @@ class DTN(object): |
33 | if self.mode == 'pretrain': | 33 | if self.mode == 'pretrain': |
34 | net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') | 34 | net = slim.conv2d(net, 10, [1, 1], padding='VALID', scope='out') |
35 | net = slim.flatten(net) | 35 | net = slim.flatten(net) |
36 | - | ||
37 | return net | 36 | return net |
38 | 37 | ||
39 | def generator(self, inputs, reuse=False): | 38 | def generator(self, inputs, reuse=False): |
... | @@ -106,12 +105,11 @@ class DTN(object): | ... | @@ -106,12 +105,11 @@ class DTN(object): |
106 | 105 | ||
107 | 106 | ||
108 | # source domain (svhn to mnist) | 107 | # source domain (svhn to mnist) |
109 | - with tf.name_scope('model_for_source_domain'): | 108 | + self.fx = self.content_extractor(self.src_images) |
110 | - self.fx = self.content_extractor(self.src_images) | 109 | + self.fake_images = self.generator(self.fx) |
111 | - self.fake_images = self.generator(self.fx) | 110 | + self.logits = self.discriminator(self.fake_images) |
112 | - self.logits = self.discriminator(self.fake_images) | 111 | + self.fgfx = self.content_extractor(self.fake_images, reuse=True) |
113 | - self.fgfx = self.content_extractor(self.fake_images, reuse=True) | 112 | + |
114 | - | ||
115 | # loss | 113 | # loss |
116 | self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits)) | 114 | self.d_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.zeros_like(self.logits)) |
117 | self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits)) | 115 | self.g_loss_src = slim.losses.sigmoid_cross_entropy(self.logits, tf.ones_like(self.logits)) |
... | @@ -128,10 +126,9 @@ class DTN(object): | ... | @@ -128,10 +126,9 @@ class DTN(object): |
128 | f_vars = [var for var in t_vars if 'content_extractor' in var.name] | 126 | f_vars = [var for var in t_vars if 'content_extractor' in var.name] |
129 | 127 | ||
130 | # train op | 128 | # train op |
131 | - with tf.name_scope('source_train_op'): | 129 | + self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) |
132 | - self.d_train_op_src = slim.learning.create_train_op(self.d_loss_src, self.d_optimizer_src, variables_to_train=d_vars) | 130 | + self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) |
133 | - self.g_train_op_src = slim.learning.create_train_op(self.g_loss_src, self.g_optimizer_src, variables_to_train=g_vars) | 131 | + self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) |
134 | - self.f_train_op_src = slim.learning.create_train_op(self.f_loss_src, self.f_optimizer_src, variables_to_train=f_vars) | ||
135 | 132 | ||
136 | # summary op | 133 | # summary op |
137 | d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src) | 134 | d_loss_src_summary = tf.summary.scalar('src_d_loss', self.d_loss_src) |
... | @@ -144,11 +141,10 @@ class DTN(object): | ... | @@ -144,11 +141,10 @@ class DTN(object): |
144 | sampled_images_summary]) | 141 | sampled_images_summary]) |
145 | 142 | ||
146 | # target domain (mnist) | 143 | # target domain (mnist) |
147 | - with tf.name_scope('model_for_target_domain'): | 144 | + self.fx = self.content_extractor(self.trg_images, reuse=True) |
148 | - self.fx = self.content_extractor(self.trg_images, reuse=True) | 145 | + self.reconst_images = self.generator(self.fx, reuse=True) |
149 | - self.reconst_images = self.generator(self.fx, reuse=True) | 146 | + self.logits_fake = self.discriminator(self.reconst_images, reuse=True) |
150 | - self.logits_fake = self.discriminator(self.reconst_images, reuse=True) | 147 | + self.logits_real = self.discriminator(self.trg_images, reuse=True) |
151 | - self.logits_real = self.discriminator(self.trg_images, reuse=True) | ||
152 | 148 | ||
153 | # loss | 149 | # loss |
154 | self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake)) | 150 | self.d_loss_fake_trg = slim.losses.sigmoid_cross_entropy(self.logits_fake, tf.zeros_like(self.logits_fake)) |
... | @@ -161,16 +157,10 @@ class DTN(object): | ... | @@ -161,16 +157,10 @@ class DTN(object): |
161 | # optimizer | 157 | # optimizer |
162 | self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | 158 | self.d_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) |
163 | self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) | 159 | self.g_optimizer_trg = tf.train.AdamOptimizer(self.learning_rate) |
164 | - | ||
165 | - t_vars = tf.trainable_variables() | ||
166 | - d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
167 | - g_vars = [var for var in t_vars if 'generator' in var.name] | ||
168 | - f_vars = [var for var in t_vars if 'content_extractor' in var.name] | ||
169 | 160 | ||
170 | # train op | 161 | # train op |
171 | - with tf.name_scope('target_train_op'): | 162 | + self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) |
172 | - self.d_train_op_trg = slim.learning.create_train_op(self.d_loss_trg, self.d_optimizer_trg, variables_to_train=d_vars) | 163 | + self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) |
173 | - self.g_train_op_trg = slim.learning.create_train_op(self.g_loss_trg, self.g_optimizer_trg, variables_to_train=g_vars) | ||
174 | 164 | ||
175 | # summary op | 165 | # summary op |
176 | d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg) | 166 | d_loss_fake_trg_summary = tf.summary.scalar('trg_d_loss_fake', self.d_loss_fake_trg) | ... | ... |
... | @@ -9,7 +9,7 @@ import scipy.misc | ... | @@ -9,7 +9,7 @@ import scipy.misc |
9 | 9 | ||
10 | class Solver(object): | 10 | class Solver(object): |
11 | 11 | ||
12 | - def __init__(self, model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, | 12 | + def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100, |
13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', | 13 | svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', |
14 | model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): | 14 | model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): |
15 | self.model = model | 15 | self.model = model |
... | @@ -29,7 +29,12 @@ class Solver(object): | ... | @@ -29,7 +29,12 @@ class Solver(object): |
29 | 29 | ||
30 | def load_svhn(self, image_dir, split='train'): | 30 | def load_svhn(self, image_dir, split='train'): |
31 | print ('loading svhn image dataset..') | 31 | print ('loading svhn image dataset..') |
32 | - image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' | 32 | + |
33 | + if self.model.mode == 'pretrain': | ||
34 | + image_file = 'extra_32x32.mat' if split=='train' else 'test_32x32.mat' | ||
35 | + else: | ||
36 | + image_file = 'train_32x32.mat' if split=='train' else 'test_32x32.mat' | ||
37 | + | ||
33 | image_dir = os.path.join(image_dir, image_file) | 38 | image_dir = os.path.join(image_dir, image_file) |
34 | svhn = scipy.io.loadmat(image_dir) | 39 | svhn = scipy.io.loadmat(image_dir) |
35 | images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 | 40 | images = np.transpose(svhn['X'], [3, 0, 1, 2]) / 127.5 - 1 |
... | @@ -136,10 +141,10 @@ class Solver(object): | ... | @@ -136,10 +141,10 @@ class Solver(object): |
136 | sess.run([model.g_train_op_src], feed_dict) | 141 | sess.run([model.g_train_op_src], feed_dict) |
137 | sess.run([model.g_train_op_src], feed_dict) | 142 | sess.run([model.g_train_op_src], feed_dict) |
138 | sess.run([model.g_train_op_src], feed_dict) | 143 | sess.run([model.g_train_op_src], feed_dict) |
144 | + | ||
139 | if i % 15 == 0: | 145 | if i % 15 == 0: |
140 | sess.run(model.f_train_op_src, feed_dict) | 146 | sess.run(model.f_train_op_src, feed_dict) |
141 | 147 | ||
142 | - | ||
143 | if (step+1) % 10 == 0: | 148 | if (step+1) % 10 == 0: |
144 | summary, dl, gl, fl = sess.run([model.summary_op_src, \ | 149 | summary, dl, gl, fl = sess.run([model.summary_op_src, \ |
145 | model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) | 150 | model.d_loss_src, model.g_loss_src, model.f_loss_src], feed_dict) |
... | @@ -169,7 +174,6 @@ class Solver(object): | ... | @@ -169,7 +174,6 @@ class Solver(object): |
169 | saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) | 174 | saver.save(sess, os.path.join(self.model_save_path, 'dtn'), global_step=step+1) |
170 | print ('model/dtn-%d saved' %(step+1)) | 175 | print ('model/dtn-%d saved' %(step+1)) |
171 | 176 | ||
172 | - | ||
173 | def eval(self): | 177 | def eval(self): |
174 | # build model | 178 | # build model |
175 | model = self.model | 179 | model = self.model |
... | @@ -195,4 +199,4 @@ class Solver(object): | ... | @@ -195,4 +199,4 @@ class Solver(object): |
195 | merged = self.merge_images(batch_images, sampled_batch_images) | 199 | merged = self.merge_images(batch_images, sampled_batch_images) |
196 | path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size)) | 200 | path = os.path.join(self.sample_save_path, 'sample-%d-to-%d.png' %(i*self.batch_size, (i+1)*self.batch_size)) |
197 | scipy.misc.imsave(path, merged) | 201 | scipy.misc.imsave(path, merged) |
198 | - print ('saved %s' %path) | 202 | + print ('saved %s' %path) |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment