김성주

fixed errors for evaluation

...@@ -60,7 +60,7 @@ args = parser.parse_args() ...@@ -60,7 +60,7 @@ args = parser.parse_args()
60 args.anchors = parse_anchors(args.anchor_path) 60 args.anchors = parse_anchors(args.anchor_path)
61 args.classes = read_class_names(args.class_name_path) 61 args.classes = read_class_names(args.class_name_path)
62 args.class_num = len(args.classes) 62 args.class_num = len(args.classes)
63 -args.img_cnt = len(open(args.eval_file, 'r').readlines()) 63 +args.img_cnt = TFRecordIterator(args.eval_file, 'GZIP').count()
64 64
65 # setting placeholders 65 # setting placeholders
66 is_training = tf.placeholder(dtype=tf.bool, name="phase_train") 66 is_training = tf.placeholder(dtype=tf.bool, name="phase_train")
......
...@@ -183,4 +183,48 @@ with tf.Session() as sess: ...@@ -183,4 +183,48 @@ with tf.Session() as sess:
183 if args.save_optimizer and mAP > best_mAP: 183 if args.save_optimizer and mAP > best_mAP:
184 best_mAP = mAP 184 best_mAP = mAP
185 saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( 185 saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
186 - epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
...\ No newline at end of file ...\ No newline at end of file
186 + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
187 + saver_to_restore.save(sess, restore_path)
188 +
189 + ## all epoches end
190 + sess.run(val_init_op)
191 +
192 + val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
193 +
194 + val_preds = []
195 +
196 + for j in trange(val_img_cnt):
197 + __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss],
198 + feed_dict={is_training: False})
199 + pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred)
200 + val_preds.extend(pred_content)
201 + val_loss_total.update(__loss[0])
202 + val_loss_xy.update(__loss[1])
203 + val_loss_wh.update(__loss[2])
204 + val_loss_conf.update(__loss[3])
205 + val_loss_class.update(__loss[4])
206 +
207 + # calc mAP
208 + rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter()
209 + gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize)
210 +
211 + info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr)
212 +
213 + for ii in range(class_num):
214 + npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric)
215 + info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap)
216 + rec_total.update(rec, npos)
217 + prec_total.update(prec, nd)
218 + ap_total.update(ap, 1)
219 +
220 + mAP = ap_total.average
221 + info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.average, prec_total.average, mAP)
222 + info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format(
223 + val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average)
224 + print(info)
225 +
226 + if save_optimizer and mAP > best_mAP:
227 + best_mAP = mAP
228 + saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
229 + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
230 + saver_to_restore.save(sess, restore_path)
...\ No newline at end of file ...\ No newline at end of file
......
This diff is collapsed. Click to expand it.