import tensorflow as tf
from ops import *
from config import *
class DTN(object):
"""Domain Transfer Network for unsupervised cross-domain image generation
Construct discriminator and generator to prepare for training.
def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32,
def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32,
dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
self.d_bn3 = batch_norm(name='d_bn3')
self.d_bn4 = batch_norm(name='d_bn4')
self.g_bn1 = batch_norm(name='g_bn1')
self.g_bn2 = batch_norm(name='g_bn2')
def function_f(self, images, reuse=False):
def function_f(self, images, reuse=False, train=True):
"""f consistancy
out: output vectors, of shape (batch_size, dim_f_out)
with tf.variable_scope('function_f', reuse=reuse):
h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64)
h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train)) # (batch_size, 16, 16, 64)
h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train)) # (batch_size, 8, 8 128)
h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train)) # (batch_size, 4, 4, 256)
h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train)) # (batch_size, 2, 2, 512)
h4 = tf.reshape(h4, [self.batch_size,-1])
out = linear(h4, self.dim_fout, name='f_out')
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16) # 32, 16, 8, 4
# project and reshape z
h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512)
h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512)
h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1') # (batch_size, 2*2*512)
h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8]) # (batch_size, 2, 2, 512)
h1 = relu(self.g_bn1(h1, train=train))
h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2') # (batch_size, 4, 4, 256)
with tf.variable_scope('discriminator', reuse=reuse):
# convolution layer
h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64)
h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1'))) # (batch_size, 16, 16, 64)
h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2'))) # (batch_size, 8, 8, 128)
h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3'))) # (batch_size, 4, 4, 256)
h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4'))) # (batch_size, 2, 2, 512)
# fully connected layer
h4 = tf.reshape(h4, [self.batch_size, -1])
self.logits_fake = self.discriminator(self.fake_images, reuse=True) # (batch_size,)
self.fgf_x = self.function_f(self.fake_images, reuse=True) # (batch_size, dim_f)
# construct generator for test phase
# construct generator for test phase (use moving average and variance for batch norm)
self.f_x = self.function_f(self.images, reuse=True, train=False)
self.sampled_images = self.generator(self.f_x, reuse=True) # (batch_size, 32, 32, 3)
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake)))
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images)) # L_TID
self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x)) * 0.15 # L_CONST
# divide variables for discriminator and generator
t_vars = tf.trainable_variables()
self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars)
self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars)
self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.f_vars+self.g_vars)
self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars)
# summary ops for tensorboard visualization
scalar_summary('d_loss_real', self.d_loss_real)
scalar_summary('d_loss_fake', self.d_loss_fake)
scalar_summary('d_loss', self.d_loss)
scalar_summary('g_loss', self.g_loss)
scalar_summary('g_const_loss', self.g_const_loss)
scalar_summary('f_const_loss', self.f_const_loss)
image_summary('original_images', self.images, max_outputs=4)
image_summary('sampled_images', self.sampled_images, max_outputs=4)
scalar_summary('d_loss_real', self.d_loss_real)
scalar_summary('d_loss_fake', self.d_loss_fake)
scalar_summary('d_loss', self.d_loss)
scalar_summary('g_loss', self.g_loss)
scalar_summary('g_const_loss', self.g_const_loss)
scalar_summary('f_const_loss', self.f_const_loss)
image_summary('original_images', self.images, max_outputs=4)
image_summary('sampled_images', self.sampled_images, max_outputs=4)
image_summary('original_images', self.images, max_images=4)
image_summary('sampled_images', self.sampled_images, max_images=4)
for var in tf.trainable_variables():
histogram_summary(, var)
histogram_summary(, var)
self.summary_op = merge_summary()
self.summary_op = merge_summary()
self.saver = tf.train.Saver()
import os
import hickle
from scipy import ndimage
import scipy.misc
from config import SummaryWriter
class Solver(object):
"""Load dataset and train and test the model"""
"""Load dataset and train and test the model"""
def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/',
log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100):
def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/',
log_path='log/', sample_path='sample/', test_model_path=None, sample_iter=100):
self.model = model
self.num_epoch = num_epoch
self.mnist_path = mnist_path
self.svhn_path = svhn_path
self.model_save_path = model_save_path
self.log_path = log_path
self.sample_path = sample_path
self.test_model_path = test_model_path
self.sample_iter = sample_iter
# create directory if not exists
if not os.path.exists(log_path):
if not os.path.exists(model_save_path):
if not os.path.exists(sample_path):
# construct the dcgan model
# load dataset
self.svhn = self.load_svhn(self.svhn_path)
self.mnist = self.load_mnist(self.mnist_path)
def load_svhn(self, image_path, split='train'):
print ('loading svhn image dataset..')
if split == 'train':
images = images / 127.5 - 1
print ('finished loading mnist image dataset..!')
return images
def merge_images(self, sources, targets, k=10):
_, h, w, _ = sources.shape
row = int(np.sqrt(self.model.batch_size))
merged = np.zeros([row*h, row*w*2, 3])
for idx, (s, t) in enumerate(zip(sources, targets)):
i = idx // row
j = idx % row
merged[i*h:(i+1)*h, (j*2)*h:(j*2+1)*h, :] = s
merged[i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h, :] = t
return merged
def train(self):
# load image dataset
svhn = self.load_svhn(self.svhn_path)
mnist = self.load_mnist(self.mnist_path)
#load image dataset
svhn = self.svhn
summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# initialize parameters
summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
summary_writer = SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
for e in range(self.num_epoch):
for i in range(num_iter_per_epoch):
# train model for source domain S
# train model for source domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
if i % 3 == 0:, feed_dict)
if i % 3 == 0:, feed_dict)
if i % 10 == 0:
feed_dict = {model.images: image_batch}
summary, d_loss, g_loss =[model.summary_op, model.d_loss, model.g_loss], feed_dict)
summary_writer.add_summary(summary, e*num_iter_per_epoch + i)
print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss))
# train model for target domain T
# train model for target domain T
image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size]
, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1)
print ('model/dtn-%d-%d saved' %(e+1, i+1))
if i % 500 == 0:, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1)
print ('model/dcgan-%d-%d saved' %(e+1, i+1))
\ No newline at end of file, os.path.join(self.model_save_path, 'dtn-%d' %(e+1)), global_step=i+1)
print ('model/dtn-%d-%d saved' %(e+1, i+1))
def test(self):
model = self.model
# load dataset
svhn = self.load_svhn(self.svhn_path)
num_iter = int(svhn.shape[0] / model.batch_size)
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# load trained parameters
saver = tf.train.Saver()
saver.restore(sess, self.test_model_path)
for i in range(self.sample_iter):
# train model for source domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sampled_image_batch =, feed_dict)
# merge and save source images and sampled target images
merged = self.merge_images(image_batch, sampled_image_batch)
path = os.path.join(self.sample_path, 'sample-%d-to-%d.png' %(i*model.batch_size, (i+1)*model.batch_size))
scipy.misc.imsave(path, merged)
print ('saved %s' %path)
from model import DTN
from solver import Solver
def main():
model = DTN()
solver = Solver(model, num_epoch=10, svhn_path='svhn/', model_save_path='model/', log_path='log/')
if __name__ == "__main__":
