yunjey

model added

1 import tensorflow as tf 1 import tensorflow as tf
2 from ops import * 2 from ops import *
3 +from config import *
4 +
3 5
4 class DTN(object): 6 class DTN(object):
5 """Domain Transfer Network for unsupervised cross-domain image generation 7 """Domain Transfer Network for unsupervised cross-domain image generation
...@@ -7,7 +9,7 @@ class DTN(object): ...@@ -7,7 +9,7 @@ class DTN(object):
7 Construct discriminator and generator to prepare for training. 9 Construct discriminator and generator to prepare for training.
8 """ 10 """
9 11
10 - def __init__(self, batch_size=100, learning_rate=0.0002, image_size=32, output_size=32, 12 + def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32,
11 dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): 13 dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
12 """ 14 """
13 Args: 15 Args:
...@@ -39,6 +41,7 @@ class DTN(object): ...@@ -39,6 +41,7 @@ class DTN(object):
39 self.d_bn1 = batch_norm(name='d_bn1') 41 self.d_bn1 = batch_norm(name='d_bn1')
40 self.d_bn2 = batch_norm(name='d_bn2') 42 self.d_bn2 = batch_norm(name='d_bn2')
41 self.d_bn3 = batch_norm(name='d_bn3') 43 self.d_bn3 = batch_norm(name='d_bn3')
44 + self.d_bn4 = batch_norm(name='d_bn4')
42 45
43 self.g_bn1 = batch_norm(name='g_bn1') 46 self.g_bn1 = batch_norm(name='g_bn1')
44 self.g_bn2 = batch_norm(name='g_bn2') 47 self.g_bn2 = batch_norm(name='g_bn2')
...@@ -52,7 +55,7 @@ class DTN(object): ...@@ -52,7 +55,7 @@ class DTN(object):
52 55
53 56
54 57
55 - def function_f(self, images, reuse=False): 58 + def function_f(self, images, reuse=False, train=True):
56 """f consistancy 59 """f consistancy
57 60
58 Args: 61 Args:
...@@ -62,10 +65,10 @@ class DTN(object): ...@@ -62,10 +65,10 @@ class DTN(object):
62 out: output vectors, of shape (batch_size, dim_f_out) 65 out: output vectors, of shape (batch_size, dim_f_out)
63 """ 66 """
64 with tf.variable_scope('function_f', reuse=reuse): 67 with tf.variable_scope('function_f', reuse=reuse):
65 - h1 = lrelu(conv2d(images, self.dim_ff, name='f_h1')) # (batch_size, 16, 16, 64) 68 + h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64)
66 - h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_ff*2, name='f_h2'))) # (batch_size, 8, 8 128) 69 + h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
67 - h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_ff*4, name='f_h3'))) # (batch_size, 4, 4, 256) 70 + h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
68 - h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_ff*8, name='f_h4'))) # (batch_size, 2, 2, 512) 71 + h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
69 72
70 h4 = tf.reshape(h4, [self.batch_size,-1]) 73 h4 = tf.reshape(h4, [self.batch_size,-1])
71 out = linear(h4, self.dim_fout, name='f_out') 74 out = linear(h4, self.dim_fout, name='f_out')
...@@ -128,10 +131,10 @@ class DTN(object): ...@@ -128,10 +131,10 @@ class DTN(object):
128 with tf.variable_scope('discriminator', reuse=reuse): 131 with tf.variable_scope('discriminator', reuse=reuse):
129 132
130 # convolution layer 133 # convolution layer
131 - h1 = lrelu(conv2d(images, self.dim_df, name='d_h1')) # (batch_size, 16, 16, 64) 134 + h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64)
132 - h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) 135 + h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
133 - h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) 136 + h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
134 - h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) 137 + h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
135 138
136 # fully connected layer 139 # fully connected layer
137 h4 = tf.reshape(h4, [self.batch_size, -1]) 140 h4 = tf.reshape(h4, [self.batch_size, -1])
...@@ -149,7 +152,8 @@ class DTN(object): ...@@ -149,7 +152,8 @@ class DTN(object):
149 self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) 152 self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
150 self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) 153 self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
151 154
152 - # construct generator for test phase 155 + # construct generator for test phase (use moving average and variance for batch norm)
156 + self.f_x = self.function_f(self.images, reuse=True, train=False)
153 self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) 157 self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3)
154 158
155 159
...@@ -159,7 +163,7 @@ class DTN(object): ...@@ -159,7 +163,7 @@ class DTN(object):
159 self.d_loss = self.d_loss_real + self.d_loss_fake 163 self.d_loss = self.d_loss_real + self.d_loss_fake
160 self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) 164 self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
161 self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID 165 self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
162 - self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) # L_CONST 166 + self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
163 167
164 # divide variables for discriminator and generator 168 # divide variables for discriminator and generator
165 t_vars = tf.trainable_variables() 169 t_vars = tf.trainable_variables()
...@@ -173,22 +177,27 @@ class DTN(object): ...@@ -173,22 +177,27 @@ class DTN(object):
173 self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) 177 self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
174 self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) 178 self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars)
175 self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) 179 self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
176 - self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars) 180 + self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars)
177 181
178 182
179 # summary ops for tensorboard visualization 183 # summary ops for tensorboard visualization
180 - tf.scalar_summary('d_loss_real', self.d_loss_real) 184 + scalar_summary('d_loss_real', self.d_loss_real)
181 - tf.scalar_summary('d_loss_fake', self.d_loss_fake) 185 + scalar_summary('d_loss_fake', self.d_loss_fake)
182 - tf.scalar_summary('d_loss', self.d_loss) 186 + scalar_summary('d_loss', self.d_loss)
183 - tf.scalar_summary('g_loss', self.g_loss) 187 + scalar_summary('g_loss', self.g_loss)
184 - tf.scalar_summary('g_const_loss', self.g_const_loss) 188 + scalar_summary('g_const_loss', self.g_const_loss)
185 - tf.scalar_summary('f_const_loss', self.f_const_loss) 189 + scalar_summary('f_const_loss', self.f_const_loss)
186 - tf.image_summary('original_images', self.images, max_images=6) 190 +
187 - tf.image_summary('sampled_images', self.sampled_images, max_images=6) 191 + try:
192 + image_summary('original_images', self.images, max_outputs=4)
193 + image_summary('sampled_images', self.sampled_images, max_outputs=4)
194 + except:
195 + image_summary('original_images', self.images, max_images=4)
196 + image_summary('sampled_images', self.sampled_images, max_images=4)
188 197
189 for var in tf.trainable_variables(): 198 for var in tf.trainable_variables():
190 - tf.histogram_summary(var.op.name, var) 199 + histogram_summary(var.op.name, var)
191 200
192 - self.summary_op = tf.merge_all_summaries() 201 + self.summary_op = merge_summary()
193 202
194 self.saver = tf.train.Saver() 203 self.saver = tf.train.Saver()
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -3,33 +3,36 @@ import numpy as np ...@@ -3,33 +3,36 @@ import numpy as np
3 import os 3 import os
4 import scipy.io 4 import scipy.io
5 import hickle 5 import hickle
6 -from scipy import ndimage 6 +import scipy.misc
7 +from config import SummaryWriter
7 8
8 9
9 class Solver(object): 10 class Solver(object):
10 - """Load dataset and train DCGAN""" 11 + """Load dataset and train and test the model"""
11 12
12 - def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'): 13 + def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/',
14 + log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100):
13 self.model = model 15 self.model = model
14 self.num_epoch = num_epoch 16 self.num_epoch = num_epoch
15 self.mnist_path = mnist_path 17 self.mnist_path = mnist_path
16 self.svhn_path = svhn_path 18 self.svhn_path = svhn_path
17 self.model_save_path = model_save_path 19 self.model_save_path = model_save_path
18 self.log_path = log_path 20 self.log_path = log_path
21 + self.sample_path = sample_path
22 + self.test_model_path = test_model_path
23 + self.sample_iter = sample_iter
19 24
20 # create directory if not exists 25 # create directory if not exists
21 if not os.path.exists(log_path): 26 if not os.path.exists(log_path):
22 os.makedirs(log_path) 27 os.makedirs(log_path)
23 if not os.path.exists(model_save_path): 28 if not os.path.exists(model_save_path):
24 os.makedirs(model_save_path) 29 os.makedirs(model_save_path)
30 + if not os.path.exists(sample_path):
31 + os.makedirs(sample_path)
25 32
26 # construct the dcgan model 33 # construct the dcgan model
27 model.build_model() 34 model.build_model()
28 35
29 - # load dataset
30 - self.svhn = self.load_svhn(self.svhn_path)
31 - self.mnist = self.load_mnist(self.mnist_path)
32 -
33 36
34 def load_svhn(self, image_path, split='train'): 37 def load_svhn(self, image_path, split='train'):
35 print ('loading svhn image dataset..') 38 print ('loading svhn image dataset..')
...@@ -61,12 +64,27 @@ class Solver(object): ...@@ -61,12 +64,27 @@ class Solver(object):
61 return images 64 return images
62 65
63 66
67 + def merge_images(self, sources, targets, k=10):
68 + _, h, w, _ = sources.shape
69 + row = int(np.sqrt(self.model.batch_size))
70 + merged = np.zeros([row*h, row*w*2, 3])
71 +
72 + for idx, (s, t) in enumerate(zip(sources, targets)):
73 + i = idx // row
74 + j = idx % row
75 + merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s
76 + merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t
77 +
78 + return merged
79 +
80 +
64 def train(self): 81 def train(self):
65 model=self.model 82 model=self.model
66 83
67 - #load image dataset 84 + # load image dataset
68 - svhn = self.svhn 85 + svhn = self.load_svhn(self.svhn_path)
69 - mnist = self.mnist 86 + mnist = self.load_mnist(self.mnist_path)
87 +
70 88
71 num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) 89 num_iter_per_epoch = int(mnist.shape[0] / model.batch_size)
72 90
...@@ -74,18 +92,24 @@ class Solver(object): ...@@ -74,18 +92,24 @@ class Solver(object):
74 config.gpu_options.allow_growth = True 92 config.gpu_options.allow_growth = True
75 with tf.Session(config=config) as sess: 93 with tf.Session(config=config) as sess:
76 # initialize parameters 94 # initialize parameters
95 + try:
96 + tf.global_variables_initializer().run()
97 + except:
77 tf.initialize_all_variables().run() 98 tf.initialize_all_variables().run()
78 - summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) 99 +
100 + summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
79 101
80 for e in range(self.num_epoch): 102 for e in range(self.num_epoch):
81 for i in range(num_iter_per_epoch): 103 for i in range(num_iter_per_epoch):
82 104
83 - # train model for domain S 105 + # train model for source domain S
84 image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] 106 image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
85 feed_dict = {model.images: image_batch} 107 feed_dict = {model.images: image_batch}
86 sess.run(model.d_optimizer_fake, feed_dict) 108 sess.run(model.d_optimizer_fake, feed_dict)
87 - sess.run(model.f_optimizer_const, feed_dict)
88 sess.run(model.g_optimizer, feed_dict) 109 sess.run(model.g_optimizer, feed_dict)
110 + sess.run(model.g_optimizer, feed_dict)
111 + if i % 3 == 0:
112 + sess.run(model.f_optimizer_const, feed_dict)
89 113
90 if i % 10 == 0: 114 if i % 10 == 0:
91 feed_dict = {model.images: image_batch} 115 feed_dict = {model.images: image_batch}
...@@ -93,16 +117,44 @@ class Solver(object): ...@@ -93,16 +117,44 @@ class Solver(object):
93 summary_writer.add_summary(summary, e*num_iter_per_epoch + i) 117 summary_writer.add_summary(summary, e*num_iter_per_epoch + i)
94 print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) 118 print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss))
95 119
96 - # train model for domain T 120 + # train model for target domain T
97 image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] 121 image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size]
98 feed_dict = {model.images: image_batch} 122 feed_dict = {model.images: image_batch}
99 sess.run(model.d_optimizer_real, feed_dict) 123 sess.run(model.d_optimizer_real, feed_dict)
100 sess.run(model.d_optimizer_fake, feed_dict) 124 sess.run(model.d_optimizer_fake, feed_dict)
101 sess.run(model.g_optimizer, feed_dict) 125 sess.run(model.g_optimizer, feed_dict)
126 + sess.run(model.g_optimizer, feed_dict)
127 + sess.run(model.g_optimizer, feed_dict)
128 + sess.run(model.g_optimizer_const, feed_dict)
102 sess.run(model.g_optimizer_const, feed_dict) 129 sess.run(model.g_optimizer_const, feed_dict)
103 130
131 + if i % 500 == 0:
132 + model.saver.save(sess, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1)
133 + print ('model/dtn-%d-%d saved' %(e+1, i+1))
104 134
105 135
106 - if i % 500 == 0:
107 - model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1)
108 - print ('model/dcgan-%d-%d saved' %(e+1, i+1))
...\ No newline at end of file ...\ No newline at end of file
136 + def test(self):
137 + model = self.model
138 +
139 + # load dataset
140 + svhn = self.load_svhn(self.svhn_path)
141 + num_iter = int(svhn.shape[0] / model.batch_size)
142 +
143 + config = tf.ConfigProto(allow_soft_placement = True)
144 + config.gpu_options.allow_growth = True
145 + with tf.Session(config=config) as sess:
146 + # load trained parameters
147 + saver = tf.train.Saver()
148 + saver.restore(sess, self.test_model_path)
149 +
150 + for i in range(self.sample_iter):
151 + # train model for source domain S
152 + image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
153 + feed_dict = {model.images: image_batch}
154 + sampled_image_batch = sess.run(model.sampled_images, feed_dict)
155 +
156 + # merge and save source images and sampled target images
157 + merged = self.merge_images(image_batch, sampled_image_batch)
158 + path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size))
159 + scipy.misc.imsave(path, merged)
160 + print ('saved %s' %path)
......
1 -from model import DTN
2 -from solver import Solver
3 -
4 -def main():
5 - model = DTN()
6 - solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/')
7 - solver.train()
8 -
9 -
10 -if __name__ == "__main__":
11 - main()
...\ No newline at end of file ...\ No newline at end of file