김성주

fix/update for pretrained transfer learning

from __future__ import division, print_function
import os
import sys
import tensorflow as tf
import numpy as np
from model import yolov3
from misc_utils import parse_anchors, load_weights
img_size = 416
weight_path = '../../data/darknet_weights/yolov3.weights'
save_path = '../../data/darknet_weights/yolov3.ckpt'
anchors = parse_anchors('../../data/yolo_anchors.txt')
model = yolov3(80, anchors)
with tf.Session() as sess:
inputs = tf.placeholder(tf.float32, [1, img_size, img_size, 3])
with tf.variable_scope('yolov3'):
feature_map = model.forward(inputs)
saver = tf.train.Saver(var_list=tf.global_variables(scope='yolov3'))
load_ops = load_weights(tf.global_variables(scope='yolov3'), weight_path)
sess.run(load_ops)
saver.save(sess, save_path=save_path)
print('TensorFlow model checkpoint has been saved to {}'.format(save_path))
\ No newline at end of file
......@@ -97,9 +97,9 @@ saver_to_restore = tf.train.Saver()
with tf.Session() as sess:
sess.run([tf.global_variables_initializer()])
if os.path.exists(args.restore_path):
try:
saver_to_restore.restore(sess, args.restore_path)
else:
except:
raise ValueError('there is no model to evaluate. You should move/create the checkpoint file to restore path')
print('\nStart evaluation...\n')
......
......@@ -102,8 +102,12 @@ else:
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
if os.path.exists(args.restore_path):
saver_to_restore.restore(sess, args.restore_path)
try:
saver_to_restore.restore(sess, restore_path)
print("Restoring parameters...")
except:
print("*** Failed to restore parameters!!! You would need pretrained weights ***")
print('\nStart training...: Total epoches =', args.total_epoches, '\n')
......@@ -184,7 +188,6 @@ with tf.Session() as sess:
best_mAP = mAP
saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
saver_to_restore.save(sess, restore_path)
## all epoches end
sess.run(val_init_op)
......@@ -227,4 +230,3 @@ with tf.Session() as sess:
best_mAP = mAP
saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
\ No newline at end of file
saver_to_restore.save(sess, restore_path)
\ No newline at end of file
......
This diff could not be displayed because it is too large.