yunjey

model added

import tensorflow as tf
from ops import *
from config import *
class DTN(object):
"""Domain Transfer Network for unsupervised cross-domain image generation
......@@ -7,7 +9,7 @@ class DTN(object):
Construct discriminator and generator to prepare for training.
"""
def __init__(self, batch_size=100, learning_rate=0.0002, image_size=32, output_size=32,
def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32,
dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
"""
Args:
......@@ -39,6 +41,7 @@ class DTN(object):
self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
self.d_bn3 = batch_norm(name='d_bn3')
self.d_bn4 = batch_norm(name='d_bn4')
self.g_bn1 = batch_norm(name='g_bn1')
self.g_bn2 = batch_norm(name='g_bn2')
......@@ -52,7 +55,7 @@ class DTN(object):
def function_f(self, images, reuse=False):
def function_f(self, images, reuse=False, train=True):
"""f consistancy
Args:
......@@ -60,12 +63,12 @@ class DTN(object):
Returns:
out: output vectors, of shape (batch_size, dim_f_out)
"""
"""
with tf.variable_scope('function_f', reuse=reuse):
h1 = lrelu(conv2d(images, self.dim_ff, name='f_h1')) # (batch_size, 16, 16, 64)
h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_ff*2, name='f_h2'))) # (batch_size, 8, 8 128)
h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_ff*4, name='f_h3'))) # (batch_size, 4, 4, 256)
h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_ff*8, name='f_h4'))) # (batch_size, 2, 2, 512)
h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64)
h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
h4 = tf.reshape(h4, [self.batch_size,-1])
out = linear(h4, self.dim_fout, name='f_out')
......@@ -96,8 +99,8 @@ class DTN(object):
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4
# project and reshape z
h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512)
h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512)
h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512)
h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512)
h1 = relu(self.g_bn1(h1, train=train))
h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256)
......@@ -128,10 +131,10 @@ class DTN(object):
with tf.variable_scope('discriminator', reuse=reuse):
# convolution layer
h1 = lrelu(conv2d(images, self.dim_df, name='d_h1')) # (batch_size, 16, 16, 64)
h2 = lrelu(self.d_bn1(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
h3 = lrelu(self.d_bn2(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
h4 = lrelu(self.d_bn3(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64)
h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
# fully connected layer
h4 = tf.reshape(h4, [self.batch_size, -1])
......@@ -149,7 +152,8 @@ class DTN(object):
self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
# construct generator for test phase
# construct generator for test phase (use moving average and variance for batch norm)
self.f_x = self.function_f(self.images, reuse=True, train=False)
self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3)
......@@ -158,8 +162,8 @@ class DTN(object):
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake)))
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) # L_CONST
self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
# divide variables for discriminator and generator
t_vars = tf.trainable_variables()
......@@ -172,23 +176,28 @@ class DTN(object):
self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars)
self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars)
self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars)
self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars)
# summary ops for tensorboard visualization
tf.scalar_summary('d_loss_real', self.d_loss_real)
tf.scalar_summary('d_loss_fake', self.d_loss_fake)
tf.scalar_summary('d_loss', self.d_loss)
tf.scalar_summary('g_loss', self.g_loss)
tf.scalar_summary('g_const_loss', self.g_const_loss)
tf.scalar_summary('f_const_loss', self.f_const_loss)
tf.image_summary('original_images', self.images, max_images=6)
tf.image_summary('sampled_images', self.sampled_images, max_images=6)
scalar_summary('d_loss_real', self.d_loss_real)
scalar_summary('d_loss_fake', self.d_loss_fake)
scalar_summary('d_loss', self.d_loss)
scalar_summary('g_loss', self.g_loss)
scalar_summary('g_const_loss', self.g_const_loss)
scalar_summary('f_const_loss', self.f_const_loss)
try:
image_summary('original_images', self.images, max_outputs=4)
image_summary('sampled_images', self.sampled_images, max_outputs=4)
except:
image_summary('original_images', self.images, max_images=4)
image_summary('sampled_images', self.sampled_images, max_images=4)
for var in tf.trainable_variables():
tf.histogram_summary(var.op.name, var)
histogram_summary(var.op.name, var)
self.summary_op = tf.merge_all_summaries()
self.summary_op = merge_summary()
self.saver = tf.train.Saver()
\ No newline at end of file
......
......@@ -3,34 +3,37 @@ import numpy as np
import os
import scipy.io
import hickle
from scipy import ndimage
import scipy.misc
from config import SummaryWriter
class Solver(object):
"""Load dataset and train DCGAN"""
"""Load dataset and train and test the model"""
def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'):
def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/',
log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100):
self.model = model
self.num_epoch = num_epoch
self.mnist_path = mnist_path
self.svhn_path = svhn_path
self.model_save_path = model_save_path
self.log_path = log_path
self.sample_path = sample_path
self.test_model_path = test_model_path
self.sample_iter = sample_iter
# create directory if not exists
if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
if not os.path.exists(sample_path):
os.makedirs(sample_path)
# construct the dcgan model
model.build_model()
# load dataset
self.svhn = self.load_svhn(self.svhn_path)
self.mnist = self.load_mnist(self.mnist_path)
def load_svhn(self, image_path, split='train'):
print ('loading svhn image dataset..')
if split == 'train':
......@@ -59,50 +62,99 @@ class Solver(object):
images = images / 127.5 - 1
print ('finished loading mnist image dataset..!')
return images
def merge_images(self, sources, targets, k=10):
_, h, w, _ = sources.shape
row = int(np.sqrt(self.model.batch_size))
merged = np.zeros([row*h, row*w*2, 3])
for idx, (s, t) in enumerate(zip(sources, targets)):
i = idx // row
j = idx % row
merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s
merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t
return merged
def train(self):
model=self.model
# load image dataset
svhn = self.load_svhn(self.svhn_path)
mnist = self.load_mnist(self.mnist_path)
#load image dataset
svhn = self.svhn
mnist = self.mnist
num_iter_per_epoch = int(mnist.shape[0] / model.batch_size)
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# initialize parameters
tf.initialize_all_variables().run()
summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
try:
tf.global_variables_initializer().run()
except:
tf.initialize_all_variables().run()
summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
for e in range(self.num_epoch):
for i in range(num_iter_per_epoch):
# train model for domain S
# train model for source domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sess.run(model.d_optimizer_fake, feed_dict)
sess.run(model.f_optimizer_const, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer, feed_dict)
if i % 3 == 0:
sess.run(model.f_optimizer_const, feed_dict)
if i % 10 == 0:
feed_dict = {model.images: image_batch}
summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict)
summary_writer.add_summary(summary, e*num_iter_per_epoch + i)
print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss))
# train model for domain T
# train model for target domain T
image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sess.run(model.d_optimizer_real, feed_dict)
sess.run(model.d_optimizer_fake, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer_const, feed_dict)
sess.run(model.g_optimizer_const, feed_dict)
if i % 500 == 0:
model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1)
print ('model/dcgan-%d-%d saved' %(e+1, i+1))
\ No newline at end of file
model.saver.save(sess, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1)
print ('model/dtn-%d-%d saved' %(e+1, i+1))
def test(self):
model = self.model
# load dataset
svhn = self.load_svhn(self.svhn_path)
num_iter = int(svhn.shape[0] / model.batch_size)
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# load trained parameters
saver = tf.train.Saver()
saver.restore(sess, self.test_model_path)
for i in range(self.sample_iter):
# train model for source domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sampled_image_batch = sess.run(model.sampled_images, feed_dict)
# merge and save source images and sampled target images
merged = self.merge_images(image_batch, sampled_image_batch)
path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size))
scipy.misc.imsave(path, merged)
print ('saved %s' %path)
......
from model import DTN
from solver import Solver
def main():
model = DTN()
solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/')
solver.train()
if __name__ == "__main__":
main()
\ No newline at end of file