yunjey

train and eval the model

Showing 1 changed file with 13 additions and 7 deletions
......@@ -9,9 +9,9 @@ import scipy.misc
class Solver(object):
def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100,
def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100,
svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'):
model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-600'):
self.model = model
self.batch_size = batch_size
......@@ -111,7 +111,7 @@ class Solver(object):
model = self.model
model.build_model()
# make log directory if not exists
# make directory if not exists
if tf.gfile.Exists(self.log_dir):
tf.gfile.DeleteRecursively(self.log_dir)
tf.gfile.MakeDirs(self.log_dir)
......@@ -121,13 +121,16 @@ class Solver(object):
tf.global_variables_initializer().run()
# restore variables of F
print ('loading pretrained model F..')
variables_to_restore = slim.get_model_variables(scope='content_extractor')
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, self.pretrained_model)
#variables_to_restore = slim.get_model_variables(scope='content_extractor')
#restorer = tf.train.Saver(variables_to_restore)
#restorer.restore(sess, self.pretrained_model)
restorer = tf.train.Saver()
restorer.restore(sess, 'model/dtn-1600')
summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
saver = tf.train.Saver()
print ('start training..!')
f_interval = 15
for step in range(self.train_iter+1):
i = step % int(svhn_images.shape[0] / self.batch_size)
......@@ -143,7 +146,10 @@ class Solver(object):
sess.run([model.g_train_op_src], feed_dict)
sess.run([model.g_train_op_src], feed_dict)
if i % 15 == 0:
if step > 1600:
f_interval = 30
if i % f_interval == 0:
sess.run(model.f_train_op_src, feed_dict)
if (step+1) % 10 == 0:
......