yunjey

add code for creating directories

Showing 1 changed file with 10 additions and 2 deletions
...@@ -4,13 +4,21 @@ from solver import Solver ...@@ -4,13 +4,21 @@ from solver import Solver
4 4
5 flags = tf.app.flags 5 flags = tf.app.flags
6 flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") 6 flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'")
7 +flags.DEFINE_string('model_save_path', 'model', "directory for saving the model")
8 +flags.DEFINE_string('sample_save_path', 'sample', "directory for saving the sampled images")
7 FLAGS = flags.FLAGS 9 FLAGS = flags.FLAGS
8 10
9 def main(_): 11 def main(_):
10 12
11 model = DTN(mode=FLAGS.mode, learning_rate=0.0003) 13 model = DTN(mode=FLAGS.mode, learning_rate=0.0003)
12 - solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, 14 + solver = Solver(model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100,
13 - svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') 15 + svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path)
16 +
17 + # create directories if not exist
18 + if not tf.gfile.Exists(FLAGS.model_save_path):
19 + tf.gfile.MakeDirs(FLAGS.model_save_path)
20 + if not tf.gfile.Exists(FLAGS.sample_save_path):
21 + tf.gfile.MakeDirs(FLAGS.sample_save_path)
14 22
15 if FLAGS.mode == 'pretrain': 23 if FLAGS.mode == 'pretrain':
16 solver.pretrain() 24 solver.pretrain()
......