yunjey

model version 0.1

1 +import tensorflow as tf
2 +from ops import *
3 +
4 +class DTN(object):
5 + """Domain Transfer Network for unsupervised cross-domain image generation
6 +
7 + Construct discriminator and generator to prepare for training.
8 + """
9 +
10 + def __init__(self, batch_size=100, learning_rate=0.0002, image_size=32, output_size=32,
11 + dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
12 + """
13 + Args:
14 + learning_rate: (optional) learning rate for discriminator and generator
15 + image_size: (optional) spatial size of input image for discriminator
16 + output_size: (optional) spatial size of image generated by generator
17 + dim_color: (optional) dimension of image color; default is 3 for rgb
18 + dim_fout: (optional) dimension of z (random input vector for generator)
19 + dim_df: (optional) dimension of discriminator's filter in first convolution layer
20 + dim_gf: (optional) dimension of generator's filter in last convolution layer
21 + dim_ff: (optional) dimension of function f's filter in first convolution layer
22 + """
23 + # hyper parameters
24 + self.batch_size = batch_size
25 + self.learning_rate = learning_rate
26 + self.image_size = image_size
27 + self.output_size = output_size
28 + self.dim_color = dim_color
29 + self.dim_fout = dim_fout
30 + self.dim_df = dim_df
31 + self.dim_gf = dim_gf
32 + self.dim_ff = dim_ff
33 +
34 + # placeholder
35 + self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images')
36 + #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator')
37 +
38 + # batch normalization layer for discriminator, generator and funtion f
39 + self.d_bn1 = batch_norm(name='d_bn1')
40 + self.d_bn2 = batch_norm(name='d_bn2')
41 + self.d_bn3 = batch_norm(name='d_bn3')
42 +
43 + self.g_bn1 = batch_norm(name='g_bn1')
44 + self.g_bn2 = batch_norm(name='g_bn2')
45 + self.g_bn3 = batch_norm(name='g_bn3')
46 + self.g_bn4 = batch_norm(name='g_bn4')
47 +
48 + self.f_bn1 = batch_norm(name='f_bn1')
49 + self.f_bn2 = batch_norm(name='f_bn2')
50 + self.f_bn3 = batch_norm(name='f_bn3')
51 + self.f_bn4 = batch_norm(name='f_bn4')
52 +
53 +
54 +
55 + def function_f(self, images, reuse=False):
56 + """f consistancy
57 +
58 + Args:
59 + images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color)
60 +
61 + Returns:
62 + out: output vectors, of shape (batch_size, dim_f_out)
63 + """
64 + with tf.variable_scope('function_f', reuse=reuse):
65 + h1 = lrelu(conv2d(images, self.dim_ff, name='f_h1')) # (batch_size, 16, 16, 64)
66 + h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_ff*2, name='f_h2'))) # (batch_size, 8, 8 128)
67 + h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_ff*4, name='f_h3'))) # (batch_size, 4, 4, 256)
68 + h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_ff*8, name='f_h4'))) # (batch_size, 2, 2, 512)
69 +
70 + h4 = tf.reshape(h4, [self.batch_size,-1])
71 + out = linear(h4, self.dim_fout, name='f_out')
72 +
73 + return tf.nn.tanh(out)
74 +
75 +
76 + def generator(self, z, reuse=False):
77 + """Generator: Deconvolutional neural network with relu activations.
78 +
79 + Last deconv layer does not use batch normalization.
80 +
81 + Args:
82 + z: random input vectors, of shape (batch_size, dim_z)
83 +
84 + Returns:
85 + out: generated images, of shape (batch_size, image_size, image_size, dim_color)
86 + """
87 + if reuse:
88 + train = False
89 + else:
90 + train = True
91 +
92 + with tf.variable_scope('generator', reuse=reuse):
93 +
94 + # spatial size for convolution
95 + s = self.output_size
96 + s2, s4, s8, s16 = s/2, s/4, s/8, s/16 # 32, 16, 8, 4
97 +
98 + # project and reshape z
99 + h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512)
100 + h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512)
101 + h1 = relu(self.g_bn1(h1, train=train))
102 +
103 + h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256)
104 + h2 = relu(self.g_bn2(h2, train=train))
105 +
106 + h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3') # (batch_size, 8, 8, 128)
107 + h3 = relu(self.g_bn3(h3, train=train))
108 +
109 + h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4') # (batch_size, 16, 16, 64)
110 + h4 = relu(self.g_bn4(h4, train=train))
111 +
112 + out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out') # (batch_size, 32, 32, dim_color)
113 +
114 + return tf.nn.tanh(out)
115 +
116 +
117 + def discriminator(self, images, reuse=False):
118 + """Discrimator: Convolutional neural network with leaky relu activations.
119 +
120 + First conv layer does not use batch normalization.
121 +
122 + Args:
123 + images: real or fake images of shape (batch_size, image_size, image_size, dim_color)
124 +
125 + Returns:
126 + out: scores for whether it is a real image or a fake image, of shape (batch_size,)
127 + """
128 + with tf.variable_scope('discriminator', reuse=reuse):
129 +
130 + # convolution layer
131 + h1 = lrelu(conv2d(images, self.dim_df, name='d_h1')) # (batch_size, 16, 16, 64)
132 + h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
133 + h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
134 + h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
135 +
136 + # fully connected layer
137 + h4 = tf.reshape(h4, [self.batch_size, -1])
138 + out = linear(h4, 1, name='d_out') # (batch_size,)
139 +
140 + return out
141 +
142 +
143 + def build_model(self):
144 +
145 + # construct generator and discriminator for training phase
146 + self.f_x = self.function_f(self.images)
147 + self.fake_images = self.generator(self.f_x) # (batch_size, 32, 32, 3)
148 + self.logits_real = self.discriminator(self.images) # (batch_size,)
149 + self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
150 + self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
151 +
152 + # construct generator for test phase
153 + self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3)
154 +
155 +
156 + # compute loss
157 + self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real)))
158 + self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake)))
159 + self.d_loss = self.d_loss_real + self.d_loss_fake
160 + self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
161 + self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
162 + self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) # L_CONST
163 +
164 + # divide variables for discriminator and generator
165 + t_vars = tf.trainable_variables()
166 + self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
167 + self.g_vars = [var for var in t_vars if 'generator' in var.name]
168 + self.f_vars = [var for var in t_vars if 'function_f' in var.name]
169 +
170 + # optimizer for discriminator and generator
171 + with tf.name_scope('optimizer'):
172 + self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars)
173 + self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
174 + self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars)
175 + self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
176 + self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars)
177 +
178 +
179 + # summary ops for tensorboard visualization
180 + tf.scalar_summary('d_loss_real', self.d_loss_real)
181 + tf.scalar_summary('d_loss_fake', self.d_loss_fake)
182 + tf.scalar_summary('d_loss', self.d_loss)
183 + tf.scalar_summary('g_loss', self.g_loss)
184 + tf.scalar_summary('g_const_loss', self.g_const_loss)
185 + tf.scalar_summary('f_const_loss', self.f_const_loss)
186 + tf.image_summary('original_images', self.images, max_images=6)
187 + tf.image_summary('sampled_images', self.sampled_images, max_images=6)
188 +
189 + for var in tf.trainable_variables():
190 + tf.histogram_summary(var.op.name, var)
191 +
192 + self.summary_op = tf.merge_all_summaries()
193 +
194 + self.saver = tf.train.Saver()
...\ No newline at end of file ...\ No newline at end of file
1 +import tensorflow as tf
2 +
3 +
4 +class batch_norm(object):
5 + """Computes batch normalization operation
6 +
7 + Args:
8 + x: input tensor of shape (batch_size, width, height, channels_in) or (batch_size, dim_in)
9 + train: True or False; At train mode, it normalizes the input with mini-batch statistics
10 + At test mode, it normalizes the input with the moving averages and variances
11 +
12 + Returns:
13 + out: batch normalized output of the same shape with x
14 + """
15 + def __init__(self, name):
16 + self.name = name
17 +
18 + def __call__(self, x, train=True):
19 + out = tf.contrib.layers.batch_norm(x, decay=0.99, center=True, scale=True, activation_fn=None,
20 + updates_collections=None, is_training=train, scope=self.name)
21 + return out
22 +
23 +
24 +def conv2d(x, channel_out, k_w=5, k_h=5, s_w=2, s_h=2, name=None):
25 + """Computes convolution operation
26 +
27 + Args:
28 + x: input tensor of shape (batch_size, width_in, heigth_in, channel_in)
29 + channel_out: number of channel for output tensor
30 + k_w: kernel width size; default is 5
31 + k_h: kernel height size; default is 5
32 + s_w: stride size for width; default is 2
33 + s_h: stride size for heigth; default is 2
34 +
35 + Returns:
36 + out: output tensor of shape (batch_size, width_out, height_out, channel_out)
37 + """
38 + channel_in = x.get_shape()[-1]
39 +
40 + with tf.variable_scope(name):
41 + w = tf.get_variable('w', shape=[k_w, k_h, channel_in, channel_out],
42 + initializer=tf.contrib.layers.xavier_initializer())
43 + b = tf.get_variable('b', shape=[channel_out], initializer=tf.constant_initializer(0.0))
44 +
45 + out = tf.nn.conv2d(x, w, strides=[1, s_w, s_h, 1], padding='SAME') + b
46 +
47 + return out
48 +
49 +
50 +def deconv2d(x, output_shape, k_w=5, k_h=5, s_w=2, s_h=2, name=None):
51 + """Computes deconvolution operation
52 +
53 + Args:
54 + x: input tensor of shape (batch_size, width_in, height_in, channel_in)
55 + output_shape: list corresponding to [batch_size, width_out, height_out, channel_out]
56 + k_w: kernel width size; default is 5
57 + k_h: kernel height size; default is 5
58 + s_w: stride size for width; default is 2
59 + s_h: stride size for heigth; default is 2
60 +
61 + Returns:
62 + out: output tensor of shape (batch_size, width_out, hegith_out, channel_out)
63 + """
64 + channel_in = x.get_shape()[-1]
65 + channel_out = output_shape[-1]
66 +
67 +
68 + with tf.variable_scope(name):
69 + w = tf.get_variable('w', shape=[k_w, k_h, channel_out, channel_in],
70 + initializer=tf.contrib.layers.xavier_initializer())
71 + b = tf.get_variable('b', shape=[channel_out], initializer=tf.constant_initializer(0.0))
72 +
73 + out = tf.nn.conv2d_transpose(x, filter=w, output_shape=output_shape, strides=[1, s_w, s_h, 1]) + b
74 +
75 + return out
76 +
77 +def linear(x, dim_out, name=None):
78 + """Computes linear transform (fully-connected layer)
79 +
80 + Args:
81 + x: input tensor of shape (batch_size, dim_in)
82 + dim_out: dimension for output tensor
83 +
84 + Returns:
85 + out: output tensor of shape (batch_size, dim_out)
86 + """
87 + dim_in = x.get_shape()[-1]
88 +
89 + with tf.variable_scope(name):
90 + w = tf.get_variable('w', shape=[dim_in, dim_out], initializer=tf.contrib.layers.xavier_initializer())
91 + b = tf.get_variable('b', shape=[dim_out], initializer=tf.constant_initializer(0.0))
92 +
93 + out = tf.matmul(x, w) + b
94 +
95 + return out
96 +
97 +
98 +def relu(x):
99 + return tf.nn.relu(x)
100 +
101 +
102 +def lrelu(x, leak=0.2):
103 + return tf.maximum(x, leak*x)
...\ No newline at end of file ...\ No newline at end of file
1 +import tensorflow as tf
2 +import numpy as np
3 +import os
4 +import scipy.io
5 +import hickle
6 +from scipy import ndimage
7 +
8 +
9 +class Solver(object):
10 + """Load dataset and train DCGAN"""
11 +
12 + def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'):
13 + self.model = model
14 + self.num_epoch = num_epoch
15 + self.mnist_path = mnist_path
16 + self.svhn_path = svhn_path
17 + self.model_save_path = model_save_path
18 + self.log_path = log_path
19 +
20 + # create directory if not exists
21 + if not os.path.exists(log_path):
22 + os.makedirs(log_path)
23 + if not os.path.exists(model_save_path):
24 + os.makedirs(model_save_path)
25 +
26 + # construct the dcgan model
27 + model.build_model()
28 +
29 + # load dataset
30 + self.svhn = self.load_svhn(self.svhn_path)
31 + self.mnist = self.load_mnist(self.mnist_path)
32 +
33 +
34 + def load_svhn(self, image_path, split='train'):
35 + print ('loading svhn image dataset..')
36 + if split == 'train':
37 + svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat'))
38 + else:
39 + svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat'))
40 +
41 + images = np.transpose(svhn['X'], [3, 0, 1, 2])
42 + images = images / 127.5 - 1
43 + print ('finished loading svhn image dataset..!')
44 + return images
45 +
46 +
47 + def load_mnist(self, image_path, split='train'):
48 + print ('loading mnist image dataset..')
49 + if split == 'train':
50 + image_file = os.path.join(image_path, 'train.images.hkl')
51 + else:
52 + image_file = os.path.join(image_path, 'test.images.hkl')
53 +
54 + images = hickle.load(image_file)
55 + images = images / 127.5 - 1
56 + print ('finished loading mnist image dataset..!')
57 + return images
58 +
59 +
60 + def train(self):
61 + model=self.model
62 +
63 + #load image dataset
64 + svhn = self.svhn
65 + mnist = self.mnist
66 +
67 + num_iter_per_epoch = int(mnist.shape[0] / model.batch_size)
68 +
69 + config = tf.ConfigProto(allow_soft_placement = True)
70 + config.gpu_options.allow_growth = True
71 + with tf.Session(config=config) as sess:
72 + # initialize parameters
73 + tf.initialize_all_variables().run()
74 + summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
75 +
76 + for e in range(self.num_epoch):
77 + for i in range(num_iter_per_epoch):
78 +
79 + # train model for domain S
80 + image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
81 + feed_dict = {model.images: image_batch}
82 + sess.run(model.d_optimizer_fake, feed_dict)
83 + sess.run(model.f_optimizer_const, feed_dict)
84 + sess.run(model.g_optimizer, feed_dict)
85 +
86 + if i % 10 == 0:
87 + feed_dict = {model.images: image_batch}
88 + summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict)
89 + summary_writer.add_summary(summary, e*num_iter_per_epoch + i)
90 + print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss))
91 +
92 + # train model for domain T
93 + image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size]
94 + feed_dict = {model.images: image_batch}
95 + sess.run(model.d_optimizer_real, feed_dict)
96 + sess.run(model.d_optimizer_fake, feed_dict)
97 + sess.run(model.g_optimizer, feed_dict)
98 + sess.run(model.g_optimizer_const, feed_dict)
99 +
100 +
101 +
102 + if i % 500 == 0:
103 + model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1)
104 + print ('model/dcgan-%d-%d saved' %(e+1, i+1))
...\ No newline at end of file ...\ No newline at end of file
1 +from model import DTN
2 +from solver import Solver
3 +
4 +def main():
5 + model = DTN()
6 + solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/')
7 + solver.train()
8 +
9 +
10 +if __name__ == "__main__":
11 + main()
...\ No newline at end of file ...\ No newline at end of file