김성주

android implementation update & readme

...@@ -12,4 +12,63 @@ ...@@ -12,4 +12,63 @@
12 - 2017103967 김성주 12 - 2017103967 김성주
13 - 2015104213 장수창 13 - 2015104213 장수창
14 14
15 +
15 16
17 +### 준비 사항
18 +tensorflow-android 라이브러리의 최신 버전이 (2020.06.01 기준) 1.13.1입니다.
19 +따라서 android implementation까지 구현하는 경우에는
20 +상위 버전과 호환이 되도록 라이브러리를 빌드하거나, 학습 혹은 pb 파일 생성 또한 tensorflow v1.13.1 이하로 진행하셔야 합니다.
21 +
22 +annotation에는 labelImg 툴을 이용하여 xml을 생성하였습니다.
23 +
24 +학습에는 TFrecord 형태로 저장된 파일을 사용합니다.
25 +데이터 하나의 형식은 {data index, image binary, image width, image height, boxes}이며
26 +boxes의 형식은 {label1, xmin, ymin, xmax, ymax, label2, xmin, ...}입니다.
27 +TFRecord 파일 작성은 code/tfrecord_writer.py를 참고하시기 바랍니다.
28 +
29 +tfrecord_writer.py에서 입력으로 받는 txt 파일은
30 +각 라인마다 {data index, image path, image width, image height, boxes} 형태로 저장되어 있습니다.
31 +txt 파일 생성은 code/annotation_xml_parser.py를 참고하시기 바랍니다.
32 +
33 +이 학습에서는 train/eval/test 데이터셋을 구분하여 사용합니다.
34 +txt 파일에 대한 데이터셋 분리는 code/dataset_splitter.py를 참고하기시 바랍니다.
35 +
36 +annotation_xml_parser.py에서 입력으로 받는 xml 파일은
37 +labelImg 툴로 생성된 Pascal VOC format XML 파일을 기준으로 합니다.
38 +
39 +학습을 위해서 anchor 파일이 필요합니다.
40 +anchor 파일 생성에는 code/yolov3/get_kmeans.py를 참고하시기 바랍니다.
41 +출력된 anchor를 code/yolov3/args.py의 anchor_path에 맞는 위치에 저장하시면 됩니다.
42 +
43 +이 학습에서는 pretrained model을 불러와 fine tuning을 이용합니다.
44 +따라서 pretrained model 파일을 준비해야 합니다.
45 +pretrained model은 [링크](https://pjreddie.com/media/files/yolov3.weights)에서 다운로드할 수 있습니다.
46 +이 파일은 darknet weights 파일이므로, tensorflow model로 변환하려면 code/yolov3/convert_weights.py를 참고하시기 바랍니다.
47 +(git에는 이미 변환된 yolov3.ckpt만이 업로드되어 있습니다. 다른 데이터셋 혹은 다른 용도로 학습을 진행하려면 새로 생성하셔야 합니다.)
48 +
49 +학습에는 train.py (train/eval dataset)를, 평가에는 eval.py (test dataset)를 사용하시면 됩니다.
50 +학습에 사용하는 파일의 경로 및 hyper parameter 설정은 args.py를 참고하시기 바랍니다.
51 +평가에 대한 경로 설정은 eval.py에서 할 수 있습니다.
52 +
53 +data/trained에 임시 테스트용 trained model 파일이 업로드되어 있습니다.
54 +
55 +
56 +android implementation을 하는 경우에는 학습된 모델에 대한 pb 파일을 생성해야 합니다.
57 +code/pb/pbCreator.py를 참고하시기 바랍니다. (code/yolov3/test_single_image.py를 약간 수정한 파일입니다)
58 +
59 +android에서는 freeze된 model만 사용할 수 있습니다.
60 +code/pb/freeze_pb.py를 참고하시기 바랍니다.
61 +
62 +android_App/assets에 pb file을 저장한 후, DetectorActivity.java에서 YOLO_MODEL_FILE의 값을 알맞게 수정하시면 됩니다.
63 +
64 +이 학습 코드로 생성된 모델의 input, output node name은
65 +각각 input_data, {yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3} 입니다.
66 +모델의 node name 참고에는 Netron 프로그램을 사용하였습니다.
67 +
68 +
69 +#### Reference
70 +학습 코드는 [링크](https://github.com/wizyoung/YOLOv3_TensorFlow)를 기반으로 작셩하였습니다.
71 +변경점은 code/yolov3/changes.txt를 참고하시기 바랍니다.
72 +
73 +android 코드는 [링크](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android)를 기반으로 작성하였습니다.
74 +
......
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="RemoteRepositoriesConfiguration">
4 + <remote-repository>
5 + <option name="id" value="central" />
6 + <option name="name" value="Maven Central repository" />
7 + <option name="url" value="https://repo1.maven.org/maven2" />
8 + </remote-repository>
9 + <remote-repository>
10 + <option name="id" value="jboss.community" />
11 + <option name="name" value="JBoss Community repository" />
12 + <option name="url" value="https://repository.jboss.org/nexus/content/repositories/public/" />
13 + </remote-repository>
14 + <remote-repository>
15 + <option name="id" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\android\m2repository" />
16 + <option name="name" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\android\m2repository" />
17 + <option name="url" value="file:/$USER_HOME$/AppData/Local/Android/Sdk/extras/android/m2repository" />
18 + </remote-repository>
19 + <remote-repository>
20 + <option name="id" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\m2repository" />
21 + <option name="name" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\m2repository" />
22 + <option name="url" value="file:/$USER_HOME$/AppData/Local/Android/Sdk/extras/m2repository" />
23 + </remote-repository>
24 + <remote-repository>
25 + <option name="id" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\google\m2repository" />
26 + <option name="name" value="C:\Users\Kareus\AppData\Local\Android\Sdk\extras\google\m2repository" />
27 + <option name="url" value="file:/$USER_HOME$/AppData/Local/Android/Sdk/extras/google/m2repository" />
28 + </remote-repository>
29 + <remote-repository>
30 + <option name="id" value="BintrayJCenter" />
31 + <option name="name" value="BintrayJCenter" />
32 + <option name="url" value="https://jcenter.bintray.com/" />
33 + </remote-repository>
34 + <remote-repository>
35 + <option name="id" value="Google" />
36 + <option name="name" value="Google" />
37 + <option name="url" value="https://dl.google.com/dl/android/maven2/" />
38 + </remote-repository>
39 + </component>
40 +</project>
...\ No newline at end of file ...\ No newline at end of file
...@@ -25,21 +25,10 @@ ...@@ -25,21 +25,10 @@
25 <uses-permission android:name="android.permission.RECORD_AUDIO" /> 25 <uses-permission android:name="android.permission.RECORD_AUDIO" />
26 26
27 <application android:allowBackup="true" 27 <application android:allowBackup="true"
28 - android:debuggable="true"
29 android:label="@string/app_name" 28 android:label="@string/app_name"
30 android:icon="@drawable/ic_launcher" 29 android:icon="@drawable/ic_launcher"
31 android:theme="@style/MaterialTheme"> 30 android:theme="@style/MaterialTheme">
32 31
33 -<!-- <activity android:name="org.tensorflow.demo.ClassifierActivity"-->
34 -<!-- android:screenOrientation="portrait"-->
35 -<!-- android:label="@string/activity_name_classification">-->
36 -<!-- <intent-filter>-->
37 -<!-- <action android:name="android.intent.action.MAIN" />-->
38 -<!-- <category android:name="android.intent.category.LAUNCHER" />-->
39 -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
40 -<!-- </intent-filter>-->
41 -<!-- </activity>-->
42 -
43 <activity android:name="org.tensorflow.demo.DetectorActivity" 32 <activity android:name="org.tensorflow.demo.DetectorActivity"
44 android:screenOrientation="portrait" 33 android:screenOrientation="portrait"
45 android:label="@string/activity_name_detection"> 34 android:label="@string/activity_name_detection">
...@@ -50,25 +39,38 @@ ...@@ -50,25 +39,38 @@
50 </intent-filter> 39 </intent-filter>
51 </activity> 40 </activity>
52 41
53 -<!-- <activity android:name="org.tensorflow.demo.StylizeActivity"--> 42 + <!--
54 -<!-- android:screenOrientation="portrait"--> 43 + <activity android:name="org.tensorflow.demo.ClassifierActivity"
55 -<!-- android:label="@string/activity_name_stylize">--> 44 + android:screenOrientation="portrait"
56 -<!-- <intent-filter>--> 45 + android:label="@string/activity_name_classification">
57 -<!-- <action android:name="android.intent.action.MAIN" />--> 46 + <intent-filter>
58 -<!-- <category android:name="android.intent.category.LAUNCHER" />--> 47 + <action android:name="android.intent.action.MAIN" />
59 -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> 48 + <category android:name="android.intent.category.LAUNCHER" />
60 -<!-- </intent-filter>--> 49 + <category android:name="android.intent.category.LEANBACK_LAUNCHER" />
61 -<!-- </activity>--> 50 + </intent-filter>
51 + </activity>
52 +
53 + <activity android:name="org.tensorflow.demo.StylizeActivity"
54 + android:screenOrientation="portrait"
55 + android:label="@string/activity_name_stylize">
56 + <intent-filter>
57 + <action android:name="android.intent.action.MAIN" />
58 + <category android:name="android.intent.category.LAUNCHER" />
59 + <category android:name="android.intent.category.LEANBACK_LAUNCHER" />
60 + </intent-filter>
61 + </activity>
62 +
63 + <activity android:name="org.tensorflow.demo.SpeechActivity"
64 + android:screenOrientation="portrait"
65 + android:label="@string/activity_name_speech">
66 + <intent-filter>
67 + <action android:name="android.intent.action.MAIN" />
68 + <category android:name="android.intent.category.LAUNCHER" />
69 + <category android:name="android.intent.category.LEANBACK_LAUNCHER" />
70 + </intent-filter>
71 + </activity>
72 + -->
62 73
63 -<!-- <activity android:name="org.tensorflow.demo.SpeechActivity"-->
64 -<!-- android:screenOrientation="portrait"-->
65 -<!-- android:label="@string/activity_name_speech">-->
66 -<!-- <intent-filter>-->
67 -<!-- <action android:name="android.intent.action.MAIN" />-->
68 -<!-- <category android:name="android.intent.category.LAUNCHER" />-->
69 -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
70 -<!-- </intent-filter>-->
71 -<!-- </activity>-->
72 </application> 74 </application>
73 75
74 </manifest> 76 </manifest>
......
1 -package(
2 - default_visibility = ["//visibility:public"],
3 - licenses = ["notice"], # Apache 2.0
4 -)
5 -
6 -# It is necessary to use this filegroup rather than globbing the files in this
7 -# folder directly the examples/android:tensorflow_demo target due to the fact
8 -# that assets_dir is necessarily set to "" there (to allow using other
9 -# arbitrary targets as assets).
10 -filegroup(
11 - name = "asset_files",
12 - srcs = glob(
13 - ["**/*"],
14 - exclude = ["BUILD"],
15 - ),
16 -)
This file is too large to display.
...@@ -42,7 +42,7 @@ allprojects { ...@@ -42,7 +42,7 @@ allprojects {
42 } 42 }
43 43
44 // set to 'bazel', 'cmake', 'makefile', 'none' 44 // set to 'bazel', 'cmake', 'makefile', 'none'
45 -def nativeBuildSystem = 'none' 45 +def nativeBuildSystem = 'cmake'
46 46
47 // Controls output directory in APK and CPU type for Bazel builds. 47 // Controls output directory in APK and CPU type for Bazel builds.
48 // NOTE: Does not affect the Makefile build target API (yet), which currently 48 // NOTE: Does not affect the Makefile build target API (yet), which currently
......
1 +org.gradle.jvmargs=-Xmx2048m
...\ No newline at end of file ...\ No newline at end of file
1 -#Sat Nov 18 15:06:47 CET 2017 1 +#Sat May 30 18:49:07 KST 2020
2 distributionBase=GRADLE_USER_HOME 2 distributionBase=GRADLE_USER_HOME
3 distributionPath=wrapper/dists 3 distributionPath=wrapper/dists
4 zipStoreBase=GRADLE_USER_HOME 4 zipStoreBase=GRADLE_USER_HOME
5 zipStorePath=wrapper/dists 5 zipStorePath=wrapper/dists
6 -distributionUrl=https\://services.gradle.org/distributions/gradle-4.1-all.zip 6 +distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.1-all.zip
......
...@@ -71,11 +71,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable ...@@ -71,11 +71,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
71 // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via 71 // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via
72 // DarkFlow (https://github.com/thtrieu/darkflow). Sample command: 72 // DarkFlow (https://github.com/thtrieu/darkflow). Sample command:
73 // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise 73 // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise
74 - private static final String YOLO_MODEL_FILE = "file:///android_asset/yolov3.pb"; 74 + private static final String YOLO_MODEL_FILE = "file:///android_asset/test_freeze_13.pb";
75 private static final int YOLO_INPUT_SIZE = 416; 75 private static final int YOLO_INPUT_SIZE = 416;
76 - private static final String YOLO_INPUT_NAME = "input"; 76 + private static final String YOLO_INPUT_NAME = "input_data";
77 - private static final String YOLO_OUTPUT_NAMES = "output"; 77 + private static final String YOLO_OUTPUT_NAMES = "yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3";
78 - private static final int YOLO_BLOCK_SIZE = 32; 78 + private static final int YOLO_BLOCK_SIZE = 16;
79 79
80 // Which detection model to use: by default uses Tensorflow Object Detection API frozen 80 // Which detection model to use: by default uses Tensorflow Object Detection API frozen
81 // checkpoints. Optionally use legacy Multibox (trained using an older version of the API) 81 // checkpoints. Optionally use legacy Multibox (trained using an older version of the API)
...@@ -131,6 +131,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable ...@@ -131,6 +131,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
131 int cropSize = TF_OD_API_INPUT_SIZE; 131 int cropSize = TF_OD_API_INPUT_SIZE;
132 if (MODE == DetectorMode.YOLO) { 132 if (MODE == DetectorMode.YOLO) {
133 detector = 133 detector =
134 +
134 TensorFlowYoloDetector.create( 135 TensorFlowYoloDetector.create(
135 getAssets(), 136 getAssets(),
136 YOLO_MODEL_FILE, 137 YOLO_MODEL_FILE,
......
...@@ -32,7 +32,7 @@ public class TensorFlowYoloDetector implements Classifier { ...@@ -32,7 +32,7 @@ public class TensorFlowYoloDetector implements Classifier {
32 private static final Logger LOGGER = new Logger(); 32 private static final Logger LOGGER = new Logger();
33 33
34 // Only return this many results with at least this confidence. 34 // Only return this many results with at least this confidence.
35 - private static final int MAX_RESULTS = 5; 35 + private static final int MAX_RESULTS = 10;
36 36
37 private static final int NUM_CLASSES = 1; 37 private static final int NUM_CLASSES = 1;
38 38
...@@ -41,17 +41,14 @@ public class TensorFlowYoloDetector implements Classifier { ...@@ -41,17 +41,14 @@ public class TensorFlowYoloDetector implements Classifier {
41 // TODO(andrewharp): allow loading anchors and classes 41 // TODO(andrewharp): allow loading anchors and classes
42 // from files. 42 // from files.
43 private static final double[] ANCHORS = { 43 private static final double[] ANCHORS = {
44 - 1.08, 1.19, 44 + 35,37, 75,48, 57,87, 116,73, 83,138, 119,110, 154,184, 250,216, 317,362
45 - 3.42, 4.41,
46 - 6.63, 11.38,
47 - 9.42, 5.11,
48 - 16.62, 10.52
49 }; 45 };
50 46
51 private static final String[] LABELS = { 47 private static final String[] LABELS = {
52 - "dog" 48 + "dog"
53 }; 49 };
54 50
51 +
55 // Config values. 52 // Config values.
56 private String inputName; 53 private String inputName;
57 private int inputSize; 54 private int inputSize;
......
1 -# Copyright 2015 Google Inc. All Rights Reserved.
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -# ==============================================================================
15 -"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
16 -This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
17 -variable values stored in a checkpoint file, and output a GraphDef with all of
18 -the variable ops converted into const ops containing the values of the
19 -variables.
20 -It's useful to do this when we need to load a single file in C++, especially in
21 -environments like mobile or embedded where we may not have access to the
22 -RestoreTensor ops and file loading calls that they rely on.
23 -An example of command-line usage is:
24 -bazel build tensorflow/python/tools:freeze_graph && \
25 -bazel-bin/tensorflow/python/tools/freeze_graph \
26 ---input_graph=some_graph_def.pb \
27 ---input_checkpoint=model.ckpt-8361242 \
28 ---output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
29 -You can also look at freeze_graph_test.py for an example of how to use it.
30 -"""
31 -from __future__ import absolute_import
32 -from __future__ import division
33 -from __future__ import print_function
34 -
35 -import tensorflow as tf
36 -
37 -from google.protobuf import text_format
38 -from tensorflow.python.framework import graph_util
39 -
40 -
41 -FLAGS = tf.app.flags.FLAGS
42 -
43 -tf.app.flags.DEFINE_string("input_graph", "",
44 - """TensorFlow 'GraphDef' file to load.""")
45 -tf.app.flags.DEFINE_string("input_saver", "",
46 - """TensorFlow saver file to load.""")
47 -tf.app.flags.DEFINE_string("input_checkpoint", "",
48 - """TensorFlow variables file to load.""")
49 -tf.app.flags.DEFINE_string("output_graph", "",
50 - """Output 'GraphDef' file name.""")
51 -tf.app.flags.DEFINE_boolean("input_binary", False,
52 - """Whether the input files are in binary format.""")
53 -tf.app.flags.DEFINE_string("output_node_names", "",
54 - """The name of the output nodes, comma separated.""")
55 -tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all",
56 - """The name of the master restore operator.""")
57 -tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
58 - """The name of the tensor holding the save path.""")
59 -tf.app.flags.DEFINE_boolean("clear_devices", True,
60 - """Whether to remove device specifications.""")
61 -tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
62 - "initializer nodes to run before freezing.")
63 -
64 -
65 -def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
66 - output_node_names, restore_op_name, filename_tensor_name,
67 - output_graph, clear_devices, initializer_nodes):
68 - """Converts all variables in a graph and checkpoint into constants."""
69 -
70 - if not tf.gfile.Exists(input_graph):
71 - print("Input graph file '" + input_graph + "' does not exist!")
72 - return -1
73 -
74 - if input_saver and not tf.gfile.Exists(input_saver):
75 - print("Input saver file '" + input_saver + "' does not exist!")
76 - return -1
77 -
78 - if not tf.gfile.Glob(input_checkpoint):
79 - print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
80 - return -1
81 -
82 - if not output_node_names:
83 - print("You need to supply the name of a node to --output_node_names.")
84 - return -1
85 -
86 - input_graph_def = tf.GraphDef()
87 - mode = "rb" if input_binary else "r"
88 - with tf.gfile.FastGFile(input_graph, mode) as f:
89 - if input_binary:
90 - input_graph_def.ParseFromString(f.read())
91 - else:
92 - text_format.Merge(f.read(), input_graph_def)
93 - # Remove all the explicit device specifications for this node. This helps to
94 - # make the graph more portable.
95 - if clear_devices:
96 - for node in input_graph_def.node:
97 - node.device = ""
98 - _ = tf.import_graph_def(input_graph_def, name="")
99 -
100 - with tf.Session() as sess:
101 - if input_saver:
102 - with tf.gfile.FastGFile(input_saver, mode) as f:
103 - saver_def = tf.train.SaverDef()
104 - if input_binary:
105 - saver_def.ParseFromString(f.read())
106 - else:
107 - text_format.Merge(f.read(), saver_def)
108 - saver = tf.train.Saver(saver_def=saver_def)
109 - saver.restore(sess, input_checkpoint)
110 - else:
111 - sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
112 - if initializer_nodes:
113 - sess.run(initializer_nodes)
114 - output_graph_def = graph_util.convert_variables_to_constants(
115 - sess, input_graph_def, output_node_names.split(","))
116 -
117 - with tf.gfile.GFile(output_graph, "wb") as f:
118 - f.write(output_graph_def.SerializeToString())
119 - print("%d ops in the final graph." % len(output_graph_def.node))
120 -
121 -
122 -def main(unused_args):
123 - freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
124 - FLAGS.input_checkpoint, FLAGS.output_node_names,
125 - FLAGS.restore_op_name, FLAGS.filename_tensor_name,
126 - FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
127 -
128 -if __name__ == "__main__":
129 - tf.app.run()
...\ No newline at end of file ...\ No newline at end of file
1 +from tensorflow.python.tools import freeze_graph
2 +
3 +ckpt_filepath = '../../data/pb/pb.ckpt'
4 +pbtxt_filename = 'model.pbtxt'
5 +pbtxt_filepath = '../../data/pb/model.pbtxt'
6 +pb_filepath = '../../data/pb/freeze.pb'
7 +
8 +freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names='yolov3/yolov3_head/feature_map_1,yolov3/yolov3_head/feature_map_2,yolov3/yolov3_head/feature_map_3', restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')
1 +from __future__ import division, print_function
2 +
3 +import tensorflow as tf
4 +import numpy as np
5 +import argparse
6 +import cv2
7 +
8 +from misc_utils import parse_anchors, read_class_names
9 +from nms_utils import gpu_nms
10 +from plot_utils import get_color_table, plot_one_box
11 +from data_utils import letterbox_resize
12 +
13 +from model import yolov3
14 +
15 +parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.")
16 +parser.add_argument("input_image", type=str,
17 + help="The path of the input image.")
18 +parser.add_argument("--anchor_path", type=str, default="../../data/yolo_anchors.txt",
19 + help="The path of the anchor txt file.")
20 +parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416],
21 + help="Resize the input image with `new_size`, size format: [width, height]")
22 +parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True,
23 + help="Whether to use the letterbox resize.")
24 +parser.add_argument("--class_name_path", type=str, default="../../data/classes.txt",
25 + help="The path of the class names.")
26 +parser.add_argument("--restore_path", type=str, default="../../data/darknet_weights/yolov3.ckpt",
27 + help="The path of the weights to restore.")
28 +parser.add_argument("--pb_path", type=str, default="../../data/pb",
29 + help="The directory of pb files")
30 +args = parser.parse_args()
31 +
32 +args.anchors = parse_anchors(args.anchor_path)
33 +args.classes = read_class_names(args.class_name_path)
34 +args.num_class = len(args.classes)
35 +
36 +color_table = get_color_table(args.num_class)
37 +
38 +img_ori = cv2.imread(args.input_image)
39 +if args.letterbox_resize:
40 + img, resize_ratio, dw, dh = letterbox_resize(img_ori, args.new_size[0], args.new_size[1])
41 +else:
42 + height_ori, width_ori = img_ori.shape[:2]
43 + img = cv2.resize(img_ori, tuple(args.new_size))
44 +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
45 +img = np.asarray(img, np.float32)
46 +img = img[np.newaxis, :] / 255.
47 +
48 +graph = tf.Graph()
49 +with tf.Session(graph=graph) as sess:
50 + input_data = tf.placeholder(tf.float32, [1, args.new_size[1], args.new_size[0], 3], name='input_data')
51 + yolo_model = yolov3(args.num_class, args.anchors)
52 + with tf.variable_scope('yolov3'):
53 + pred_feature_maps = yolo_model.forward(input_data, False)
54 + pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps)
55 +
56 + pred_scores = pred_confs * pred_probs
57 +
58 + boxes, scores, labels = gpu_nms(pred_boxes, pred_scores, args.num_class, max_boxes=200, score_thresh=0.3, nms_thresh=0.45)
59 +
60 + saver = tf.train.Saver()
61 + saver.restore(sess, args.restore_path)
62 +
63 + boxes_, scores_, labels_ = sess.run([boxes, scores, labels], feed_dict={input_data: img})
64 +
65 + if args.letterbox_resize:
66 + boxes_[:, [0, 2]] = (boxes_[:, [0, 2]] - dw) / resize_ratio
67 + boxes_[:, [1, 3]] = (boxes_[:, [1, 3]] - dh) / resize_ratio
68 + else:
69 + boxes_[:, [0, 2]] *= (width_ori/float(args.new_size[0]))
70 + boxes_[:, [1, 3]] *= (height_ori/float(args.new_size[1]))
71 +
72 + print("box coords:")
73 + print(boxes_)
74 + print('*' * 30)
75 + print("scores:")
76 + print(scores_)
77 + print('*' * 30)
78 + print("labels:")
79 + print(labels_)
80 +
81 + for i in range(len(boxes_)):
82 + x0, y0, x1, y1 = boxes_[i]
83 + plot_one_box(img_ori, [x0, y0, x1, y1], label=args.classes[labels_[i]] + ', {:.2f}%'.format(scores_[i] * 100), color=color_table[labels_[i]])
84 + cv2.imshow('Detection result', img_ori)
85 + cv2.imwrite('detection_result.jpg', img_ori)
86 + cv2.waitKey(0)
87 +
88 + saver.save(sess, args.pb_path+'/pb.ckpt')
89 + tf.io.write_graph(sess.graph_def, args.pb_path, 'model.pb', as_text=False)
90 + tf.io.write_graph(sess.graph_def, args.pb_path, 'model.pbtxt', as_text=True)
...\ No newline at end of file ...\ No newline at end of file
...@@ -15,15 +15,15 @@ from model import yolov3 ...@@ -15,15 +15,15 @@ from model import yolov3
15 parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.") 15 parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.")
16 parser.add_argument("input_image", type=str, 16 parser.add_argument("input_image", type=str,
17 help="The path of the input image.") 17 help="The path of the input image.")
18 -parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt", 18 +parser.add_argument("--anchor_path", type=str, default="../../data/yolo_anchors.txt",
19 help="The path of the anchor txt file.") 19 help="The path of the anchor txt file.")
20 parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416], 20 parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416],
21 help="Resize the input image with `new_size`, size format: [width, height]") 21 help="Resize the input image with `new_size`, size format: [width, height]")
22 parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True, 22 parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True,
23 help="Whether to use the letterbox resize.") 23 help="Whether to use the letterbox resize.")
24 -parser.add_argument("--class_name_path", type=str, default="./data/coco.names", 24 +parser.add_argument("--class_name_path", type=str, default="../../data/classes.txt",
25 help="The path of the class names.") 25 help="The path of the class names.")
26 -parser.add_argument("--restore_path", type=str, default="./data/darknet_weights/yolov3.ckpt", 26 +parser.add_argument("--restore_path", type=str, default="../../data/darknet_weights/yolov3.ckpt",
27 help="The path of the weights to restore.") 27 help="The path of the weights to restore.")
28 args = parser.parse_args() 28 args = parser.parse_args()
29 29
......
This file is too large to display.