Showing
3 changed files
with
101 additions
and
51 deletions
1 | import tensorflow as tf | 1 | import tensorflow as tf |
2 | from ops import * | 2 | from ops import * |
3 | +from config import * | ||
4 | + | ||
3 | 5 | ||
4 | class DTN(object): | 6 | class DTN(object): |
5 | """Domain Transfer Network for unsupervised cross-domain image generation | 7 | """Domain Transfer Network for unsupervised cross-domain image generation |
... | @@ -7,7 +9,7 @@ class DTN(object): | ... | @@ -7,7 +9,7 @@ class DTN(object): |
7 | Construct discriminator and generator to prepare for training. | 9 | Construct discriminator and generator to prepare for training. |
8 | """ | 10 | """ |
9 | 11 | ||
10 | - def __init__(self, batch_size=100, learning_rate=0.0002, image_size=32, output_size=32, | 12 | + def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32, |
11 | dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): | 13 | dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): |
12 | """ | 14 | """ |
13 | Args: | 15 | Args: |
... | @@ -39,6 +41,7 @@ class DTN(object): | ... | @@ -39,6 +41,7 @@ class DTN(object): |
39 | self.d_bn1 = batch_norm(name='d_bn1') | 41 | self.d_bn1 = batch_norm(name='d_bn1') |
40 | self.d_bn2 = batch_norm(name='d_bn2') | 42 | self.d_bn2 = batch_norm(name='d_bn2') |
41 | self.d_bn3 = batch_norm(name='d_bn3') | 43 | self.d_bn3 = batch_norm(name='d_bn3') |
44 | + self.d_bn4 = batch_norm(name='d_bn4') | ||
42 | 45 | ||
43 | self.g_bn1 = batch_norm(name='g_bn1') | 46 | self.g_bn1 = batch_norm(name='g_bn1') |
44 | self.g_bn2 = batch_norm(name='g_bn2') | 47 | self.g_bn2 = batch_norm(name='g_bn2') |
... | @@ -52,7 +55,7 @@ class DTN(object): | ... | @@ -52,7 +55,7 @@ class DTN(object): |
52 | 55 | ||
53 | 56 | ||
54 | 57 | ||
55 | - def function_f(self, images, reuse=False): | 58 | + def function_f(self, images, reuse=False, train=True): |
56 | """f consistancy | 59 | """f consistancy |
57 | 60 | ||
58 | Args: | 61 | Args: |
... | @@ -62,10 +65,10 @@ class DTN(object): | ... | @@ -62,10 +65,10 @@ class DTN(object): |
62 | out: output vectors, of shape (batch_size, dim_f_out) | 65 | out: output vectors, of shape (batch_size, dim_f_out) |
63 | """ | 66 | """ |
64 | with tf.variable_scope('function_f', reuse=reuse): | 67 | with tf.variable_scope('function_f', reuse=reuse): |
65 | - h1 = lrelu(conv2d(images, self.dim_ff, name='f_h1')) # (batch_size, 16, 16, 64) | 68 | + h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64) |
66 | - h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_ff*2, name='f_h2'))) # (batch_size, 8, 8 128) | 69 | + h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128) |
67 | - h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_ff*4, name='f_h3'))) # (batch_size, 4, 4, 256) | 70 | + h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256) |
68 | - h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_ff*8, name='f_h4'))) # (batch_size, 2, 2, 512) | 71 | + h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512) |
69 | 72 | ||
70 | h4 = tf.reshape(h4, [self.batch_size,-1]) | 73 | h4 = tf.reshape(h4, [self.batch_size,-1]) |
71 | out = linear(h4, self.dim_fout, name='f_out') | 74 | out = linear(h4, self.dim_fout, name='f_out') |
... | @@ -128,10 +131,10 @@ class DTN(object): | ... | @@ -128,10 +131,10 @@ class DTN(object): |
128 | with tf.variable_scope('discriminator', reuse=reuse): | 131 | with tf.variable_scope('discriminator', reuse=reuse): |
129 | 132 | ||
130 | # convolution layer | 133 | # convolution layer |
131 | - h1 = lrelu(conv2d(images, self.dim_df, name='d_h1')) # (batch_size, 16, 16, 64) | 134 | + h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64) |
132 | - h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) | 135 | + h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) |
133 | - h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) | 136 | + h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) |
134 | - h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) | 137 | + h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) |
135 | 138 | ||
136 | # fully connected layer | 139 | # fully connected layer |
137 | h4 = tf.reshape(h4, [self.batch_size, -1]) | 140 | h4 = tf.reshape(h4, [self.batch_size, -1]) |
... | @@ -149,7 +152,8 @@ class DTN(object): | ... | @@ -149,7 +152,8 @@ class DTN(object): |
149 | self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) | 152 | self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) |
150 | self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) | 153 | self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) |
151 | 154 | ||
152 | - # construct generator for test phase | 155 | + # construct generator for test phase (use moving average and variance for batch norm) |
156 | + self.f_x = self.function_f(self.images, reuse=True, train=False) | ||
153 | self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) | 157 | self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) |
154 | 158 | ||
155 | 159 | ||
... | @@ -159,7 +163,7 @@ class DTN(object): | ... | @@ -159,7 +163,7 @@ class DTN(object): |
159 | self.d_loss = self.d_loss_real + self.d_loss_fake | 163 | self.d_loss = self.d_loss_real + self.d_loss_fake |
160 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) | 164 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) |
161 | self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID | 165 | self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID |
162 | - self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) # L_CONST | 166 | + self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST |
163 | 167 | ||
164 | # divide variables for discriminator and generator | 168 | # divide variables for discriminator and generator |
165 | t_vars = tf.trainable_variables() | 169 | t_vars = tf.trainable_variables() |
... | @@ -173,22 +177,27 @@ class DTN(object): | ... | @@ -173,22 +177,27 @@ class DTN(object): |
173 | self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) | 177 | self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) |
174 | self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) | 178 | self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) |
175 | self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) | 179 | self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) |
176 | - self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars) | 180 | + self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars) |
177 | 181 | ||
178 | 182 | ||
179 | # summary ops for tensorboard visualization | 183 | # summary ops for tensorboard visualization |
180 | - tf.scalar_summary('d_loss_real', self.d_loss_real) | 184 | + scalar_summary('d_loss_real', self.d_loss_real) |
181 | - tf.scalar_summary('d_loss_fake', self.d_loss_fake) | 185 | + scalar_summary('d_loss_fake', self.d_loss_fake) |
182 | - tf.scalar_summary('d_loss', self.d_loss) | 186 | + scalar_summary('d_loss', self.d_loss) |
183 | - tf.scalar_summary('g_loss', self.g_loss) | 187 | + scalar_summary('g_loss', self.g_loss) |
184 | - tf.scalar_summary('g_const_loss', self.g_const_loss) | 188 | + scalar_summary('g_const_loss', self.g_const_loss) |
185 | - tf.scalar_summary('f_const_loss', self.f_const_loss) | 189 | + scalar_summary('f_const_loss', self.f_const_loss) |
186 | - tf.image_summary('original_images', self.images, max_images=6) | 190 | + |
187 | - tf.image_summary('sampled_images', self.sampled_images, max_images=6) | 191 | + try: |
192 | + image_summary('original_images', self.images, max_outputs=4) | ||
193 | + image_summary('sampled_images', self.sampled_images, max_outputs=4) | ||
194 | + except: | ||
195 | + image_summary('original_images', self.images, max_images=4) | ||
196 | + image_summary('sampled_images', self.sampled_images, max_images=4) | ||
188 | 197 | ||
189 | for var in tf.trainable_variables(): | 198 | for var in tf.trainable_variables(): |
190 | - tf.histogram_summary(var.op.name, var) | 199 | + histogram_summary(var.op.name, var) |
191 | 200 | ||
192 | - self.summary_op = tf.merge_all_summaries() | 201 | + self.summary_op = merge_summary() |
193 | 202 | ||
194 | self.saver = tf.train.Saver() | 203 | self.saver = tf.train.Saver() |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -3,33 +3,36 @@ import numpy as np | ... | @@ -3,33 +3,36 @@ import numpy as np |
3 | import os | 3 | import os |
4 | import scipy.io | 4 | import scipy.io |
5 | import hickle | 5 | import hickle |
6 | -from scipy import ndimage | 6 | +import scipy.misc |
7 | +from config import SummaryWriter | ||
7 | 8 | ||
8 | 9 | ||
9 | class Solver(object): | 10 | class Solver(object): |
10 | - """Load dataset and train DCGAN""" | 11 | + """Load dataset and train and test the model""" |
11 | 12 | ||
12 | - def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'): | 13 | + def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', |
14 | + log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100): | ||
13 | self.model = model | 15 | self.model = model |
14 | self.num_epoch = num_epoch | 16 | self.num_epoch = num_epoch |
15 | self.mnist_path = mnist_path | 17 | self.mnist_path = mnist_path |
16 | self.svhn_path = svhn_path | 18 | self.svhn_path = svhn_path |
17 | self.model_save_path = model_save_path | 19 | self.model_save_path = model_save_path |
18 | self.log_path = log_path | 20 | self.log_path = log_path |
21 | + self.sample_path = sample_path | ||
22 | + self.test_model_path = test_model_path | ||
23 | + self.sample_iter = sample_iter | ||
19 | 24 | ||
20 | # create directory if not exists | 25 | # create directory if not exists |
21 | if not os.path.exists(log_path): | 26 | if not os.path.exists(log_path): |
22 | os.makedirs(log_path) | 27 | os.makedirs(log_path) |
23 | if not os.path.exists(model_save_path): | 28 | if not os.path.exists(model_save_path): |
24 | os.makedirs(model_save_path) | 29 | os.makedirs(model_save_path) |
30 | + if not os.path.exists(sample_path): | ||
31 | + os.makedirs(sample_path) | ||
25 | 32 | ||
26 | # construct the dcgan model | 33 | # construct the dcgan model |
27 | model.build_model() | 34 | model.build_model() |
28 | 35 | ||
29 | - # load dataset | ||
30 | - self.svhn = self.load_svhn(self.svhn_path) | ||
31 | - self.mnist = self.load_mnist(self.mnist_path) | ||
32 | - | ||
33 | 36 | ||
34 | def load_svhn(self, image_path, split='train'): | 37 | def load_svhn(self, image_path, split='train'): |
35 | print ('loading svhn image dataset..') | 38 | print ('loading svhn image dataset..') |
... | @@ -61,12 +64,27 @@ class Solver(object): | ... | @@ -61,12 +64,27 @@ class Solver(object): |
61 | return images | 64 | return images |
62 | 65 | ||
63 | 66 | ||
67 | + def merge_images(self, sources, targets, k=10): | ||
68 | + _, h, w, _ = sources.shape | ||
69 | + row = int(np.sqrt(self.model.batch_size)) | ||
70 | + merged = np.zeros([row*h, row*w*2, 3]) | ||
71 | + | ||
72 | + for idx, (s, t) in enumerate(zip(sources, targets)): | ||
73 | + i = idx // row | ||
74 | + j = idx % row | ||
75 | + merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s | ||
76 | + merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t | ||
77 | + | ||
78 | + return merged | ||
79 | + | ||
80 | + | ||
64 | def train(self): | 81 | def train(self): |
65 | model=self.model | 82 | model=self.model |
66 | 83 | ||
67 | - #load image dataset | 84 | + # load image dataset |
68 | - svhn = self.svhn | 85 | + svhn = self.load_svhn(self.svhn_path) |
69 | - mnist = self.mnist | 86 | + mnist = self.load_mnist(self.mnist_path) |
87 | + | ||
70 | 88 | ||
71 | num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) | 89 | num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) |
72 | 90 | ||
... | @@ -74,18 +92,24 @@ class Solver(object): | ... | @@ -74,18 +92,24 @@ class Solver(object): |
74 | config.gpu_options.allow_growth = True | 92 | config.gpu_options.allow_growth = True |
75 | with tf.Session(config=config) as sess: | 93 | with tf.Session(config=config) as sess: |
76 | # initialize parameters | 94 | # initialize parameters |
95 | + try: | ||
96 | + tf.global_variables_initializer().run() | ||
97 | + except: | ||
77 | tf.initialize_all_variables().run() | 98 | tf.initialize_all_variables().run() |
78 | - summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) | 99 | + |
100 | + summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) | ||
79 | 101 | ||
80 | for e in range(self.num_epoch): | 102 | for e in range(self.num_epoch): |
81 | for i in range(num_iter_per_epoch): | 103 | for i in range(num_iter_per_epoch): |
82 | 104 | ||
83 | - # train model for domain S | 105 | + # train model for source domain S |
84 | image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] | 106 | image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] |
85 | feed_dict = {model.images: image_batch} | 107 | feed_dict = {model.images: image_batch} |
86 | sess.run(model.d_optimizer_fake, feed_dict) | 108 | sess.run(model.d_optimizer_fake, feed_dict) |
87 | - sess.run(model.f_optimizer_const, feed_dict) | ||
88 | sess.run(model.g_optimizer, feed_dict) | 109 | sess.run(model.g_optimizer, feed_dict) |
110 | + sess.run(model.g_optimizer, feed_dict) | ||
111 | + if i % 3 == 0: | ||
112 | + sess.run(model.f_optimizer_const, feed_dict) | ||
89 | 113 | ||
90 | if i % 10 == 0: | 114 | if i % 10 == 0: |
91 | feed_dict = {model.images: image_batch} | 115 | feed_dict = {model.images: image_batch} |
... | @@ -93,16 +117,44 @@ class Solver(object): | ... | @@ -93,16 +117,44 @@ class Solver(object): |
93 | summary_writer.add_summary(summary, e*num_iter_per_epoch + i) | 117 | summary_writer.add_summary(summary, e*num_iter_per_epoch + i) |
94 | print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) | 118 | print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) |
95 | 119 | ||
96 | - # train model for domain T | 120 | + # train model for target domain T |
97 | image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] | 121 | image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] |
98 | feed_dict = {model.images: image_batch} | 122 | feed_dict = {model.images: image_batch} |
99 | sess.run(model.d_optimizer_real, feed_dict) | 123 | sess.run(model.d_optimizer_real, feed_dict) |
100 | sess.run(model.d_optimizer_fake, feed_dict) | 124 | sess.run(model.d_optimizer_fake, feed_dict) |
101 | sess.run(model.g_optimizer, feed_dict) | 125 | sess.run(model.g_optimizer, feed_dict) |
126 | + sess.run(model.g_optimizer, feed_dict) | ||
127 | + sess.run(model.g_optimizer, feed_dict) | ||
128 | + sess.run(model.g_optimizer_const, feed_dict) | ||
102 | sess.run(model.g_optimizer_const, feed_dict) | 129 | sess.run(model.g_optimizer_const, feed_dict) |
103 | 130 | ||
131 | + if i % 500 == 0: | ||
132 | + model.saver.save(sess, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1) | ||
133 | + print ('model/dtn-%d-%d saved' %(e+1, i+1)) | ||
104 | 134 | ||
105 | 135 | ||
106 | - if i % 500 == 0: | ||
107 | - model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1) | ||
108 | - print ('model/dcgan-%d-%d saved' %(e+1, i+1)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
136 | + def test(self): | ||
137 | + model = self.model | ||
138 | + | ||
139 | + # load dataset | ||
140 | + svhn = self.load_svhn(self.svhn_path) | ||
141 | + num_iter = int(svhn.shape[0] / model.batch_size) | ||
142 | + | ||
143 | + config = tf.ConfigProto(allow_soft_placement = True) | ||
144 | + config.gpu_options.allow_growth = True | ||
145 | + with tf.Session(config=config) as sess: | ||
146 | + # load trained parameters | ||
147 | + saver = tf.train.Saver() | ||
148 | + saver.restore(sess, self.test_model_path) | ||
149 | + | ||
150 | + for i in range(self.sample_iter): | ||
151 | + # train model for source domain S | ||
152 | + image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] | ||
153 | + feed_dict = {model.images: image_batch} | ||
154 | + sampled_image_batch = sess.run(model.sampled_images, feed_dict) | ||
155 | + | ||
156 | + # merge and save source images and sampled target images | ||
157 | + merged = self.merge_images(image_batch, sampled_image_batch) | ||
158 | + path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size)) | ||
159 | + scipy.misc.imsave(path, merged) | ||
160 | + print ('saved %s' %path) | ... | ... |
train.py
deleted
100644 → 0
1 | -from model import DTN | ||
2 | -from solver import Solver | ||
3 | - | ||
4 | -def main(): | ||
5 | - model = DTN() | ||
6 | - solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/') | ||
7 | - solver.train() | ||
8 | - | ||
9 | - | ||
10 | -if __name__ == "__main__": | ||
11 | - main() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment