yunjey

train and eval the model

......@@ -11,7 +11,7 @@ class Solver(object):
def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100,
svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-600'):
model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-1800'):
self.model = model
self.batch_size = batch_size
......@@ -121,11 +121,9 @@ class Solver(object):
tf.global_variables_initializer().run()
# restore variables of F
print ('loading pretrained model F..')
#variables_to_restore = slim.get_model_variables(scope='content_extractor')
#restorer = tf.train.Saver(variables_to_restore)
#restorer.restore(sess, self.pretrained_model)
restorer = tf.train.Saver()
restorer.restore(sess, 'model/dtn-1600')
variables_to_restore = slim.get_model_variables(scope='content_extractor')
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, self.pretrained_model)
summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
saver = tf.train.Saver()
......