Showing
4 changed files
with
412 additions
and
0 deletions
model.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +from ops import * | ||
3 | + | ||
4 | +class DTN(object): | ||
5 | + """Domain Transfer Network for unsupervised cross-domain image generation | ||
6 | + | ||
7 | + Construct discriminator and generator to prepare for training. | ||
8 | + """ | ||
9 | + | ||
10 | + def __init__(self, batch_size=100, learning_rate=0.0002, image_size=32, output_size=32, | ||
11 | + dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64): | ||
12 | + """ | ||
13 | + Args: | ||
14 | + learning_rate: (optional) learning rate for discriminator and generator | ||
15 | + image_size: (optional) spatial size of input image for discriminator | ||
16 | + output_size: (optional) spatial size of image generated by generator | ||
17 | + dim_color: (optional) dimension of image color; default is 3 for rgb | ||
18 | + dim_fout: (optional) dimension of z (random input vector for generator) | ||
19 | + dim_df: (optional) dimension of discriminator's filter in first convolution layer | ||
20 | + dim_gf: (optional) dimension of generator's filter in last convolution layer | ||
21 | + dim_ff: (optional) dimension of function f's filter in first convolution layer | ||
22 | + """ | ||
23 | + # hyper parameters | ||
24 | + self.batch_size = batch_size | ||
25 | + self.learning_rate = learning_rate | ||
26 | + self.image_size = image_size | ||
27 | + self.output_size = output_size | ||
28 | + self.dim_color = dim_color | ||
29 | + self.dim_fout = dim_fout | ||
30 | + self.dim_df = dim_df | ||
31 | + self.dim_gf = dim_gf | ||
32 | + self.dim_ff = dim_ff | ||
33 | + | ||
34 | + # placeholder | ||
35 | + self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images') | ||
36 | + #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator') | ||
37 | + | ||
38 | + # batch normalization layer for discriminator, generator and funtion f | ||
39 | + self.d_bn1 = batch_norm(name='d_bn1') | ||
40 | + self.d_bn2 = batch_norm(name='d_bn2') | ||
41 | + self.d_bn3 = batch_norm(name='d_bn3') | ||
42 | + | ||
43 | + self.g_bn1 = batch_norm(name='g_bn1') | ||
44 | + self.g_bn2 = batch_norm(name='g_bn2') | ||
45 | + self.g_bn3 = batch_norm(name='g_bn3') | ||
46 | + self.g_bn4 = batch_norm(name='g_bn4') | ||
47 | + | ||
48 | + self.f_bn1 = batch_norm(name='f_bn1') | ||
49 | + self.f_bn2 = batch_norm(name='f_bn2') | ||
50 | + self.f_bn3 = batch_norm(name='f_bn3') | ||
51 | + self.f_bn4 = batch_norm(name='f_bn4') | ||
52 | + | ||
53 | + | ||
54 | + | ||
55 | + def function_f(self, images, reuse=False): | ||
56 | + """f consistancy | ||
57 | + | ||
58 | + Args: | ||
59 | + images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color) | ||
60 | + | ||
61 | + Returns: | ||
62 | + out: output vectors, of shape (batch_size, dim_f_out) | ||
63 | + """ | ||
64 | + with tf.variable_scope('function_f', reuse=reuse): | ||
65 | + h1 = lrelu(conv2d(images, self.dim_ff, name='f_h1')) # (batch_size, 16, 16, 64) | ||
66 | + h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_ff*2, name='f_h2'))) # (batch_size, 8, 8 128) | ||
67 | + h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_ff*4, name='f_h3'))) # (batch_size, 4, 4, 256) | ||
68 | + h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_ff*8, name='f_h4'))) # (batch_size, 2, 2, 512) | ||
69 | + | ||
70 | + h4 = tf.reshape(h4, [self.batch_size,-1]) | ||
71 | + out = linear(h4, self.dim_fout, name='f_out') | ||
72 | + | ||
73 | + return tf.nn.tanh(out) | ||
74 | + | ||
75 | + | ||
76 | + def generator(self, z, reuse=False): | ||
77 | + """Generator: Deconvolutional neural network with relu activations. | ||
78 | + | ||
79 | + Last deconv layer does not use batch normalization. | ||
80 | + | ||
81 | + Args: | ||
82 | + z: random input vectors, of shape (batch_size, dim_z) | ||
83 | + | ||
84 | + Returns: | ||
85 | + out: generated images, of shape (batch_size, image_size, image_size, dim_color) | ||
86 | + """ | ||
87 | + if reuse: | ||
88 | + train = False | ||
89 | + else: | ||
90 | + train = True | ||
91 | + | ||
92 | + with tf.variable_scope('generator', reuse=reuse): | ||
93 | + | ||
94 | + # spatial size for convolution | ||
95 | + s = self.output_size | ||
96 | + s2, s4, s8, s16 = s/2, s/4, s/8, s/16 # 32, 16, 8, 4 | ||
97 | + | ||
98 | + # project and reshape z | ||
99 | + h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512) | ||
100 | + h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512) | ||
101 | + h1 = relu(self.g_bn1(h1, train=train)) | ||
102 | + | ||
103 | + h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256) | ||
104 | + h2 = relu(self.g_bn2(h2, train=train)) | ||
105 | + | ||
106 | + h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128) | ||
107 | + h3 = relu(self.g_bn3(h3, train=train)) | ||
108 | + | ||
109 | + h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64) | ||
110 | + h4 = relu(self.g_bn4(h4, train=train)) | ||
111 | + | ||
112 | + out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color) | ||
113 | + | ||
114 | + return tf.nn.tanh(out) | ||
115 | + | ||
116 | + | ||
117 | + def discriminator(self, images, reuse=False): | ||
118 | + """Discrimator: Convolutional neural network with leaky relu activations. | ||
119 | + | ||
120 | + First conv layer does not use batch normalization. | ||
121 | + | ||
122 | + Args: | ||
123 | + images: real or fake images of shape (batch_size, image_size, image_size, dim_color) | ||
124 | + | ||
125 | + Returns: | ||
126 | + out: scores for whether it is a real image or a fake image, of shape (batch_size,) | ||
127 | + """ | ||
128 | + with tf.variable_scope('discriminator', reuse=reuse): | ||
129 | + | ||
130 | + # convolution layer | ||
131 | + h1 = lrelu(conv2d(images, self.dim_df, name='d_h1')) # (batch_size, 16, 16, 64) | ||
132 | + h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128) | ||
133 | + h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256) | ||
134 | + h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512) | ||
135 | + | ||
136 | + # fully connected layer | ||
137 | + h4 = tf.reshape(h4, [self.batch_size, -1]) | ||
138 | + out = linear(h4, 1, name='d_out') # (batch_size,) | ||
139 | + | ||
140 | + return out | ||
141 | + | ||
142 | + | ||
143 | + def build_model(self): | ||
144 | + | ||
145 | + # construct generator and discriminator for training phase | ||
146 | + self.f_x = self.function_f(self.images) | ||
147 | + self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3) | ||
148 | + self.logits_real = self.discriminator(self.images) # (batch_size,) | ||
149 | + self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,) | ||
150 | + self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f) | ||
151 | + | ||
152 | + # construct generator for test phase | ||
153 | + self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3) | ||
154 | + | ||
155 | + | ||
156 | + # compute loss | ||
157 | + self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real))) | ||
158 | + self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake))) | ||
159 | + self.d_loss = self.d_loss_real + self.d_loss_fake | ||
160 | + self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake))) | ||
161 | + self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID | ||
162 | + self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) # L_CONST | ||
163 | + | ||
164 | + # divide variables for discriminator and generator | ||
165 | + t_vars = tf.trainable_variables() | ||
166 | + self.d_vars = [var for var in t_vars if 'discriminator' in var.name] | ||
167 | + self.g_vars = [var for var in t_vars if 'generator' in var.name] | ||
168 | + self.f_vars = [var for var in t_vars if 'function_f' in var.name] | ||
169 | + | ||
170 | + # optimizer for discriminator and generator | ||
171 | + with tf.name_scope('optimizer'): | ||
172 | + self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars) | ||
173 | + self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars) | ||
174 | + self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars) | ||
175 | + self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars) | ||
176 | + self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars) | ||
177 | + | ||
178 | + | ||
179 | + # summary ops for tensorboard visualization | ||
180 | + tf.scalar_summary('d_loss_real', self.d_loss_real) | ||
181 | + tf.scalar_summary('d_loss_fake', self.d_loss_fake) | ||
182 | + tf.scalar_summary('d_loss', self.d_loss) | ||
183 | + tf.scalar_summary('g_loss', self.g_loss) | ||
184 | + tf.scalar_summary('g_const_loss', self.g_const_loss) | ||
185 | + tf.scalar_summary('f_const_loss', self.f_const_loss) | ||
186 | + tf.image_summary('original_images', self.images, max_images=6) | ||
187 | + tf.image_summary('sampled_images', self.sampled_images, max_images=6) | ||
188 | + | ||
189 | + for var in tf.trainable_variables(): | ||
190 | + tf.histogram_summary(var.op.name, var) | ||
191 | + | ||
192 | + self.summary_op = tf.merge_all_summaries() | ||
193 | + | ||
194 | + self.saver = tf.train.Saver() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ops.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | + | ||
3 | + | ||
4 | +class batch_norm(object): | ||
5 | + """Computes batch normalization operation | ||
6 | + | ||
7 | + Args: | ||
8 | + x: input tensor of shape (batch_size, width, height, channels_in) or (batch_size, dim_in) | ||
9 | + train: True or False; At train mode, it normalizes the input with mini-batch statistics | ||
10 | + At test mode, it normalizes the input with the moving averages and variances | ||
11 | + | ||
12 | + Returns: | ||
13 | + out: batch normalized output of the same shape with x | ||
14 | + """ | ||
15 | + def __init__(self, name): | ||
16 | + self.name = name | ||
17 | + | ||
18 | + def __call__(self, x, train=True): | ||
19 | + out = tf.contrib.layers.batch_norm(x, decay=0.99, center=True, scale=True, activation_fn=None, | ||
20 | + updates_collections=None, is_training=train, scope=self.name) | ||
21 | + return out | ||
22 | + | ||
23 | + | ||
24 | +def conv2d(x, channel_out, k_w=5, k_h=5, s_w=2, s_h=2, name=None): | ||
25 | + """Computes convolution operation | ||
26 | + | ||
27 | + Args: | ||
28 | + x: input tensor of shape (batch_size, width_in, heigth_in, channel_in) | ||
29 | + channel_out: number of channel for output tensor | ||
30 | + k_w: kernel width size; default is 5 | ||
31 | + k_h: kernel height size; default is 5 | ||
32 | + s_w: stride size for width; default is 2 | ||
33 | + s_h: stride size for heigth; default is 2 | ||
34 | + | ||
35 | + Returns: | ||
36 | + out: output tensor of shape (batch_size, width_out, height_out, channel_out) | ||
37 | + """ | ||
38 | + channel_in = x.get_shape()[-1] | ||
39 | + | ||
40 | + with tf.variable_scope(name): | ||
41 | + w = tf.get_variable('w', shape=[k_w, k_h, channel_in, channel_out], | ||
42 | + initializer=tf.contrib.layers.xavier_initializer()) | ||
43 | + b = tf.get_variable('b', shape=[channel_out], initializer=tf.constant_initializer(0.0)) | ||
44 | + | ||
45 | + out = tf.nn.conv2d(x, w, strides=[1, s_w, s_h, 1], padding='SAME') + b | ||
46 | + | ||
47 | + return out | ||
48 | + | ||
49 | + | ||
50 | +def deconv2d(x, output_shape, k_w=5, k_h=5, s_w=2, s_h=2, name=None): | ||
51 | + """Computes deconvolution operation | ||
52 | + | ||
53 | + Args: | ||
54 | + x: input tensor of shape (batch_size, width_in, height_in, channel_in) | ||
55 | + output_shape: list corresponding to [batch_size, width_out, height_out, channel_out] | ||
56 | + k_w: kernel width size; default is 5 | ||
57 | + k_h: kernel height size; default is 5 | ||
58 | + s_w: stride size for width; default is 2 | ||
59 | + s_h: stride size for heigth; default is 2 | ||
60 | + | ||
61 | + Returns: | ||
62 | + out: output tensor of shape (batch_size, width_out, hegith_out, channel_out) | ||
63 | + """ | ||
64 | + channel_in = x.get_shape()[-1] | ||
65 | + channel_out = output_shape[-1] | ||
66 | + | ||
67 | + | ||
68 | + with tf.variable_scope(name): | ||
69 | + w = tf.get_variable('w', shape=[k_w, k_h, channel_out, channel_in], | ||
70 | + initializer=tf.contrib.layers.xavier_initializer()) | ||
71 | + b = tf.get_variable('b', shape=[channel_out], initializer=tf.constant_initializer(0.0)) | ||
72 | + | ||
73 | + out = tf.nn.conv2d_transpose(x, filter=w, output_shape=output_shape, strides=[1, s_w, s_h, 1]) + b | ||
74 | + | ||
75 | + return out | ||
76 | + | ||
77 | +def linear(x, dim_out, name=None): | ||
78 | + """Computes linear transform (fully-connected layer) | ||
79 | + | ||
80 | + Args: | ||
81 | + x: input tensor of shape (batch_size, dim_in) | ||
82 | + dim_out: dimension for output tensor | ||
83 | + | ||
84 | + Returns: | ||
85 | + out: output tensor of shape (batch_size, dim_out) | ||
86 | + """ | ||
87 | + dim_in = x.get_shape()[-1] | ||
88 | + | ||
89 | + with tf.variable_scope(name): | ||
90 | + w = tf.get_variable('w', shape=[dim_in, dim_out], initializer=tf.contrib.layers.xavier_initializer()) | ||
91 | + b = tf.get_variable('b', shape=[dim_out], initializer=tf.constant_initializer(0.0)) | ||
92 | + | ||
93 | + out = tf.matmul(x, w) + b | ||
94 | + | ||
95 | + return out | ||
96 | + | ||
97 | + | ||
98 | +def relu(x): | ||
99 | + return tf.nn.relu(x) | ||
100 | + | ||
101 | + | ||
102 | +def lrelu(x, leak=0.2): | ||
103 | + return tf.maximum(x, leak*x) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
solver.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy as np | ||
3 | +import os | ||
4 | +import scipy.io | ||
5 | +import hickle | ||
6 | +from scipy import ndimage | ||
7 | + | ||
8 | + | ||
9 | +class Solver(object): | ||
10 | + """Load dataset and train DCGAN""" | ||
11 | + | ||
12 | + def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'): | ||
13 | + self.model = model | ||
14 | + self.num_epoch = num_epoch | ||
15 | + self.mnist_path = mnist_path | ||
16 | + self.svhn_path = svhn_path | ||
17 | + self.model_save_path = model_save_path | ||
18 | + self.log_path = log_path | ||
19 | + | ||
20 | + # create directory if not exists | ||
21 | + if not os.path.exists(log_path): | ||
22 | + os.makedirs(log_path) | ||
23 | + if not os.path.exists(model_save_path): | ||
24 | + os.makedirs(model_save_path) | ||
25 | + | ||
26 | + # construct the dcgan model | ||
27 | + model.build_model() | ||
28 | + | ||
29 | + # load dataset | ||
30 | + self.svhn = self.load_svhn(self.svhn_path) | ||
31 | + self.mnist = self.load_mnist(self.mnist_path) | ||
32 | + | ||
33 | + | ||
34 | + def load_svhn(self, image_path, split='train'): | ||
35 | + print ('loading svhn image dataset..') | ||
36 | + if split == 'train': | ||
37 | + svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat')) | ||
38 | + else: | ||
39 | + svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat')) | ||
40 | + | ||
41 | + images = np.transpose(svhn['X'], [3, 0, 1, 2]) | ||
42 | + images = images / 127.5 - 1 | ||
43 | + print ('finished loading svhn image dataset..!') | ||
44 | + return images | ||
45 | + | ||
46 | + | ||
47 | + def load_mnist(self, image_path, split='train'): | ||
48 | + print ('loading mnist image dataset..') | ||
49 | + if split == 'train': | ||
50 | + image_file = os.path.join(image_path, 'train.images.hkl') | ||
51 | + else: | ||
52 | + image_file = os.path.join(image_path, 'test.images.hkl') | ||
53 | + | ||
54 | + images = hickle.load(image_file) | ||
55 | + images = images / 127.5 - 1 | ||
56 | + print ('finished loading mnist image dataset..!') | ||
57 | + return images | ||
58 | + | ||
59 | + | ||
60 | + def train(self): | ||
61 | + model=self.model | ||
62 | + | ||
63 | + #load image dataset | ||
64 | + svhn = self.svhn | ||
65 | + mnist = self.mnist | ||
66 | + | ||
67 | + num_iter_per_epoch = int(mnist.shape[0] / model.batch_size) | ||
68 | + | ||
69 | + config = tf.ConfigProto(allow_soft_placement = True) | ||
70 | + config.gpu_options.allow_growth = True | ||
71 | + with tf.Session(config=config) as sess: | ||
72 | + # initialize parameters | ||
73 | + tf.initialize_all_variables().run() | ||
74 | + summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph()) | ||
75 | + | ||
76 | + for e in range(self.num_epoch): | ||
77 | + for i in range(num_iter_per_epoch): | ||
78 | + | ||
79 | + # train model for domain S | ||
80 | + image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size] | ||
81 | + feed_dict = {model.images: image_batch} | ||
82 | + sess.run(model.d_optimizer_fake, feed_dict) | ||
83 | + sess.run(model.f_optimizer_const, feed_dict) | ||
84 | + sess.run(model.g_optimizer, feed_dict) | ||
85 | + | ||
86 | + if i % 10 == 0: | ||
87 | + feed_dict = {model.images: image_batch} | ||
88 | + summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict) | ||
89 | + summary_writer.add_summary(summary, e*num_iter_per_epoch + i) | ||
90 | + print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss)) | ||
91 | + | ||
92 | + # train model for domain T | ||
93 | + image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size] | ||
94 | + feed_dict = {model.images: image_batch} | ||
95 | + sess.run(model.d_optimizer_real, feed_dict) | ||
96 | + sess.run(model.d_optimizer_fake, feed_dict) | ||
97 | + sess.run(model.g_optimizer, feed_dict) | ||
98 | + sess.run(model.g_optimizer_const, feed_dict) | ||
99 | + | ||
100 | + | ||
101 | + | ||
102 | + if i % 500 == 0: | ||
103 | + model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1) | ||
104 | + print ('model/dcgan-%d-%d saved' %(e+1, i+1)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
train.py
0 → 100644
1 | +from model import DTN | ||
2 | +from solver import Solver | ||
3 | + | ||
4 | +def main(): | ||
5 | + model = DTN() | ||
6 | + solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/') | ||
7 | + solver.train() | ||
8 | + | ||
9 | + | ||
10 | +if __name__ == "__main__": | ||
11 | + main() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment