yunjey

train and eval the model

...@@ -11,7 +11,7 @@ class Solver(object): ...@@ -11,7 +11,7 @@ class Solver(object):
11 11
12 def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100, 12 def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100,
13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', 13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
14 - model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-600'): 14 + model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-1800'):
15 15
16 self.model = model 16 self.model = model
17 self.batch_size = batch_size 17 self.batch_size = batch_size
...@@ -121,11 +121,9 @@ class Solver(object): ...@@ -121,11 +121,9 @@ class Solver(object):
121 tf.global_variables_initializer().run() 121 tf.global_variables_initializer().run()
122 # restore variables of F 122 # restore variables of F
123 print ('loading pretrained model F..') 123 print ('loading pretrained model F..')
124 - #variables_to_restore = slim.get_model_variables(scope='content_extractor') 124 + variables_to_restore = slim.get_model_variables(scope='content_extractor')
125 - #restorer = tf.train.Saver(variables_to_restore) 125 + restorer = tf.train.Saver(variables_to_restore)
126 - #restorer.restore(sess, self.pretrained_model) 126 + restorer.restore(sess, self.pretrained_model)
127 - restorer = tf.train.Saver()
128 - restorer.restore(sess, 'model/dtn-1600')
129 summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) 127 summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
130 saver = tf.train.Saver() 128 saver = tf.train.Saver()
131 129
......