Toggle navigation
Toggle navigation
This project
Loading...
Sign in
최강혁
/
dddd
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Graphs
Network
Create a new issue
Commits
Issue Boards
Authored by
yunjey
2017-01-22 05:37:06 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
051cf6043ae7255a94602bff735ada18f2415aab
051cf604
1 parent
aca152a9
main file
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
main.py
main.py
View file @
051cf60
...
...
@@ -5,25 +5,23 @@ from solver import Solver
flags
=
tf
.
app
.
flags
flags
.
DEFINE_
boolean
(
'is_train'
,
False
,
'True if train mode, False if test mode'
)
flags
.
DEFINE_
string
(
'mode'
,
'train'
,
"'pretrain', 'train' or 'eval'"
)
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
model
=
DTN
(
batch_size
=
100
,
learning_rate
=
0.001
,
image_size
=
32
,
output_size
=
32
,
dim_color
=
3
,
dim_fout
=
100
,
dim_df
=
64
,
dim_gf
=
64
,
dim_ff
=
64
)
solver
=
Solver
(
model
,
num_epoch
=
10
,
mnist_path
=
'mnist/'
,
svhn_path
=
'svhn/'
,
model_save_path
=
'model/'
,
log_path
=
'log/'
,
sample_path
=
'sample/'
,
test_model_path
=
'model/dtn-2-1'
,
sample_iter
=
100
)
model
=
DTN
(
mode
=
FLAGS
.
mode
,
learning_rate
=
0.0003
)
solver
=
Solver
(
model
,
batch_size
=
100
,
pretrain_iter
=
5000
,
train_iter
=
2000
,
sample_iter
=
100
,
svhn_dir
=
'svhn'
,
mnist_dir
=
'mnist'
,
log_dir
=
'logs'
,
model_save_path
=
'model'
)
if
FLAGS
.
is_train
:
if
FLAGS
.
mode
==
'pretrain'
:
solver
.
pretrain
()
elif
FLAGS
.
mode
==
'train'
:
solver
.
train
()
else
:
solver
.
test
()
solver
.
eval
()
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
...
...
Please
register
or
login
to post a comment