import tensorflow as tf
from ops import * 
from config import * 

class DTN(object):
    """Domain Transfer Network for unsupervised cross-domain image generation
    Construct discriminator and generator to prepare for training.
    def __init__(self, batch_size=100, learning_rate=0.0001, image_size=32, output_size=32, 
                 dim_color=3, dim_fout=100, dim_df=64, dim_gf=64, dim_ff=64):
            learning_rate: (optional) learning rate for discriminator and generator
            image_size: (optional) spatial size of input image for discriminator
            output_size: (optional) spatial size of image generated by generator
            dim_color: (optional) dimension of image color; default is 3 for rgb
            dim_fout: (optional) dimension of z (random input vector for generator)
            dim_df: (optional) dimension of discriminator's filter in first convolution layer
            dim_gf: (optional) dimension of generator's filter in last convolution layer
            dim_ff: (optional) dimension of function f's filter in first convolution layer
        # hyper parameters
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.image_size = image_size
        self.output_size = output_size
        self.dim_color = dim_color
        self.dim_fout = dim_fout
        self.dim_df = dim_df
        self.dim_gf = dim_gf
        self.dim_ff = dim_ff
        # placeholder
        self.images = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, dim_color], name='images')
        #self.z = tf.placeholder(tf.float32, shape=[None, dim_z], name='input_for_generator')
        # batch normalization layer for discriminator, generator and funtion f
        self.d_bn1 = batch_norm(name='d_bn1')
        self.d_bn2 = batch_norm(name='d_bn2')
        self.d_bn3 = batch_norm(name='d_bn3')
        self.d_bn4 = batch_norm(name='d_bn4')
        self.g_bn1 = batch_norm(name='g_bn1')
        self.g_bn2 = batch_norm(name='g_bn2')
        self.g_bn3 = batch_norm(name='g_bn3')
        self.g_bn4 = batch_norm(name='g_bn4')
        self.f_bn1 = batch_norm(name='f_bn1')
        self.f_bn2 = batch_norm(name='f_bn2')
        self.f_bn3 = batch_norm(name='f_bn3')
        self.f_bn4 = batch_norm(name='f_bn4') 
    def function_f(self, images, reuse=False, train=True):
        """f consistancy
            images: images for domain S and T, of shape (batch_size, image_size, image_size, dim_color)
            out: output vectors, of shape (batch_size, dim_f_out)
        with tf.variable_scope('function_f', reuse=reuse):
            h1 = lrelu(self.f_bn1(conv2d(images, self.dim_ff, name='f_h1'), train=train))      # (batch_size, 16, 16, 64)
            h2 = lrelu(self.f_bn2(conv2d(h1, self.dim_ff*2, name='f_h2'), train=train))        # (batch_size, 8, 8 128)
            h3 = lrelu(self.f_bn3(conv2d(h2, self.dim_ff*4, name='f_h3'), train=train))        # (batch_size, 4, 4, 256)
            h4 = lrelu(self.f_bn4(conv2d(h3, self.dim_ff*8, name='f_h4'), train=train))        # (batch_size, 2, 2, 512)

            h4 = tf.reshape(h4, [self.batch_size,-1])
            out = linear(h4, self.dim_fout, name='f_out') 
        return tf.nn.tanh(out)
    def generator(self, z, reuse=False):
        """Generator: Deconvolutional neural network with relu activations.
        Last deconv layer does not use batch normalization.
            z: random input vectors, of shape (batch_size, dim_z)
            out: generated images, of shape (batch_size, image_size, image_size, dim_color)
        if reuse:
            train = False
            train = True
        with tf.variable_scope('generator', reuse=reuse):
            # spatial size for convolution
            s = self.output_size
            s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)     # 32, 16, 8, 4
            # project and reshape z 
            h1= linear(z, s16*s16*self.dim_gf*8, name='g_h1')                          # (batch_size, 2*2*512)
            h1 = tf.reshape(h1, [-1, s16, s16, self.dim_gf*8])                         # (batch_size, 2, 2, 512) 
            h1 = relu(self.g_bn1(h1, train=train))
            h2 = deconv2d(h1, [self.batch_size, s8, s8, self.dim_gf*4], name='g_h2')   # (batch_size, 4, 4, 256)
            h2 = relu(self.g_bn2(h2, train=train))
            h3 = deconv2d(h2, [self.batch_size, s4, s4, self.dim_gf*2], name='g_h3')   # (batch_size, 8, 8, 128)
            h3 = relu(self.g_bn3(h3, train=train))
            h4 = deconv2d(h3, [self.batch_size, s2, s2, self.dim_gf], name='g_h4')     # (batch_size, 16, 16, 64)
            h4 = relu(self.g_bn4(h4, train=train))
            out = deconv2d(h4, [self.batch_size, s, s, self.dim_color], name='g_out')  # (batch_size, 32, 32, dim_color)
            return tf.nn.tanh(out)
    def discriminator(self, images, reuse=False):
        """Discrimator: Convolutional neural network with leaky relu activations.
        First conv layer does not use batch normalization.
            images: real or fake images of shape (batch_size, image_size, image_size, dim_color)  
            out: scores for whether it is a real image or a fake image, of shape (batch_size,)
        with tf.variable_scope('discriminator', reuse=reuse):
            # convolution layer
            h1 = lrelu(self.d_bn1(conv2d(images, self.dim_df, name='d_h1')))      # (batch_size, 16, 16, 64)
            h2 = lrelu(self.d_bn2(conv2d(h1, self.dim_df*2, name='d_h2')))        # (batch_size, 8, 8, 128)
            h3 = lrelu(self.d_bn3(conv2d(h2, self.dim_df*4, name='d_h3')))        # (batch_size, 4, 4, 256)
            h4 = lrelu(self.d_bn4(conv2d(h3, self.dim_df*8, name='d_h4')))        # (batch_size, 2, 2, 512)

            # fully connected layer
            h4 = tf.reshape(h4, [self.batch_size, -1])
            out = linear(h4, 1, name='d_out')                                     # (batch_size,)  

            return out
    def build_model(self):
        # construct generator and discriminator for training phase 
        self.f_x = self.function_f(self.images)
        self.fake_images = self.generator(self.f_x)                              # (batch_size, 32, 32, 3)
        self.logits_real = self.discriminator(self.images)                       # (batch_size,)
        self.logits_fake = self.discriminator(self.fake_images, reuse=True)      # (batch_size,)
        self.fgf_x = self.function_f(self.fake_images, reuse=True)   # (batch_size, dim_f)
        # construct generator for test phase (use moving average and variance for batch norm)
        self.f_x = self.function_f(self.images, reuse=True, train=False)
        self.sampled_images = self.generator(self.f_x, reuse=True)                # (batch_size, 32, 32, 3)
        # compute loss 
        self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_real, tf.ones_like(self.logits_real)))
        self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.zeros_like(self.logits_fake)))           
        self.d_loss = self.d_loss_real + self.d_loss_fake
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.logits_fake, tf.ones_like(self.logits_fake)))
        self.g_const_loss = tf.reduce_mean(tf.square(self.images - self.fake_images))   # L_TID
        self.f_const_loss = tf.reduce_mean(tf.square(self.f_x - self.fgf_x))  * 0.15         # L_CONST
        # divide variables for discriminator and generator 
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'discriminator' in]
        self.g_vars = [var for var in t_vars if 'generator' in]
        self.f_vars = [var for var in t_vars if 'function_f' in]
        # optimizer for discriminator and generator
        with tf.name_scope('optimizer'):
            self.d_optimizer_real = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_real, var_list=self.d_vars)
            self.d_optimizer_fake = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.d_loss_fake, var_list=self.d_vars)
            self.g_optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_loss, var_list=self.g_vars+self.f_vars)   
            self.g_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.g_const_loss, var_list=self.g_vars+self.f_vars)
            self.f_optimizer_const = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5).minimize(self.f_const_loss, var_list=self.g_vars+self.f_vars) 
        # summary ops for tensorboard visualization
        scalar_summary('d_loss_real', self.d_loss_real)
        scalar_summary('d_loss_fake', self.d_loss_fake)
        scalar_summary('d_loss', self.d_loss)
        scalar_summary('g_loss', self.g_loss)
        scalar_summary('g_const_loss', self.g_const_loss)
        scalar_summary('f_const_loss', self.f_const_loss)
            image_summary('original_images', self.images, max_outputs=4)
            image_summary('sampled_images', self.sampled_images, max_outputs=4)
            image_summary('original_images', self.images, max_images=4)
            image_summary('sampled_images', self.sampled_images, max_images=4)
        for var in tf.trainable_variables():
            histogram_summary(, var)
        self.summary_op = merge_summary() 
        self.saver = tf.train.Saver()