Showing
1 changed file
with
10 additions
and
2 deletions
... | @@ -4,13 +4,21 @@ from solver import Solver | ... | @@ -4,13 +4,21 @@ from solver import Solver |
4 | 4 | ||
5 | flags = tf.app.flags | 5 | flags = tf.app.flags |
6 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") | 6 | flags.DEFINE_string('mode', 'train', "'pretrain', 'train' or 'eval'") |
7 | +flags.DEFINE_string('model_save_path', 'model', "directory for saving the model") | ||
8 | +flags.DEFINE_string('sample_save_path', 'sample', "directory for saving the sampled images") | ||
7 | FLAGS = flags.FLAGS | 9 | FLAGS = flags.FLAGS |
8 | 10 | ||
9 | def main(_): | 11 | def main(_): |
10 | 12 | ||
11 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) | 13 | model = DTN(mode=FLAGS.mode, learning_rate=0.0003) |
12 | - solver = Solver(model, batch_size=100, pretrain_iter=5000, train_iter=2000, sample_iter=100, | 14 | + solver = Solver(model, batch_size=100, pretrain_iter=20000, train_iter=2000, sample_iter=100, |
13 | - svhn_dir='svhn', mnist_dir='mnist', log_dir='logs', model_save_path='model') | 15 | + svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path) |
16 | + | ||
17 | + # create directories if not exist | ||
18 | + if not tf.gfile.Exists(FLAGS.model_save_path): | ||
19 | + tf.gfile.MakeDirs(FLAGS.model_save_path) | ||
20 | + if not tf.gfile.Exists(FLAGS.sample_save_path): | ||
21 | + tf.gfile.MakeDirs(FLAGS.sample_save_path) | ||
14 | 22 | ||
15 | if FLAGS.mode == 'pretrain': | 23 | if FLAGS.mode == 'pretrain': |
16 | solver.pretrain() | 24 | solver.pretrain() | ... | ... |
-
Please register or login to post a comment