solver.py
4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import tensorflow as tf
import numpy as np
import os
import scipy.io
import hickle
from scipy import ndimage
class Solver(object):
"""Load dataset and train DCGAN"""
def __init__(self, model, num_epoch=10, mnist_path= 'mnist/', svhn_path='svhn/', model_save_path='model/', log_path='log/'):
self.model = model
self.num_epoch = num_epoch
self.mnist_path = mnist_path
self.svhn_path = svhn_path
self.model_save_path = model_save_path
self.log_path = log_path
# create directory if not exists
if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
# construct the dcgan model
model.build_model()
# load dataset
self.svhn = self.load_svhn(self.svhn_path)
self.mnist = self.load_mnist(self.mnist_path)
def load_svhn(self, image_path, split='train'):
print ('loading svhn image dataset..')
if split == 'train':
svhn = scipy.io.loadmat(os.path.join(image_path, 'train_32x32.mat'))
else:
svhn = scipy.io.loadmat(os.path.join(image_path, 'test_32x32.mat'))
images = np.transpose(svhn['X'], [3, 0, 1, 2])
images = images / 127.5 - 1
print ('finished loading svhn image dataset..!')
return images
def load_mnist(self, image_path, split='train'):
print ('loading mnist image dataset..')
if split == 'train':
image_file = os.path.join(image_path, 'train.images.hkl')
else:
image_file = os.path.join(image_path, 'test.images.hkl')
images = hickle.load(image_file)
images = images / 127.5 - 1
print ('finished loading mnist image dataset..!')
return images
def train(self):
model=self.model
#load image dataset
svhn = self.svhn
mnist = self.mnist
num_iter_per_epoch = int(mnist.shape[0] / model.batch_size)
config = tf.ConfigProto(allow_soft_placement = True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# initialize parameters
tf.initialize_all_variables().run()
summary_writer = tf.train.SummaryWriter(logdir=self.log_path, graph=tf.get_default_graph())
for e in range(self.num_epoch):
for i in range(num_iter_per_epoch):
# train model for domain S
image_batch = svhn[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sess.run(model.d_optimizer_fake, feed_dict)
sess.run(model.f_optimizer_const, feed_dict)
sess.run(model.g_optimizer, feed_dict)
if i % 10 == 0:
feed_dict = {model.images: image_batch}
summary, d_loss, g_loss = sess.run([model.summary_op, model.d_loss, model.g_loss], feed_dict)
summary_writer.add_summary(summary, e*num_iter_per_epoch + i)
print ('Epoch: [%d] Step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]' %(e+1, i+1, num_iter_per_epoch, d_loss, g_loss))
# train model for domain T
image_batch = mnist[i*model.batch_size:(i+1)*model.batch_size]
feed_dict = {model.images: image_batch}
sess.run(model.d_optimizer_real, feed_dict)
sess.run(model.d_optimizer_fake, feed_dict)
sess.run(model.g_optimizer, feed_dict)
sess.run(model.g_optimizer_const, feed_dict)
if i % 500 == 0:
model.saver.save(sess, os.path.join(self.model_save_path, 'dcgan-%d' %(e+1)), global_step=i+1)
print ('model/dcgan-%d-%d saved' %(e+1, i+1))