yunjey

train and eval the model

Showing 1 changed file with 13 additions and 7 deletions
...@@ -9,9 +9,9 @@ import scipy.misc ...@@ -9,9 +9,9 @@ import scipy.misc
9 9
10 class Solver(object): 10 class Solver(object):
11 11
12 - def __init__(self, model, batch_size=100, pretrain_iter=10000, train_iter=2000, sample_iter=100, 12 + def __init__(self, model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100,
13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample', 13 svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', sample_save_path='sample',
14 - model_save_path='model', pretrained_model='model/svhn_model-10000', test_model='model/dtn-2000'): 14 + model_save_path='model', pretrained_model='model/svhn_model-20000', test_model='model/dtn-600'):
15 15
16 self.model = model 16 self.model = model
17 self.batch_size = batch_size 17 self.batch_size = batch_size
...@@ -111,7 +111,7 @@ class Solver(object): ...@@ -111,7 +111,7 @@ class Solver(object):
111 model = self.model 111 model = self.model
112 model.build_model() 112 model.build_model()
113 113
114 - # make log directory if not exists 114 + # make directory if not exists
115 if tf.gfile.Exists(self.log_dir): 115 if tf.gfile.Exists(self.log_dir):
116 tf.gfile.DeleteRecursively(self.log_dir) 116 tf.gfile.DeleteRecursively(self.log_dir)
117 tf.gfile.MakeDirs(self.log_dir) 117 tf.gfile.MakeDirs(self.log_dir)
...@@ -121,13 +121,16 @@ class Solver(object): ...@@ -121,13 +121,16 @@ class Solver(object):
121 tf.global_variables_initializer().run() 121 tf.global_variables_initializer().run()
122 # restore variables of F 122 # restore variables of F
123 print ('loading pretrained model F..') 123 print ('loading pretrained model F..')
124 - variables_to_restore = slim.get_model_variables(scope='content_extractor') 124 + #variables_to_restore = slim.get_model_variables(scope='content_extractor')
125 - restorer = tf.train.Saver(variables_to_restore) 125 + #restorer = tf.train.Saver(variables_to_restore)
126 - restorer.restore(sess, self.pretrained_model) 126 + #restorer.restore(sess, self.pretrained_model)
127 + restorer = tf.train.Saver()
128 + restorer.restore(sess, 'model/dtn-1600')
127 summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph()) 129 summary_writer = tf.summary.FileWriter(logdir=self.log_dir, graph=tf.get_default_graph())
128 saver = tf.train.Saver() 130 saver = tf.train.Saver()
129 131
130 print ('start training..!') 132 print ('start training..!')
133 + f_interval = 15
131 for step in range(self.train_iter+1): 134 for step in range(self.train_iter+1):
132 135
133 i = step % int(svhn_images.shape[0] / self.batch_size) 136 i = step % int(svhn_images.shape[0] / self.batch_size)
...@@ -143,7 +146,10 @@ class Solver(object): ...@@ -143,7 +146,10 @@ class Solver(object):
143 sess.run([model.g_train_op_src], feed_dict) 146 sess.run([model.g_train_op_src], feed_dict)
144 sess.run([model.g_train_op_src], feed_dict) 147 sess.run([model.g_train_op_src], feed_dict)
145 148
146 - if i % 15 == 0: 149 + if step > 1600:
150 + f_interval = 30
151 +
152 + if i % f_interval == 0:
147 sess.run(model.f_train_op_src, feed_dict) 153 sess.run(model.f_train_op_src, feed_dict)
148 154
149 if (step+1) % 10 == 0: 155 if (step+1) % 10 == 0:
......