Showing
115 changed files
with
16755 additions
and
0 deletions
android/android/.gitignore
0 → 100644
1 | +# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore | ||
2 | +*.iml | ||
3 | +.idea/compiler.xml | ||
4 | +.idea/copyright | ||
5 | +.idea/dictionaries | ||
6 | +.idea/gradle.xml | ||
7 | +.idea/libraries | ||
8 | +.idea/inspectionProfiles | ||
9 | +.idea/misc.xml | ||
10 | +.idea/modules.xml | ||
11 | +.idea/runConfigurations.xml | ||
12 | +.idea/tasks.xml | ||
13 | +.idea/workspace.xml | ||
14 | +.gradle | ||
15 | +local.properties | ||
16 | +.DS_Store | ||
17 | +build/ | ||
18 | +gradleBuild/ | ||
19 | +*.apk | ||
20 | +*.ap_ | ||
21 | +*.dex | ||
22 | +*.class | ||
23 | +bin/ | ||
24 | +gen/ | ||
25 | +out/ | ||
26 | +*.log | ||
27 | +.navigation/ | ||
28 | +/captures | ||
29 | +.externalNativeBuild |
android/android/.idea/codeStyles/Project.xml
0 → 100644
1 | +<component name="ProjectCodeStyleConfiguration"> | ||
2 | + <code_scheme name="Project" version="173"> | ||
3 | + <codeStyleSettings language="XML"> | ||
4 | + <indentOptions> | ||
5 | + <option name="CONTINUATION_INDENT_SIZE" value="4" /> | ||
6 | + </indentOptions> | ||
7 | + <arrangement> | ||
8 | + <rules> | ||
9 | + <section> | ||
10 | + <rule> | ||
11 | + <match> | ||
12 | + <AND> | ||
13 | + <NAME>xmlns:android</NAME> | ||
14 | + <XML_ATTRIBUTE /> | ||
15 | + <XML_NAMESPACE>^$</XML_NAMESPACE> | ||
16 | + </AND> | ||
17 | + </match> | ||
18 | + </rule> | ||
19 | + </section> | ||
20 | + <section> | ||
21 | + <rule> | ||
22 | + <match> | ||
23 | + <AND> | ||
24 | + <NAME>xmlns:.*</NAME> | ||
25 | + <XML_ATTRIBUTE /> | ||
26 | + <XML_NAMESPACE>^$</XML_NAMESPACE> | ||
27 | + </AND> | ||
28 | + </match> | ||
29 | + <order>BY_NAME</order> | ||
30 | + </rule> | ||
31 | + </section> | ||
32 | + <section> | ||
33 | + <rule> | ||
34 | + <match> | ||
35 | + <AND> | ||
36 | + <NAME>.*:id</NAME> | ||
37 | + <XML_ATTRIBUTE /> | ||
38 | + <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE> | ||
39 | + </AND> | ||
40 | + </match> | ||
41 | + </rule> | ||
42 | + </section> | ||
43 | + <section> | ||
44 | + <rule> | ||
45 | + <match> | ||
46 | + <AND> | ||
47 | + <NAME>.*:name</NAME> | ||
48 | + <XML_ATTRIBUTE /> | ||
49 | + <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE> | ||
50 | + </AND> | ||
51 | + </match> | ||
52 | + </rule> | ||
53 | + </section> | ||
54 | + <section> | ||
55 | + <rule> | ||
56 | + <match> | ||
57 | + <AND> | ||
58 | + <NAME>name</NAME> | ||
59 | + <XML_ATTRIBUTE /> | ||
60 | + <XML_NAMESPACE>^$</XML_NAMESPACE> | ||
61 | + </AND> | ||
62 | + </match> | ||
63 | + </rule> | ||
64 | + </section> | ||
65 | + <section> | ||
66 | + <rule> | ||
67 | + <match> | ||
68 | + <AND> | ||
69 | + <NAME>style</NAME> | ||
70 | + <XML_ATTRIBUTE /> | ||
71 | + <XML_NAMESPACE>^$</XML_NAMESPACE> | ||
72 | + </AND> | ||
73 | + </match> | ||
74 | + </rule> | ||
75 | + </section> | ||
76 | + <section> | ||
77 | + <rule> | ||
78 | + <match> | ||
79 | + <AND> | ||
80 | + <NAME>.*</NAME> | ||
81 | + <XML_ATTRIBUTE /> | ||
82 | + <XML_NAMESPACE>^$</XML_NAMESPACE> | ||
83 | + </AND> | ||
84 | + </match> | ||
85 | + <order>BY_NAME</order> | ||
86 | + </rule> | ||
87 | + </section> | ||
88 | + <section> | ||
89 | + <rule> | ||
90 | + <match> | ||
91 | + <AND> | ||
92 | + <NAME>.*</NAME> | ||
93 | + <XML_ATTRIBUTE /> | ||
94 | + <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE> | ||
95 | + </AND> | ||
96 | + </match> | ||
97 | + <order>ANDROID_ATTRIBUTE_ORDER</order> | ||
98 | + </rule> | ||
99 | + </section> | ||
100 | + <section> | ||
101 | + <rule> | ||
102 | + <match> | ||
103 | + <AND> | ||
104 | + <NAME>.*</NAME> | ||
105 | + <XML_ATTRIBUTE /> | ||
106 | + <XML_NAMESPACE>.*</XML_NAMESPACE> | ||
107 | + </AND> | ||
108 | + </match> | ||
109 | + <order>BY_NAME</order> | ||
110 | + </rule> | ||
111 | + </section> | ||
112 | + </rules> | ||
113 | + </arrangement> | ||
114 | + </codeStyleSettings> | ||
115 | + </code_scheme> | ||
116 | +</component> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
android/android/.idea/vcs.xml
0 → 100644
android/android/AndroidManifest.xml
0 → 100644
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<!-- | ||
3 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
4 | + | ||
5 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | + you may not use this file except in compliance with the License. | ||
7 | + You may obtain a copy of the License at | ||
8 | + | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | + | ||
11 | + Unless required by applicable law or agreed to in writing, software | ||
12 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | + See the License for the specific language governing permissions and | ||
15 | + limitations under the License. | ||
16 | +--> | ||
17 | + | ||
18 | +<manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||
19 | + package="org.tensorflow.demo"> | ||
20 | + | ||
21 | + <uses-permission android:name="android.permission.CAMERA" /> | ||
22 | + <uses-feature android:name="android.hardware.camera" /> | ||
23 | + <uses-feature android:name="android.hardware.camera.autofocus" /> | ||
24 | + <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/> | ||
25 | + <uses-permission android:name="android.permission.RECORD_AUDIO" /> | ||
26 | + | ||
27 | + <application android:allowBackup="true" | ||
28 | + android:debuggable="true" | ||
29 | + android:label="@string/app_name" | ||
30 | + android:icon="@drawable/ic_launcher" | ||
31 | + android:theme="@style/MaterialTheme"> | ||
32 | + | ||
33 | +<!-- <activity android:name="org.tensorflow.demo.ClassifierActivity"--> | ||
34 | +<!-- android:screenOrientation="portrait"--> | ||
35 | +<!-- android:label="@string/activity_name_classification">--> | ||
36 | +<!-- <intent-filter>--> | ||
37 | +<!-- <action android:name="android.intent.action.MAIN" />--> | ||
38 | +<!-- <category android:name="android.intent.category.LAUNCHER" />--> | ||
39 | +<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> | ||
40 | +<!-- </intent-filter>--> | ||
41 | +<!-- </activity>--> | ||
42 | + | ||
43 | + <activity android:name="org.tensorflow.demo.DetectorActivity" | ||
44 | + android:screenOrientation="portrait" | ||
45 | + android:label="@string/activity_name_detection"> | ||
46 | + <intent-filter> | ||
47 | + <action android:name="android.intent.action.MAIN" /> | ||
48 | + <category android:name="android.intent.category.LAUNCHER" /> | ||
49 | + <category android:name="android.intent.category.LEANBACK_LAUNCHER" /> | ||
50 | + </intent-filter> | ||
51 | + </activity> | ||
52 | + | ||
53 | +<!-- <activity android:name="org.tensorflow.demo.StylizeActivity"--> | ||
54 | +<!-- android:screenOrientation="portrait"--> | ||
55 | +<!-- android:label="@string/activity_name_stylize">--> | ||
56 | +<!-- <intent-filter>--> | ||
57 | +<!-- <action android:name="android.intent.action.MAIN" />--> | ||
58 | +<!-- <category android:name="android.intent.category.LAUNCHER" />--> | ||
59 | +<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> | ||
60 | +<!-- </intent-filter>--> | ||
61 | +<!-- </activity>--> | ||
62 | + | ||
63 | +<!-- <activity android:name="org.tensorflow.demo.SpeechActivity"--> | ||
64 | +<!-- android:screenOrientation="portrait"--> | ||
65 | +<!-- android:label="@string/activity_name_speech">--> | ||
66 | +<!-- <intent-filter>--> | ||
67 | +<!-- <action android:name="android.intent.action.MAIN" />--> | ||
68 | +<!-- <category android:name="android.intent.category.LAUNCHER" />--> | ||
69 | +<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />--> | ||
70 | +<!-- </intent-filter>--> | ||
71 | +<!-- </activity>--> | ||
72 | + </application> | ||
73 | + | ||
74 | +</manifest> |
android/android/BUILD
0 → 100644
1 | +# Description: | ||
2 | +# TensorFlow camera demo app for Android. | ||
3 | + | ||
4 | +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") | ||
5 | +load( | ||
6 | + "//tensorflow:tensorflow.bzl", | ||
7 | + "tf_copts", | ||
8 | +) | ||
9 | + | ||
10 | +package( | ||
11 | + default_visibility = ["//visibility:public"], | ||
12 | + licenses = ["notice"], # Apache 2.0 | ||
13 | +) | ||
14 | + | ||
15 | +exports_files(["LICENSE"]) | ||
16 | + | ||
17 | +LINKER_SCRIPT = "jni/version_script.lds" | ||
18 | + | ||
19 | +# libtensorflow_demo.so contains the native code for image colorspace conversion | ||
20 | +# and object tracking used by the demo. It does not require TF as a dependency | ||
21 | +# to build if STANDALONE_DEMO_LIB is defined. | ||
22 | +# TF support for the demo is provided separately by libtensorflow_inference.so. | ||
23 | +cc_binary( | ||
24 | + name = "libtensorflow_demo.so", | ||
25 | + srcs = glob([ | ||
26 | + "jni/**/*.cc", | ||
27 | + "jni/**/*.h", | ||
28 | + ]), | ||
29 | + copts = tf_copts(), | ||
30 | + defines = ["STANDALONE_DEMO_LIB"], | ||
31 | + linkopts = [ | ||
32 | + "-landroid", | ||
33 | + "-ldl", | ||
34 | + "-ljnigraphics", | ||
35 | + "-llog", | ||
36 | + "-lm", | ||
37 | + "-z defs", | ||
38 | + "-s", | ||
39 | + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT), | ||
40 | + ], | ||
41 | + linkshared = 1, | ||
42 | + linkstatic = 1, | ||
43 | + tags = [ | ||
44 | + "manual", | ||
45 | + "notap", | ||
46 | + ], | ||
47 | + deps = [ | ||
48 | + LINKER_SCRIPT, | ||
49 | + ], | ||
50 | +) | ||
51 | + | ||
52 | +cc_library( | ||
53 | + name = "tensorflow_native_libs", | ||
54 | + srcs = [ | ||
55 | + ":libtensorflow_demo.so", | ||
56 | + "//tensorflow/tools/android/inference_interface:libtensorflow_inference.so", | ||
57 | + ], | ||
58 | + tags = [ | ||
59 | + "manual", | ||
60 | + "notap", | ||
61 | + ], | ||
62 | +) | ||
63 | + | ||
64 | +android_binary( | ||
65 | + name = "tensorflow_demo", | ||
66 | + srcs = glob([ | ||
67 | + "src/**/*.java", | ||
68 | + ]), | ||
69 | + # Package assets from assets dir as well as all model targets. Remove undesired models | ||
70 | + # (and corresponding Activities in source) to reduce APK size. | ||
71 | + assets = [ | ||
72 | + "//tensorflow/examples/android/assets:asset_files", | ||
73 | + ":external_assets", | ||
74 | + ], | ||
75 | + assets_dir = "", | ||
76 | + custom_package = "org.tensorflow.demo", | ||
77 | + manifest = "AndroidManifest.xml", | ||
78 | + resource_files = glob(["res/**"]), | ||
79 | + tags = [ | ||
80 | + "manual", | ||
81 | + "notap", | ||
82 | + ], | ||
83 | + deps = [ | ||
84 | + ":tensorflow_native_libs", | ||
85 | + "//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java", | ||
86 | + ], | ||
87 | +) | ||
88 | + | ||
89 | +# LINT.IfChange | ||
90 | +filegroup( | ||
91 | + name = "external_assets", | ||
92 | + srcs = [ | ||
93 | + "@inception_v1//:model_files", | ||
94 | + "@mobile_ssd//:model_files", | ||
95 | + "@speech_commands//:model_files", | ||
96 | + "@stylize//:model_files", | ||
97 | + ], | ||
98 | +) | ||
99 | +# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle) | ||
100 | + | ||
101 | +filegroup( | ||
102 | + name = "java_files", | ||
103 | + srcs = glob(["src/**/*.java"]), | ||
104 | +) | ||
105 | + | ||
106 | +filegroup( | ||
107 | + name = "jni_files", | ||
108 | + srcs = glob([ | ||
109 | + "jni/**/*.cc", | ||
110 | + "jni/**/*.h", | ||
111 | + ]), | ||
112 | +) | ||
113 | + | ||
114 | +filegroup( | ||
115 | + name = "resource_files", | ||
116 | + srcs = glob(["res/**"]), | ||
117 | +) | ||
118 | + | ||
119 | +exports_files([ | ||
120 | + "AndroidManifest.xml", | ||
121 | +]) |
android/android/README.md
0 → 100644
1 | +# TensorFlow Android Camera Demo | ||
2 | + | ||
3 | +This folder contains an example application utilizing TensorFlow for Android | ||
4 | +devices. | ||
5 | + | ||
6 | +## Description | ||
7 | + | ||
8 | +The demos in this folder are designed to give straightforward samples of using | ||
9 | +TensorFlow in mobile applications. | ||
10 | + | ||
11 | +Inference is done using the [TensorFlow Android Inference | ||
12 | +Interface](../../tools/android/inference_interface), which may be built | ||
13 | +separately if you want a standalone library to drop into your existing | ||
14 | +application. Object tracking and efficient YUV -> RGB conversion are handled by | ||
15 | +`libtensorflow_demo.so`. | ||
16 | + | ||
17 | +A device running Android 5.0 (API 21) or higher is required to run the demo due | ||
18 | +to the use of the camera2 API, although the native libraries themselves can run | ||
19 | +on API >= 14 devices. | ||
20 | + | ||
21 | +## Current samples: | ||
22 | + | ||
23 | +1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java): | ||
24 | + Uses the [Google Inception](https://arxiv.org/abs/1409.4842) | ||
25 | + model to classify camera frames in real-time, displaying the top results | ||
26 | + in an overlay on the camera image. | ||
27 | +2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java): | ||
28 | + Demonstrates an SSD-Mobilenet model trained using the | ||
29 | + [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection/) | ||
30 | + introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to | ||
31 | + localize and track objects (from 80 categories) in the camera preview | ||
32 | + in real-time. | ||
33 | +3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java): | ||
34 | + Uses a model based on [A Learned Representation For Artistic | ||
35 | + Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview | ||
36 | + image to that of a number of different artists. | ||
37 | +4. [TF | ||
38 | + Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java): | ||
39 | + Runs a simple speech recognition model built by the [audio training | ||
40 | + tutorial](https://www.tensorflow.org/versions/master/tutorials/audio_recognition). Listens | ||
41 | + for a small set of words, and highlights them in the UI when they are | ||
42 | + recognized. | ||
43 | + | ||
44 | +<img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%"> | ||
45 | + | ||
46 | +## Prebuilt Components: | ||
47 | + | ||
48 | +The fastest path to trying the demo is to download the [prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). | ||
49 | + | ||
50 | +Also available are precompiled native libraries, and a jcenter package that you | ||
51 | +may simply drop into your own applications. See | ||
52 | +[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md) | ||
53 | +for more details. | ||
54 | + | ||
55 | +## Running the Demo | ||
56 | + | ||
57 | +Once the app is installed it can be started via the "TF Classify", "TF Detect", | ||
58 | +"TF Stylize", and "TF Speech" icons, which have the orange TensorFlow logo as | ||
59 | +their icon. | ||
60 | + | ||
61 | +While running the activities, pressing the volume keys on your device will | ||
62 | +toggle debug visualizations on/off, rendering additional info to the screen that | ||
63 | +may be useful for development purposes. | ||
64 | + | ||
65 | +## Building in Android Studio using the TensorFlow AAR from JCenter | ||
66 | + | ||
67 | +The simplest way to compile the demo app yourself, and try out changes to the | ||
68 | +project code is to use AndroidStudio. Simply set this `android` directory as the | ||
69 | +project root. | ||
70 | + | ||
71 | +Then edit the `build.gradle` file and change the value of `nativeBuildSystem` to | ||
72 | +`'none'` so that the project is built in the simplest way possible: | ||
73 | + | ||
74 | +```None | ||
75 | +def nativeBuildSystem = 'none' | ||
76 | +``` | ||
77 | + | ||
78 | +While this project includes full build integration for TensorFlow, this setting | ||
79 | +disables it, and uses the TensorFlow Inference Interface package from JCenter. | ||
80 | + | ||
81 | +Note: Currently, in this build mode, YUV -> RGB is done using a less efficient | ||
82 | +Java implementation, and object tracking is not available in the "TF Detect" | ||
83 | +activity. Setting the build system to `'cmake'` currently only builds | ||
84 | +`libtensorflow_demo.so`, which provides fast YUV -> RGB conversion and object | ||
85 | +tracking, while still acquiring TensorFlow support via the downloaded AAR, so it | ||
86 | +may be a lightweight way to enable these features. | ||
87 | + | ||
88 | +For any project that does not include custom low level TensorFlow code, this is | ||
89 | +likely sufficient. | ||
90 | + | ||
91 | +For details on how to include this JCenter package in your own project see | ||
92 | +[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md) | ||
93 | + | ||
94 | +## Building the Demo with TensorFlow from Source | ||
95 | + | ||
96 | +Pick your preferred approach below. At the moment, we have full support for | ||
97 | +Bazel, and partial support for gradle, cmake, make, and Android Studio. | ||
98 | + | ||
99 | +As a first step for all build types, clone the TensorFlow repo with: | ||
100 | + | ||
101 | +``` | ||
102 | +git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git | ||
103 | +``` | ||
104 | + | ||
105 | +Note that `--recurse-submodules` is necessary to prevent some issues with | ||
106 | +protobuf compilation. | ||
107 | + | ||
108 | +### Bazel | ||
109 | + | ||
110 | +NOTE: Bazel does not currently support building for Android on Windows. Full | ||
111 | +support for gradle/cmake builds is coming soon, but in the meantime we suggest | ||
112 | +that Windows users download the | ||
113 | +[prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) | ||
114 | +instead. | ||
115 | + | ||
116 | +##### Install Bazel and Android Prerequisites | ||
117 | + | ||
118 | +Bazel is the primary build system for TensorFlow. To build with Bazel, it and | ||
119 | +the Android NDK and SDK must be installed on your system. | ||
120 | + | ||
121 | +1. Install the latest version of Bazel as per the instructions [on the Bazel | ||
122 | + website](https://bazel.build/versions/master/docs/install.html). | ||
123 | +2. The Android NDK is required to build the native (C/C++) TensorFlow code. The | ||
124 | + current recommended version is 14b, which may be found | ||
125 | + [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). | ||
126 | +3. The Android SDK and build tools may be obtained | ||
127 | + [here](https://developer.android.com/tools/revisions/build-tools.html), or | ||
128 | + alternatively as part of [Android | ||
129 | + Studio](https://developer.android.com/studio/index.html). Build tools API >= | ||
130 | + 23 is required to build the TF Android demo (though it will run on API >= 21 | ||
131 | + devices). | ||
132 | + | ||
133 | +##### Edit WORKSPACE | ||
134 | + | ||
135 | +NOTE: As long as you have the SDK and NDK installed, the `./configure` script | ||
136 | +will create these rules for you. Answer "Yes" when the script asks to | ||
137 | +automatically configure the `./WORKSPACE`. | ||
138 | + | ||
139 | +The Android entries in | ||
140 | +[`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L19-L36) must be uncommented | ||
141 | +with the paths filled in appropriately depending on where you installed the NDK | ||
142 | +and SDK. Otherwise an error such as: "The external label | ||
143 | +'//external:android/sdk' is not bound to anything" will be reported. | ||
144 | + | ||
145 | +Also edit the API levels for the SDK in WORKSPACE to the highest level you have | ||
146 | +installed in your SDK. This must be >= 23 (this is completely independent of the | ||
147 | +API level of the demo, which is defined in AndroidManifest.xml). The NDK API | ||
148 | +level may remain at 14. | ||
149 | + | ||
150 | +##### Install Model Files (optional) | ||
151 | + | ||
152 | +The TensorFlow `GraphDef`s that contain the model definitions and weights are | ||
153 | +not packaged in the repo because of their size. They are downloaded | ||
154 | +automatically and packaged with the APK by Bazel via a new_http_archive defined | ||
155 | +in `WORKSPACE` during the build process, and by Gradle via | ||
156 | +download-models.gradle. | ||
157 | + | ||
158 | +**Optional**: If you wish to place the models in your assets manually, remove | ||
159 | +all of the `model_files` entries from the `assets` list in `tensorflow_demo` | ||
160 | +found in the [`BUILD`](BUILD#L92) file. Then download and extract the archives | ||
161 | +yourself to the `assets` directory in the source tree: | ||
162 | + | ||
163 | +```bash | ||
164 | +BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models | ||
165 | +for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip | ||
166 | +do | ||
167 | + curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP} | ||
168 | + unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/ | ||
169 | +done | ||
170 | +``` | ||
171 | + | ||
172 | +This will extract the models and their associated metadata files to the local | ||
173 | +assets/ directory. | ||
174 | + | ||
175 | +If you are using Gradle, make sure to remove download-models.gradle reference | ||
176 | +from build.gradle after your manually download models; otherwise gradle might | ||
177 | +download models again and overwrite your models. | ||
178 | + | ||
179 | +##### Build | ||
180 | + | ||
181 | +After editing your WORKSPACE file to update the SDK/NDK configuration, you may | ||
182 | +build the APK. Run this from your workspace root: | ||
183 | + | ||
184 | +```bash | ||
185 | +bazel build --cxxopt='--std=c++11' -c opt //tensorflow/examples/android:tensorflow_demo | ||
186 | +``` | ||
187 | + | ||
188 | +##### Install | ||
189 | + | ||
190 | +Make sure that adb debugging is enabled on your Android 5.0 (API 21) or later | ||
191 | +device, then after building use the following command from your workspace root | ||
192 | +to install the APK: | ||
193 | + | ||
194 | +```bash | ||
195 | +adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk | ||
196 | +``` | ||
197 | + | ||
198 | +### Android Studio with Bazel | ||
199 | + | ||
200 | +Android Studio may be used to build the demo in conjunction with Bazel. First, | ||
201 | +make sure that you can build with Bazel following the above directions. Then, | ||
202 | +look at [build.gradle](build.gradle) and make sure that the path to Bazel | ||
203 | +matches that of your system. | ||
204 | + | ||
205 | +At this point you can add the tensorflow/examples/android directory as a new | ||
206 | +Android Studio project. Click through installing all the Gradle extensions it | ||
207 | +requests, and you should be able to have Android Studio build the demo like any | ||
208 | +other application (it will call out to Bazel to build the native code with the | ||
209 | +NDK). | ||
210 | + | ||
211 | +### CMake | ||
212 | + | ||
213 | +Full CMake support for the demo is coming soon, but for now it is possible to | ||
214 | +build the TensorFlow Android Inference library using | ||
215 | +[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake). |
android/android/__init__.py
0 → 100644
File mode changed
android/android/assets/BUILD
0 → 100644
1 | +package( | ||
2 | + default_visibility = ["//visibility:public"], | ||
3 | + licenses = ["notice"], # Apache 2.0 | ||
4 | +) | ||
5 | + | ||
6 | +# It is necessary to use this filegroup rather than globbing the files in this | ||
7 | +# folder directly the examples/android:tensorflow_demo target due to the fact | ||
8 | +# that assets_dir is necessarily set to "" there (to allow using other | ||
9 | +# arbitrary targets as assets). | ||
10 | +filegroup( | ||
11 | + name = "asset_files", | ||
12 | + srcs = glob( | ||
13 | + ["**/*"], | ||
14 | + exclude = ["BUILD"], | ||
15 | + ), | ||
16 | +) |
android/android/assets/yolov3.pb
0 → 100644
This file is too large to display.
android/android/build.gradle
0 → 100644
1 | +// This file provides basic support for building the TensorFlow demo | ||
2 | +// in Android Studio with Gradle. | ||
3 | +// | ||
4 | +// Note that Bazel is still used by default to compile the native libs, | ||
5 | +// and should be installed at the location noted below. This build file | ||
6 | +// automates the process of calling out to it and copying the compiled | ||
7 | +// libraries back into the appropriate directory. | ||
8 | +// | ||
9 | +// Alternatively, experimental support for Makefile builds is provided by | ||
10 | +// setting nativeBuildSystem below to 'makefile'. This will allow building the demo | ||
11 | +// on Windows machines, but note that full equivalence with the Bazel | ||
12 | +// build is not yet guaranteed. See comments below for caveats and tips | ||
13 | +// for speeding up the build, such as enabling ccache. | ||
14 | +// NOTE: Running a make build will cause subsequent Bazel builds to *fail* | ||
15 | +// unless the contrib/makefile/downloads/ and gen/ dirs are deleted afterwards. | ||
16 | + | ||
17 | +// The cmake build only creates libtensorflow_demo.so. In this situation, | ||
18 | +// libtensorflow_inference.so will be acquired via the tensorflow.aar dependency. | ||
19 | + | ||
20 | +// It is necessary to customize Gradle's build directory, as otherwise | ||
21 | +// it will conflict with the BUILD file used by Bazel on case-insensitive OSs. | ||
22 | +project.buildDir = 'gradleBuild' | ||
23 | +getProject().setBuildDir('gradleBuild') | ||
24 | + | ||
25 | +buildscript { | ||
26 | + repositories { | ||
27 | + jcenter() | ||
28 | + google() | ||
29 | + } | ||
30 | + | ||
31 | + dependencies { | ||
32 | + classpath 'com.android.tools.build:gradle:3.3.1' | ||
33 | + classpath 'org.apache.httpcomponents:httpclient:4.5.4' | ||
34 | + } | ||
35 | +} | ||
36 | + | ||
37 | +allprojects { | ||
38 | + repositories { | ||
39 | + jcenter() | ||
40 | + google() | ||
41 | + } | ||
42 | +} | ||
43 | + | ||
44 | +// set to 'bazel', 'cmake', 'makefile', 'none' | ||
45 | +def nativeBuildSystem = 'none' | ||
46 | + | ||
47 | +// Controls output directory in APK and CPU type for Bazel builds. | ||
48 | +// NOTE: Does not affect the Makefile build target API (yet), which currently | ||
49 | +// assumes armeabi-v7a. If building with make, changing this will require | ||
50 | +// editing the Makefile as well. | ||
51 | +// The CMake build has only been tested with armeabi-v7a; others may not work. | ||
52 | +def cpuType = 'armeabi-v7a' | ||
53 | + | ||
54 | +// Output directory in the local directory for packaging into the APK. | ||
55 | +def nativeOutDir = 'libs/' + cpuType | ||
56 | + | ||
57 | +// Default to building with Bazel and override with make if requested. | ||
58 | +def nativeBuildRule = 'buildNativeBazel' | ||
59 | +def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so' | ||
60 | +def inferenceLibPath = '../../../bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so' | ||
61 | + | ||
62 | +// Override for Makefile builds. | ||
63 | +if (nativeBuildSystem == 'makefile') { | ||
64 | + nativeBuildRule = 'buildNativeMake' | ||
65 | + demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_demo.so' | ||
66 | + inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_inference.so' | ||
67 | +} | ||
68 | + | ||
69 | +// If building with Bazel, this is the location of the bazel binary. | ||
70 | +// NOTE: Bazel does not yet support building for Android on Windows, | ||
71 | +// so in this case the Makefile build must be used as described above. | ||
72 | +def bazelLocation = '/usr/local/bin/bazel' | ||
73 | + | ||
74 | +// import DownloadModels task | ||
75 | +project.ext.ASSET_DIR = projectDir.toString() + '/assets' | ||
76 | +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' | ||
77 | + | ||
78 | +// Download default models; if you wish to use your own models then | ||
79 | +// place them in the "assets" directory and comment out this line. | ||
80 | +apply from: "download-models.gradle" | ||
81 | + | ||
82 | +apply plugin: 'com.android.application' | ||
83 | + | ||
84 | +android { | ||
85 | + compileSdkVersion 23 | ||
86 | + | ||
87 | + if (nativeBuildSystem == 'cmake') { | ||
88 | + defaultConfig { | ||
89 | + applicationId = 'org.tensorflow.demo' | ||
90 | + minSdkVersion 21 | ||
91 | + targetSdkVersion 23 | ||
92 | + ndk { | ||
93 | + abiFilters "${cpuType}" | ||
94 | + } | ||
95 | + externalNativeBuild { | ||
96 | + cmake { | ||
97 | + arguments '-DANDROID_STL=c++_static' | ||
98 | + } | ||
99 | + } | ||
100 | + } | ||
101 | + externalNativeBuild { | ||
102 | + cmake { | ||
103 | + path './jni/CMakeLists.txt' | ||
104 | + } | ||
105 | + } | ||
106 | + } | ||
107 | + | ||
108 | + lintOptions { | ||
109 | + abortOnError false | ||
110 | + } | ||
111 | + | ||
112 | + sourceSets { | ||
113 | + main { | ||
114 | + if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') { | ||
115 | + // TensorFlow Java API sources. | ||
116 | + java { | ||
117 | + srcDir '../../java/src/main/java' | ||
118 | + exclude '**/examples/**' | ||
119 | + } | ||
120 | + | ||
121 | + // Android TensorFlow wrappers, etc. | ||
122 | + java { | ||
123 | + srcDir '../../tools/android/inference_interface/java' | ||
124 | + } | ||
125 | + } | ||
126 | + // Android demo app sources. | ||
127 | + java { | ||
128 | + srcDir 'src' | ||
129 | + } | ||
130 | + | ||
131 | + manifest.srcFile 'AndroidManifest.xml' | ||
132 | + resources.srcDirs = ['src'] | ||
133 | + aidl.srcDirs = ['src'] | ||
134 | + renderscript.srcDirs = ['src'] | ||
135 | + res.srcDirs = ['res'] | ||
136 | + assets.srcDirs = [project.ext.ASSET_DIR] | ||
137 | + jniLibs.srcDirs = ['libs'] | ||
138 | + } | ||
139 | + | ||
140 | + debug.setRoot('build-types/debug') | ||
141 | + release.setRoot('build-types/release') | ||
142 | + } | ||
143 | + defaultConfig { | ||
144 | + targetSdkVersion 23 | ||
145 | + minSdkVersion 21 | ||
146 | + } | ||
147 | +} | ||
148 | + | ||
149 | +task buildNativeBazel(type: Exec) { | ||
150 | + workingDir '../../..' | ||
151 | + commandLine bazelLocation, 'build', '-c', 'opt', \ | ||
152 | + 'tensorflow/examples/android:tensorflow_native_libs', \ | ||
153 | + '--crosstool_top=//external:android/crosstool', \ | ||
154 | + '--cpu=' + cpuType, \ | ||
155 | + '--host_crosstool_top=@bazel_tools//tools/cpp:toolchain' | ||
156 | +} | ||
157 | + | ||
158 | +task buildNativeMake(type: Exec) { | ||
159 | + environment "NDK_ROOT", android.ndkDirectory | ||
160 | + // Tip: install ccache and uncomment the following to speed up | ||
161 | + // builds significantly. | ||
162 | + // environment "CC_PREFIX", 'ccache' | ||
163 | + workingDir '../../..' | ||
164 | + commandLine 'tensorflow/contrib/makefile/build_all_android.sh', \ | ||
165 | + '-s', \ | ||
166 | + 'tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in', \ | ||
167 | + '-t', \ | ||
168 | + 'libtensorflow_inference.so libtensorflow_demo.so all' \ | ||
169 | + , '-a', cpuType \ | ||
170 | + //, '-T' // Uncomment to skip protobuf and speed up subsequent builds. | ||
171 | +} | ||
172 | + | ||
173 | + | ||
174 | +task copyNativeLibs(type: Copy) { | ||
175 | + from demoLibPath | ||
176 | + from inferenceLibPath | ||
177 | + into nativeOutDir | ||
178 | + duplicatesStrategy = 'include' | ||
179 | + dependsOn nativeBuildRule | ||
180 | + fileMode 0644 | ||
181 | +} | ||
182 | + | ||
183 | +tasks.whenTaskAdded { task -> | ||
184 | + if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') { | ||
185 | + if (task.name == 'assembleDebug') { | ||
186 | + task.dependsOn 'copyNativeLibs' | ||
187 | + } | ||
188 | + if (task.name == 'assembleRelease') { | ||
189 | + task.dependsOn 'copyNativeLibs' | ||
190 | + } | ||
191 | + } | ||
192 | +} | ||
193 | + | ||
194 | +dependencies { | ||
195 | + if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') { | ||
196 | + implementation 'org.tensorflow:tensorflow-android:+' | ||
197 | + } | ||
198 | +} |
android/android/download-models.gradle
0 → 100644
1 | +/* | ||
2 | + * download-models.gradle | ||
3 | + * Downloads model files from ${MODEL_URL} into application's asset folder | ||
4 | + * Input: | ||
5 | + * project.ext.TMP_DIR: absolute path to hold downloaded zip files | ||
6 | + * project.ext.ASSET_DIR: absolute path to save unzipped model files | ||
7 | + * Output: | ||
8 | + * 3 model files will be downloaded into given folder of ext.ASSET_DIR | ||
9 | + */ | ||
10 | +// hard coded model files | ||
11 | +// LINT.IfChange | ||
12 | +def models = ['inception_v1.zip', | ||
13 | + 'object_detection/ssd_mobilenet_v1_android_export.zip', | ||
14 | + 'stylize_v1.zip', | ||
15 | + 'speech_commands_conv_actions.zip'] | ||
16 | +// LINT.ThenChange(//tensorflow/examples/android/BUILD) | ||
17 | + | ||
18 | +// Root URL for model archives | ||
19 | +def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models' | ||
20 | + | ||
21 | +buildscript { | ||
22 | + repositories { | ||
23 | + jcenter() | ||
24 | + } | ||
25 | + dependencies { | ||
26 | + classpath 'de.undercouch:gradle-download-task:3.2.0' | ||
27 | + } | ||
28 | +} | ||
29 | + | ||
30 | +import de.undercouch.gradle.tasks.download.Download | ||
31 | +task downloadFile(type: Download){ | ||
32 | + for (f in models) { | ||
33 | + src "${MODEL_URL}/" + f | ||
34 | + } | ||
35 | + dest new File(project.ext.TMP_DIR) | ||
36 | + overwrite true | ||
37 | +} | ||
38 | + | ||
39 | +task extractModels(type: Copy) { | ||
40 | + for (f in models) { | ||
41 | + def localFile = f.split("/")[-1] | ||
42 | + from zipTree(project.ext.TMP_DIR + '/' + localFile) | ||
43 | + } | ||
44 | + | ||
45 | + into file(project.ext.ASSET_DIR) | ||
46 | + fileMode 0644 | ||
47 | + exclude '**/LICENSE' | ||
48 | + | ||
49 | + def needDownload = false | ||
50 | + for (f in models) { | ||
51 | + def localFile = f.split("/")[-1] | ||
52 | + if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) { | ||
53 | + needDownload = true | ||
54 | + } | ||
55 | + } | ||
56 | + | ||
57 | + if (needDownload) { | ||
58 | + dependsOn downloadFile | ||
59 | + } | ||
60 | +} | ||
61 | + | ||
62 | +tasks.whenTaskAdded { task -> | ||
63 | + if (task.name == 'assembleDebug') { | ||
64 | + task.dependsOn 'extractModels' | ||
65 | + } | ||
66 | + if (task.name == 'assembleRelease') { | ||
67 | + task.dependsOn 'extractModels' | ||
68 | + } | ||
69 | +} | ||
70 | + |
No preview for this file type
android/android/gradlew
0 → 100644
1 | +#!/usr/bin/env bash | ||
2 | + | ||
3 | +############################################################################## | ||
4 | +## | ||
5 | +## Gradle start up script for UN*X | ||
6 | +## | ||
7 | +############################################################################## | ||
8 | + | ||
9 | +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. | ||
10 | +DEFAULT_JVM_OPTS="" | ||
11 | + | ||
12 | +APP_NAME="Gradle" | ||
13 | +APP_BASE_NAME=`basename "$0"` | ||
14 | + | ||
15 | +# Use the maximum available, or set MAX_FD != -1 to use that value. | ||
16 | +MAX_FD="maximum" | ||
17 | + | ||
18 | +warn ( ) { | ||
19 | + echo "$*" | ||
20 | +} | ||
21 | + | ||
22 | +die ( ) { | ||
23 | + echo | ||
24 | + echo "$*" | ||
25 | + echo | ||
26 | + exit 1 | ||
27 | +} | ||
28 | + | ||
29 | +# OS specific support (must be 'true' or 'false'). | ||
30 | +cygwin=false | ||
31 | +msys=false | ||
32 | +darwin=false | ||
33 | +case "`uname`" in | ||
34 | + CYGWIN* ) | ||
35 | + cygwin=true | ||
36 | + ;; | ||
37 | + Darwin* ) | ||
38 | + darwin=true | ||
39 | + ;; | ||
40 | + MINGW* ) | ||
41 | + msys=true | ||
42 | + ;; | ||
43 | +esac | ||
44 | + | ||
45 | +# Attempt to set APP_HOME | ||
46 | +# Resolve links: $0 may be a link | ||
47 | +PRG="$0" | ||
48 | +# Need this for relative symlinks. | ||
49 | +while [ -h "$PRG" ] ; do | ||
50 | + ls=`ls -ld "$PRG"` | ||
51 | + link=`expr "$ls" : '.*-> \(.*\)$'` | ||
52 | + if expr "$link" : '/.*' > /dev/null; then | ||
53 | + PRG="$link" | ||
54 | + else | ||
55 | + PRG=`dirname "$PRG"`"/$link" | ||
56 | + fi | ||
57 | +done | ||
58 | +SAVED="`pwd`" | ||
59 | +cd "`dirname \"$PRG\"`/" >/dev/null | ||
60 | +APP_HOME="`pwd -P`" | ||
61 | +cd "$SAVED" >/dev/null | ||
62 | + | ||
63 | +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar | ||
64 | + | ||
65 | +# Determine the Java command to use to start the JVM. | ||
66 | +if [ -n "$JAVA_HOME" ] ; then | ||
67 | + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then | ||
68 | + # IBM's JDK on AIX uses strange locations for the executables | ||
69 | + JAVACMD="$JAVA_HOME/jre/sh/java" | ||
70 | + else | ||
71 | + JAVACMD="$JAVA_HOME/bin/java" | ||
72 | + fi | ||
73 | + if [ ! -x "$JAVACMD" ] ; then | ||
74 | + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME | ||
75 | + | ||
76 | +Please set the JAVA_HOME variable in your environment to match the | ||
77 | +location of your Java installation." | ||
78 | + fi | ||
79 | +else | ||
80 | + JAVACMD="java" | ||
81 | + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. | ||
82 | + | ||
83 | +Please set the JAVA_HOME variable in your environment to match the | ||
84 | +location of your Java installation." | ||
85 | +fi | ||
86 | + | ||
87 | +# Increase the maximum file descriptors if we can. | ||
88 | +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then | ||
89 | + MAX_FD_LIMIT=`ulimit -H -n` | ||
90 | + if [ $? -eq 0 ] ; then | ||
91 | + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then | ||
92 | + MAX_FD="$MAX_FD_LIMIT" | ||
93 | + fi | ||
94 | + ulimit -n $MAX_FD | ||
95 | + if [ $? -ne 0 ] ; then | ||
96 | + warn "Could not set maximum file descriptor limit: $MAX_FD" | ||
97 | + fi | ||
98 | + else | ||
99 | + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" | ||
100 | + fi | ||
101 | +fi | ||
102 | + | ||
103 | +# For Darwin, add options to specify how the application appears in the dock | ||
104 | +if $darwin; then | ||
105 | + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" | ||
106 | +fi | ||
107 | + | ||
108 | +# For Cygwin, switch paths to Windows format before running java | ||
109 | +if $cygwin ; then | ||
110 | + APP_HOME=`cygpath --path --mixed "$APP_HOME"` | ||
111 | + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` | ||
112 | + JAVACMD=`cygpath --unix "$JAVACMD"` | ||
113 | + | ||
114 | + # We build the pattern for arguments to be converted via cygpath | ||
115 | + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` | ||
116 | + SEP="" | ||
117 | + for dir in $ROOTDIRSRAW ; do | ||
118 | + ROOTDIRS="$ROOTDIRS$SEP$dir" | ||
119 | + SEP="|" | ||
120 | + done | ||
121 | + OURCYGPATTERN="(^($ROOTDIRS))" | ||
122 | + # Add a user-defined pattern to the cygpath arguments | ||
123 | + if [ "$GRADLE_CYGPATTERN" != "" ] ; then | ||
124 | + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" | ||
125 | + fi | ||
126 | + # Now convert the arguments - kludge to limit ourselves to /bin/sh | ||
127 | + i=0 | ||
128 | + for arg in "$@" ; do | ||
129 | + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` | ||
130 | + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option | ||
131 | + | ||
132 | + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition | ||
133 | + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` | ||
134 | + else | ||
135 | + eval `echo args$i`="\"$arg\"" | ||
136 | + fi | ||
137 | + i=$((i+1)) | ||
138 | + done | ||
139 | + case $i in | ||
140 | + (0) set -- ;; | ||
141 | + (1) set -- "$args0" ;; | ||
142 | + (2) set -- "$args0" "$args1" ;; | ||
143 | + (3) set -- "$args0" "$args1" "$args2" ;; | ||
144 | + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; | ||
145 | + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; | ||
146 | + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; | ||
147 | + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; | ||
148 | + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; | ||
149 | + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; | ||
150 | + esac | ||
151 | +fi | ||
152 | + | ||
153 | +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules | ||
154 | +function splitJvmOpts() { | ||
155 | + JVM_OPTS=("$@") | ||
156 | +} | ||
157 | +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS | ||
158 | +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" | ||
159 | + | ||
160 | +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" |
android/android/gradlew.bat
0 → 100644
1 | +@if "%DEBUG%" == "" @echo off | ||
2 | +@rem ########################################################################## | ||
3 | +@rem | ||
4 | +@rem Gradle startup script for Windows | ||
5 | +@rem | ||
6 | +@rem ########################################################################## | ||
7 | + | ||
8 | +@rem Set local scope for the variables with windows NT shell | ||
9 | +if "%OS%"=="Windows_NT" setlocal | ||
10 | + | ||
11 | +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. | ||
12 | +set DEFAULT_JVM_OPTS= | ||
13 | + | ||
14 | +set DIRNAME=%~dp0 | ||
15 | +if "%DIRNAME%" == "" set DIRNAME=. | ||
16 | +set APP_BASE_NAME=%~n0 | ||
17 | +set APP_HOME=%DIRNAME% | ||
18 | + | ||
19 | +@rem Find java.exe | ||
20 | +if defined JAVA_HOME goto findJavaFromJavaHome | ||
21 | + | ||
22 | +set JAVA_EXE=java.exe | ||
23 | +%JAVA_EXE% -version >NUL 2>&1 | ||
24 | +if "%ERRORLEVEL%" == "0" goto init | ||
25 | + | ||
26 | +echo. | ||
27 | +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. | ||
28 | +echo. | ||
29 | +echo Please set the JAVA_HOME variable in your environment to match the | ||
30 | +echo location of your Java installation. | ||
31 | + | ||
32 | +goto fail | ||
33 | + | ||
34 | +:findJavaFromJavaHome | ||
35 | +set JAVA_HOME=%JAVA_HOME:"=% | ||
36 | +set JAVA_EXE=%JAVA_HOME%/bin/java.exe | ||
37 | + | ||
38 | +if exist "%JAVA_EXE%" goto init | ||
39 | + | ||
40 | +echo. | ||
41 | +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% | ||
42 | +echo. | ||
43 | +echo Please set the JAVA_HOME variable in your environment to match the | ||
44 | +echo location of your Java installation. | ||
45 | + | ||
46 | +goto fail | ||
47 | + | ||
48 | +:init | ||
49 | +@rem Get command-line arguments, handling Windowz variants | ||
50 | + | ||
51 | +if not "%OS%" == "Windows_NT" goto win9xME_args | ||
52 | +if "%@eval[2+2]" == "4" goto 4NT_args | ||
53 | + | ||
54 | +:win9xME_args | ||
55 | +@rem Slurp the command line arguments. | ||
56 | +set CMD_LINE_ARGS= | ||
57 | +set _SKIP=2 | ||
58 | + | ||
59 | +:win9xME_args_slurp | ||
60 | +if "x%~1" == "x" goto execute | ||
61 | + | ||
62 | +set CMD_LINE_ARGS=%* | ||
63 | +goto execute | ||
64 | + | ||
65 | +:4NT_args | ||
66 | +@rem Get arguments from the 4NT Shell from JP Software | ||
67 | +set CMD_LINE_ARGS=%$ | ||
68 | + | ||
69 | +:execute | ||
70 | +@rem Setup the command line | ||
71 | + | ||
72 | +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar | ||
73 | + | ||
74 | +@rem Execute Gradle | ||
75 | +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% | ||
76 | + | ||
77 | +:end | ||
78 | +@rem End local scope for the variables with windows NT shell | ||
79 | +if "%ERRORLEVEL%"=="0" goto mainEnd | ||
80 | + | ||
81 | +:fail | ||
82 | +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of | ||
83 | +rem the _cmd.exe /c_ return code! | ||
84 | +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 | ||
85 | +exit /b 1 | ||
86 | + | ||
87 | +:mainEnd | ||
88 | +if "%OS%"=="Windows_NT" endlocal | ||
89 | + | ||
90 | +:omega |
android/android/jni/CMakeLists.txt
0 → 100644
1 | +# | ||
2 | +# Copyright (C) 2016 The Android Open Source Project | ||
3 | +# | ||
4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +# you may not use this file except in compliance with the License. | ||
6 | +# You may obtain a copy of the License at | ||
7 | +# | ||
8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | +# | ||
10 | +# Unless required by applicable law or agreed to in writing, software | ||
11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +# See the License for the specific language governing permissions and | ||
14 | +# limitations under the License. | ||
15 | +# | ||
16 | + | ||
17 | +project(TENSORFLOW_DEMO) | ||
18 | +cmake_minimum_required(VERSION 3.4.1) | ||
19 | + | ||
20 | +set(CMAKE_VERBOSE_MAKEFILE on) | ||
21 | + | ||
22 | +get_filename_component(TF_SRC_ROOT ${CMAKE_SOURCE_DIR}/../../../.. ABSOLUTE) | ||
23 | +get_filename_component(SAMPLE_SRC_DIR ${CMAKE_SOURCE_DIR}/.. ABSOLUTE) | ||
24 | + | ||
25 | +if (ANDROID_ABI MATCHES "^armeabi-v7a$") | ||
26 | + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon") | ||
27 | +elseif(ANDROID_ABI MATCHES "^arm64-v8a") | ||
28 | + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -ftree-vectorize") | ||
29 | +endif() | ||
30 | + | ||
31 | +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSTANDALONE_DEMO_LIB \ | ||
32 | + -std=c++11 -fno-exceptions -fno-rtti -O2 -Wno-narrowing \ | ||
33 | + -fPIE") | ||
34 | +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \ | ||
35 | + -Wl,--allow-multiple-definition \ | ||
36 | + -Wl,--whole-archive -fPIE -v") | ||
37 | + | ||
38 | +file(GLOB_RECURSE tensorflow_demo_sources ${SAMPLE_SRC_DIR}/jni/*.*) | ||
39 | +add_library(tensorflow_demo SHARED | ||
40 | + ${tensorflow_demo_sources}) | ||
41 | +target_include_directories(tensorflow_demo PRIVATE | ||
42 | + ${TF_SRC_ROOT} | ||
43 | + ${CMAKE_SOURCE_DIR}) | ||
44 | + | ||
45 | +target_link_libraries(tensorflow_demo | ||
46 | + android | ||
47 | + log | ||
48 | + jnigraphics | ||
49 | + m | ||
50 | + atomic | ||
51 | + z) |
android/android/jni/__init__.py
0 → 100644
File mode changed
android/android/jni/imageutils_jni.cc
0 → 100644
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// This file binds the native image utility code to the Java class | ||
17 | +// which exposes them. | ||
18 | + | ||
19 | +#include <jni.h> | ||
20 | +#include <stdio.h> | ||
21 | +#include <stdlib.h> | ||
22 | + | ||
23 | +#include "tensorflow/examples/android/jni/rgb2yuv.h" | ||
24 | +#include "tensorflow/examples/android/jni/yuv2rgb.h" | ||
25 | + | ||
26 | +#define IMAGEUTILS_METHOD(METHOD_NAME) \ | ||
27 | + Java_org_tensorflow_demo_env_ImageUtils_##METHOD_NAME // NOLINT | ||
28 | + | ||
29 | +#ifdef __cplusplus | ||
30 | +extern "C" { | ||
31 | +#endif | ||
32 | + | ||
33 | +JNIEXPORT void JNICALL | ||
34 | +IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)( | ||
35 | + JNIEnv* env, jclass clazz, jbyteArray input, jintArray output, | ||
36 | + jint width, jint height, jboolean halfSize); | ||
37 | + | ||
38 | +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)( | ||
39 | + JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v, | ||
40 | + jintArray output, jint width, jint height, jint y_row_stride, | ||
41 | + jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize); | ||
42 | + | ||
43 | +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)( | ||
44 | + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width, | ||
45 | + jint height); | ||
46 | + | ||
47 | +JNIEXPORT void JNICALL | ||
48 | +IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)( | ||
49 | + JNIEnv* env, jclass clazz, jintArray input, jbyteArray output, | ||
50 | + jint width, jint height); | ||
51 | + | ||
52 | +JNIEXPORT void JNICALL | ||
53 | +IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)( | ||
54 | + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, | ||
55 | + jint width, jint height); | ||
56 | + | ||
57 | +#ifdef __cplusplus | ||
58 | +} | ||
59 | +#endif | ||
60 | + | ||
61 | +JNIEXPORT void JNICALL | ||
62 | +IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)( | ||
63 | + JNIEnv* env, jclass clazz, jbyteArray input, jintArray output, | ||
64 | + jint width, jint height, jboolean halfSize) { | ||
65 | + jboolean inputCopy = JNI_FALSE; | ||
66 | + jbyte* const i = env->GetByteArrayElements(input, &inputCopy); | ||
67 | + | ||
68 | + jboolean outputCopy = JNI_FALSE; | ||
69 | + jint* const o = env->GetIntArrayElements(output, &outputCopy); | ||
70 | + | ||
71 | + if (halfSize) { | ||
72 | + ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(i), | ||
73 | + reinterpret_cast<uint32_t*>(o), width, | ||
74 | + height); | ||
75 | + } else { | ||
76 | + ConvertYUV420SPToARGB8888(reinterpret_cast<uint8_t*>(i), | ||
77 | + reinterpret_cast<uint8_t*>(i) + width * height, | ||
78 | + reinterpret_cast<uint32_t*>(o), width, height); | ||
79 | + } | ||
80 | + | ||
81 | + env->ReleaseByteArrayElements(input, i, JNI_ABORT); | ||
82 | + env->ReleaseIntArrayElements(output, o, 0); | ||
83 | +} | ||
84 | + | ||
85 | +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)( | ||
86 | + JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v, | ||
87 | + jintArray output, jint width, jint height, jint y_row_stride, | ||
88 | + jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize) { | ||
89 | + jboolean inputCopy = JNI_FALSE; | ||
90 | + jbyte* const y_buff = env->GetByteArrayElements(y, &inputCopy); | ||
91 | + jboolean outputCopy = JNI_FALSE; | ||
92 | + jint* const o = env->GetIntArrayElements(output, &outputCopy); | ||
93 | + | ||
94 | + if (halfSize) { | ||
95 | + ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(y_buff), | ||
96 | + reinterpret_cast<uint32_t*>(o), width, | ||
97 | + height); | ||
98 | + } else { | ||
99 | + jbyte* const u_buff = env->GetByteArrayElements(u, &inputCopy); | ||
100 | + jbyte* const v_buff = env->GetByteArrayElements(v, &inputCopy); | ||
101 | + | ||
102 | + ConvertYUV420ToARGB8888( | ||
103 | + reinterpret_cast<uint8_t*>(y_buff), reinterpret_cast<uint8_t*>(u_buff), | ||
104 | + reinterpret_cast<uint8_t*>(v_buff), reinterpret_cast<uint32_t*>(o), | ||
105 | + width, height, y_row_stride, uv_row_stride, uv_pixel_stride); | ||
106 | + | ||
107 | + env->ReleaseByteArrayElements(u, u_buff, JNI_ABORT); | ||
108 | + env->ReleaseByteArrayElements(v, v_buff, JNI_ABORT); | ||
109 | + } | ||
110 | + | ||
111 | + env->ReleaseByteArrayElements(y, y_buff, JNI_ABORT); | ||
112 | + env->ReleaseIntArrayElements(output, o, 0); | ||
113 | +} | ||
114 | + | ||
115 | +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)( | ||
116 | + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width, | ||
117 | + jint height) { | ||
118 | + jboolean inputCopy = JNI_FALSE; | ||
119 | + jbyte* const i = env->GetByteArrayElements(input, &inputCopy); | ||
120 | + | ||
121 | + jboolean outputCopy = JNI_FALSE; | ||
122 | + jbyte* const o = env->GetByteArrayElements(output, &outputCopy); | ||
123 | + | ||
124 | + ConvertYUV420SPToRGB565(reinterpret_cast<uint8_t*>(i), | ||
125 | + reinterpret_cast<uint16_t*>(o), width, height); | ||
126 | + | ||
127 | + env->ReleaseByteArrayElements(input, i, JNI_ABORT); | ||
128 | + env->ReleaseByteArrayElements(output, o, 0); | ||
129 | +} | ||
130 | + | ||
131 | +JNIEXPORT void JNICALL | ||
132 | +IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)( | ||
133 | + JNIEnv* env, jclass clazz, jintArray input, jbyteArray output, | ||
134 | + jint width, jint height) { | ||
135 | + jboolean inputCopy = JNI_FALSE; | ||
136 | + jint* const i = env->GetIntArrayElements(input, &inputCopy); | ||
137 | + | ||
138 | + jboolean outputCopy = JNI_FALSE; | ||
139 | + jbyte* const o = env->GetByteArrayElements(output, &outputCopy); | ||
140 | + | ||
141 | + ConvertARGB8888ToYUV420SP(reinterpret_cast<uint32_t*>(i), | ||
142 | + reinterpret_cast<uint8_t*>(o), width, height); | ||
143 | + | ||
144 | + env->ReleaseIntArrayElements(input, i, JNI_ABORT); | ||
145 | + env->ReleaseByteArrayElements(output, o, 0); | ||
146 | +} | ||
147 | + | ||
148 | +JNIEXPORT void JNICALL | ||
149 | +IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)( | ||
150 | + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, | ||
151 | + jint width, jint height) { | ||
152 | + jboolean inputCopy = JNI_FALSE; | ||
153 | + jbyte* const i = env->GetByteArrayElements(input, &inputCopy); | ||
154 | + | ||
155 | + jboolean outputCopy = JNI_FALSE; | ||
156 | + jbyte* const o = env->GetByteArrayElements(output, &outputCopy); | ||
157 | + | ||
158 | + ConvertRGB565ToYUV420SP(reinterpret_cast<uint16_t*>(i), | ||
159 | + reinterpret_cast<uint8_t*>(o), width, height); | ||
160 | + | ||
161 | + env->ReleaseByteArrayElements(input, i, JNI_ABORT); | ||
162 | + env->ReleaseByteArrayElements(output, o, 0); | ||
163 | +} |
android/android/jni/object_tracking/config.h
0 → 100644
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ | ||
18 | + | ||
19 | +#include <math.h> | ||
20 | + | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
22 | + | ||
23 | +namespace tf_tracking { | ||
24 | + | ||
25 | +// Arbitrary keypoint type ids for labeling the origin of tracked keypoints. | ||
26 | +enum KeypointType { | ||
27 | + KEYPOINT_TYPE_DEFAULT = 0, | ||
28 | + KEYPOINT_TYPE_FAST = 1, | ||
29 | + KEYPOINT_TYPE_INTEREST = 2 | ||
30 | +}; | ||
31 | + | ||
32 | +// Struct that can be used to more richly store the results of a detection | ||
33 | +// than a single number, while still maintaining comparability. | ||
34 | +struct MatchScore { | ||
35 | + explicit MatchScore(double val) : value(val) {} | ||
36 | + MatchScore() { value = 0.0; } | ||
37 | + | ||
38 | + double value; | ||
39 | + | ||
40 | + MatchScore& operator+(const MatchScore& rhs) { | ||
41 | + value += rhs.value; | ||
42 | + return *this; | ||
43 | + } | ||
44 | + | ||
45 | + friend std::ostream& operator<<(std::ostream& stream, | ||
46 | + const MatchScore& detection) { | ||
47 | + stream << detection.value; | ||
48 | + return stream; | ||
49 | + } | ||
50 | +}; | ||
51 | +inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) { | ||
52 | + return cC1.value < cC2.value; | ||
53 | +} | ||
54 | +inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) { | ||
55 | + return cC1.value > cC2.value; | ||
56 | +} | ||
57 | +inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) { | ||
58 | + return cC1.value >= cC2.value; | ||
59 | +} | ||
60 | +inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) { | ||
61 | + return cC1.value <= cC2.value; | ||
62 | +} | ||
63 | + | ||
64 | +// Fixed seed used for all random number generators. | ||
65 | +static const int kRandomNumberSeed = 11111; | ||
66 | + | ||
67 | +// TODO(andrewharp): Move as many of these settings as possible into a settings | ||
68 | +// object which can be passed in from Java at runtime. | ||
69 | + | ||
70 | +// Whether or not to use ESM instead of LK flow. | ||
71 | +static const bool kUseEsm = false; | ||
72 | + | ||
73 | +// This constant gets added to the diagonal of the Hessian | ||
74 | +// before solving for translation in 2dof ESM. | ||
75 | +// It ensures better behavior especially in the absence of | ||
76 | +// strong texture. | ||
77 | +static const int kEsmRegularizer = 20; | ||
78 | + | ||
79 | +// Do we want to brightness-normalize each keypoint patch when we compute | ||
80 | +// its flow using ESM? | ||
81 | +static const bool kDoBrightnessNormalize = true; | ||
82 | + | ||
83 | +// Whether or not to use fixed-point interpolated pixel lookups in optical flow. | ||
84 | +#define USE_FIXED_POINT_FLOW 1 | ||
85 | + | ||
86 | +// Whether to normalize keypoint windows for intensity in LK optical flow. | ||
87 | +// This is a define for now because it helps keep the code streamlined. | ||
88 | +#define NORMALIZE 1 | ||
89 | + | ||
90 | +// Number of keypoints to store per frame. | ||
91 | +static const int kMaxKeypoints = 76; | ||
92 | + | ||
93 | +// Keypoint detection. | ||
94 | +static const int kMaxTempKeypoints = 1024; | ||
95 | + | ||
96 | +// Number of floats each keypoint takes up when exporting to an array. | ||
97 | +static const int kKeypointStep = 7; | ||
98 | + | ||
99 | +// Number of frame deltas to keep around in the circular queue. | ||
100 | +static const int kNumFrames = 512; | ||
101 | + | ||
102 | +// Number of iterations to do tracking on each keypoint at each pyramid level. | ||
103 | +static const int kNumIterations = 3; | ||
104 | + | ||
105 | +// The number of bins (on a side) to divide each bin from the previous | ||
106 | +// cache level into. Higher numbers will decrease performance by increasing | ||
107 | +// cache misses, but mean that cache hits are more locally relevant. | ||
108 | +static const int kCacheBranchFactor = 2; | ||
109 | + | ||
110 | +// Number of levels to put in the cache. | ||
111 | +// Each level of the cache is a square grid of bins, length: | ||
112 | +// branch_factor^(level - 1) on each side. | ||
113 | +// | ||
114 | +// This may be greater than kNumPyramidLevels. Setting it to 0 means no | ||
115 | +// caching is enabled. | ||
116 | +static const int kNumCacheLevels = 3; | ||
117 | + | ||
118 | +// The level at which the cache pyramid gets cut off and replaced by a matrix | ||
119 | +// transform if such a matrix has been provided to the cache. | ||
120 | +static const int kCacheCutoff = 1; | ||
121 | + | ||
122 | +static const int kNumPyramidLevels = 4; | ||
123 | + | ||
124 | +// The minimum number of keypoints needed in an object's area. | ||
125 | +static const int kMaxKeypointsForObject = 16; | ||
126 | + | ||
127 | +// Minimum number of pyramid levels to use after getting cached value. | ||
128 | +// This allows fine-scale adjustment from the cached value, which is taken | ||
129 | +// from the center of the corresponding top cache level box. | ||
130 | +// Can be [0, kNumPyramidLevels). | ||
131 | +static const int kMinNumPyramidLevelsToUseForAdjustment = 1; | ||
132 | + | ||
133 | +// Window size to integrate over to find local image derivative. | ||
134 | +static const int kFlowIntegrationWindowSize = 3; | ||
135 | + | ||
136 | +// Total area of integration windows. | ||
137 | +static const int kFlowArraySize = | ||
138 | + (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1); | ||
139 | + | ||
140 | +// Error that's considered good enough to early abort tracking. | ||
141 | +static const float kTrackingAbortThreshold = 0.03f; | ||
142 | + | ||
143 | +// Maximum number of deviations a keypoint-correspondence delta can be from the | ||
144 | +// weighted average before being thrown out for region-based queries. | ||
145 | +static const float kNumDeviations = 2.0f; | ||
146 | + | ||
147 | +// The length of the allowed delta between the forward and the backward | ||
148 | +// flow deltas in terms of the length of the forward flow vector. | ||
149 | +static const float kMaxForwardBackwardErrorAllowed = 0.5f; | ||
150 | + | ||
151 | +// Threshold for pixels to be considered different. | ||
152 | +static const int kFastDiffAmount = 10; | ||
153 | + | ||
154 | +// How far from edge of frame to stop looking for FAST keypoints. | ||
155 | +static const int kFastBorderBuffer = 10; | ||
156 | + | ||
157 | +// Determines if non-detected arbitrary keypoints should be added to regions. | ||
158 | +// This will help if no keypoints have been detected in the region yet. | ||
159 | +static const bool kAddArbitraryKeypoints = true; | ||
160 | + | ||
161 | +// How many arbitrary keypoints to add along each axis as candidates for each | ||
162 | +// region? | ||
163 | +static const int kNumToAddAsCandidates = 1; | ||
164 | + | ||
165 | +// In terms of region dimensions, how closely can we place keypoints | ||
166 | +// next to each other? | ||
167 | +static const float kClosestPercent = 0.6f; | ||
168 | + | ||
169 | +// How many FAST qualifying pixels must be connected to a pixel for it to be | ||
170 | +// considered a candidate keypoint for Harris filtering. | ||
171 | +static const int kMinNumConnectedForFastKeypoint = 8; | ||
172 | + | ||
173 | +// Size of the window to integrate over for Harris filtering. | ||
174 | +// Compare to kFlowIntegrationWindowSize. | ||
175 | +static const int kHarrisWindowSize = 2; | ||
176 | + | ||
177 | + | ||
178 | +// DETECTOR PARAMETERS | ||
179 | + | ||
180 | +// Before relocalizing, make sure the new proposed position is better than | ||
181 | +// the existing position by a small amount to prevent thrashing. | ||
182 | +static const MatchScore kMatchScoreBuffer(0.01f); | ||
183 | + | ||
184 | +// Minimum score a tracked object can have and still be considered a match. | ||
185 | +// TODO(andrewharp): Make this a per detector thing. | ||
186 | +static const MatchScore kMinimumMatchScore(0.5f); | ||
187 | + | ||
188 | +static const float kMinimumCorrelationForTracking = 0.4f; | ||
189 | + | ||
190 | +static const MatchScore kMatchScoreForImmediateTermination(0.0f); | ||
191 | + | ||
192 | +// Run the detector every N frames. | ||
193 | +static const int kDetectEveryNFrames = 4; | ||
194 | + | ||
195 | +// How many features does each feature_set contain? | ||
196 | +static const int kFeaturesPerFeatureSet = 10; | ||
197 | + | ||
198 | +// The number of FeatureSets managed by the object detector. | ||
199 | +// More FeatureSets can increase recall at the cost of performance. | ||
200 | +static const int kNumFeatureSets = 7; | ||
201 | + | ||
202 | +// How many FeatureSets must respond affirmatively for a candidate descriptor | ||
203 | +// and position to be given more thorough attention? | ||
204 | +static const int kNumFeatureSetsForCandidate = 2; | ||
205 | + | ||
206 | +// How large the thumbnails used for correlation validation are. Used for both | ||
207 | +// width and height. | ||
208 | +static const int kNormalizedThumbnailSize = 11; | ||
209 | + | ||
210 | +// The area of intersection divided by union for the bounding boxes that tells | ||
211 | +// if this tracking has slipped enough to invalidate all unlocked examples. | ||
212 | +static const float kPositionOverlapThreshold = 0.6f; | ||
213 | + | ||
214 | +// The number of detection failures allowed before an object goes invisible. | ||
215 | +// Tracking will still occur, so if it is actually still being tracked and | ||
216 | +// comes back into a detectable position, it's likely to be found. | ||
217 | +static const int kMaxNumDetectionFailures = 4; | ||
218 | + | ||
219 | + | ||
220 | +// Minimum square size to scan with sliding window. | ||
221 | +static const float kScanMinSquareSize = 16.0f; | ||
222 | + | ||
223 | +// Minimum square size to scan with sliding window. | ||
224 | +static const float kScanMaxSquareSize = 64.0f; | ||
225 | + | ||
226 | +// Scale difference for consecutive scans of the sliding window. | ||
227 | +static const float kScanScaleFactor = sqrtf(2.0f); | ||
228 | + | ||
229 | +// Step size for sliding window. | ||
230 | +static const int kScanStepSize = 10; | ||
231 | + | ||
232 | + | ||
233 | +// How tightly to pack the descriptor boxes for confirmed exemplars. | ||
234 | +static const float kLockedScaleFactor = 1 / sqrtf(2.0f); | ||
235 | + | ||
236 | +// How tightly to pack the descriptor boxes for unconfirmed exemplars. | ||
237 | +static const float kUnlockedScaleFactor = 1 / 2.0f; | ||
238 | + | ||
239 | +// How tightly the boxes to scan centered at the last known position will be | ||
240 | +// packed. | ||
241 | +static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f); | ||
242 | + | ||
243 | +// The bounds on how close a new object example must be to existing object | ||
244 | +// examples for detection to be valid. | ||
245 | +static const float kMinCorrelationForNewExample = 0.75f; | ||
246 | +static const float kMaxCorrelationForNewExample = 0.99f; | ||
247 | + | ||
248 | + | ||
249 | +// The number of safe tries an exemplar has after being created before | ||
250 | +// missed detections count against it. | ||
251 | +static const int kFreeTries = 5; | ||
252 | + | ||
253 | +// A false positive is worth this many missed detections. | ||
254 | +static const int kFalsePositivePenalty = 5; | ||
255 | + | ||
256 | +struct ObjectDetectorConfig { | ||
257 | + const Size image_size; | ||
258 | + | ||
259 | + explicit ObjectDetectorConfig(const Size& image_size) | ||
260 | + : image_size(image_size) {} | ||
261 | + virtual ~ObjectDetectorConfig() = default; | ||
262 | +}; | ||
263 | + | ||
264 | +struct KeypointDetectorConfig { | ||
265 | + const Size image_size; | ||
266 | + | ||
267 | + bool detect_skin; | ||
268 | + | ||
269 | + explicit KeypointDetectorConfig(const Size& image_size) | ||
270 | + : image_size(image_size), | ||
271 | + detect_skin(false) {} | ||
272 | +}; | ||
273 | + | ||
274 | + | ||
275 | +struct OpticalFlowConfig { | ||
276 | + const Size image_size; | ||
277 | + | ||
278 | + explicit OpticalFlowConfig(const Size& image_size) | ||
279 | + : image_size(image_size) {} | ||
280 | +}; | ||
281 | + | ||
282 | +struct TrackerConfig { | ||
283 | + const Size image_size; | ||
284 | + KeypointDetectorConfig keypoint_detector_config; | ||
285 | + OpticalFlowConfig flow_config; | ||
286 | + bool always_track; | ||
287 | + | ||
288 | + float object_box_scale_factor_for_features; | ||
289 | + | ||
290 | + explicit TrackerConfig(const Size& image_size) | ||
291 | + : image_size(image_size), | ||
292 | + keypoint_detector_config(image_size), | ||
293 | + flow_config(image_size), | ||
294 | + always_track(false), | ||
295 | + object_box_scale_factor_for_features(1.0f) {} | ||
296 | +}; | ||
297 | + | ||
298 | +} // namespace tf_tracking | ||
299 | + | ||
300 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
21 | + | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" | ||
24 | + | ||
25 | +namespace tf_tracking { | ||
26 | + | ||
27 | +// Class that helps OpticalFlow to speed up flow computation | ||
28 | +// by caching coarse-grained flow. | ||
29 | +class FlowCache { | ||
30 | + public: | ||
31 | + explicit FlowCache(const OpticalFlowConfig* const config) | ||
32 | + : config_(config), | ||
33 | + image_size_(config->image_size), | ||
34 | + optical_flow_(config), | ||
35 | + fullframe_matrix_(NULL) { | ||
36 | + for (int i = 0; i < kNumCacheLevels; ++i) { | ||
37 | + const int curr_dims = BlockDimForCacheLevel(i); | ||
38 | + has_cache_[i] = new Image<bool>(curr_dims, curr_dims); | ||
39 | + displacements_[i] = new Image<Point2f>(curr_dims, curr_dims); | ||
40 | + } | ||
41 | + } | ||
42 | + | ||
43 | + ~FlowCache() { | ||
44 | + for (int i = 0; i < kNumCacheLevels; ++i) { | ||
45 | + SAFE_DELETE(has_cache_[i]); | ||
46 | + SAFE_DELETE(displacements_[i]); | ||
47 | + } | ||
48 | + delete[](fullframe_matrix_); | ||
49 | + fullframe_matrix_ = NULL; | ||
50 | + } | ||
51 | + | ||
52 | + void NextFrame(ImageData* const new_frame, | ||
53 | + const float* const align_matrix23) { | ||
54 | + ClearCache(); | ||
55 | + SetFullframeAlignmentMatrix(align_matrix23); | ||
56 | + optical_flow_.NextFrame(new_frame); | ||
57 | + } | ||
58 | + | ||
59 | + void ClearCache() { | ||
60 | + for (int i = 0; i < kNumCacheLevels; ++i) { | ||
61 | + has_cache_[i]->Clear(false); | ||
62 | + } | ||
63 | + delete[](fullframe_matrix_); | ||
64 | + fullframe_matrix_ = NULL; | ||
65 | + } | ||
66 | + | ||
67 | + // Finds the flow at a point, using the cache for performance. | ||
68 | + bool FindFlowAtPoint(const float u_x, const float u_y, | ||
69 | + float* const flow_x, float* const flow_y) const { | ||
70 | + // Get the best guess from the cache. | ||
71 | + const Point2f guess_from_cache = LookupGuess(u_x, u_y); | ||
72 | + | ||
73 | + *flow_x = guess_from_cache.x; | ||
74 | + *flow_y = guess_from_cache.y; | ||
75 | + | ||
76 | + // Now refine the guess using the image pyramid. | ||
77 | + for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1; | ||
78 | + pyramid_level >= 0; --pyramid_level) { | ||
79 | + if (!optical_flow_.FindFlowAtPointSingleLevel( | ||
80 | + pyramid_level, u_x, u_y, false, flow_x, flow_y)) { | ||
81 | + return false; | ||
82 | + } | ||
83 | + } | ||
84 | + | ||
85 | + return true; | ||
86 | + } | ||
87 | + | ||
88 | + // Determines the displacement of a point, and uses that to calculate a new | ||
89 | + // position. | ||
90 | + // Returns true iff the displacement determination worked and the new position | ||
91 | + // is in the image. | ||
92 | + bool FindNewPositionOfPoint(const float u_x, const float u_y, | ||
93 | + float* final_x, float* final_y) const { | ||
94 | + float flow_x; | ||
95 | + float flow_y; | ||
96 | + if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) { | ||
97 | + return false; | ||
98 | + } | ||
99 | + | ||
100 | + // Add in the displacement to get the final position. | ||
101 | + *final_x = u_x + flow_x; | ||
102 | + *final_y = u_y + flow_y; | ||
103 | + | ||
104 | + // Assign the best guess, if we're still in the image. | ||
105 | + if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) && | ||
106 | + InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) { | ||
107 | + return true; | ||
108 | + } else { | ||
109 | + return false; | ||
110 | + } | ||
111 | + } | ||
112 | + | ||
113 | + // Comparison function for qsort. | ||
114 | + static int Compare(const void* a, const void* b) { | ||
115 | + return *reinterpret_cast<const float*>(a) - | ||
116 | + *reinterpret_cast<const float*>(b); | ||
117 | + } | ||
118 | + | ||
119 | + // Returns the median flow within the given bounding box as determined | ||
120 | + // by a grid_width x grid_height grid. | ||
121 | + Point2f GetMedianFlow(const BoundingBox& bounding_box, | ||
122 | + const bool filter_by_fb_error, | ||
123 | + const int grid_width, | ||
124 | + const int grid_height) const { | ||
125 | + const int kMaxPoints = 100; | ||
126 | + SCHECK(grid_width * grid_height <= kMaxPoints, | ||
127 | + "Too many points for Median flow!"); | ||
128 | + | ||
129 | + const BoundingBox valid_box = bounding_box.Intersect( | ||
130 | + BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1)); | ||
131 | + | ||
132 | + if (valid_box.GetArea() <= 0.0f) { | ||
133 | + return Point2f(0, 0); | ||
134 | + } | ||
135 | + | ||
136 | + float x_deltas[kMaxPoints]; | ||
137 | + float y_deltas[kMaxPoints]; | ||
138 | + | ||
139 | + int curr_offset = 0; | ||
140 | + for (int i = 0; i < grid_width; ++i) { | ||
141 | + for (int j = 0; j < grid_height; ++j) { | ||
142 | + const float x_in = valid_box.left_ + | ||
143 | + (valid_box.GetWidth() * i) / (grid_width - 1); | ||
144 | + | ||
145 | + const float y_in = valid_box.top_ + | ||
146 | + (valid_box.GetHeight() * j) / (grid_height - 1); | ||
147 | + | ||
148 | + float curr_flow_x; | ||
149 | + float curr_flow_y; | ||
150 | + const bool success = FindNewPositionOfPoint(x_in, y_in, | ||
151 | + &curr_flow_x, &curr_flow_y); | ||
152 | + | ||
153 | + if (success) { | ||
154 | + x_deltas[curr_offset] = curr_flow_x; | ||
155 | + y_deltas[curr_offset] = curr_flow_y; | ||
156 | + ++curr_offset; | ||
157 | + } else { | ||
158 | + LOGW("Tracking failure!"); | ||
159 | + } | ||
160 | + } | ||
161 | + } | ||
162 | + | ||
163 | + if (curr_offset > 0) { | ||
164 | + qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare); | ||
165 | + qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare); | ||
166 | + | ||
167 | + return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]); | ||
168 | + } | ||
169 | + | ||
170 | + LOGW("No points were valid!"); | ||
171 | + return Point2f(0, 0); | ||
172 | + } | ||
173 | + | ||
174 | + void SetFullframeAlignmentMatrix(const float* const align_matrix23) { | ||
175 | + if (align_matrix23 != NULL) { | ||
176 | + if (fullframe_matrix_ == NULL) { | ||
177 | + fullframe_matrix_ = new float[6]; | ||
178 | + } | ||
179 | + | ||
180 | + memcpy(fullframe_matrix_, align_matrix23, | ||
181 | + 6 * sizeof(fullframe_matrix_[0])); | ||
182 | + } | ||
183 | + } | ||
184 | + | ||
185 | + private: | ||
186 | + Point2f LookupGuessFromLevel( | ||
187 | + const int cache_level, const float x, const float y) const { | ||
188 | + // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level); | ||
189 | + | ||
190 | + // Cutoff at the target level and use the matrix transform instead. | ||
191 | + if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) { | ||
192 | + const float xnew = x * fullframe_matrix_[0] + | ||
193 | + y * fullframe_matrix_[1] + | ||
194 | + fullframe_matrix_[2]; | ||
195 | + const float ynew = x * fullframe_matrix_[3] + | ||
196 | + y * fullframe_matrix_[4] + | ||
197 | + fullframe_matrix_[5]; | ||
198 | + | ||
199 | + return Point2f(xnew - x, ynew - y); | ||
200 | + } | ||
201 | + | ||
202 | + const int level_dim = BlockDimForCacheLevel(cache_level); | ||
203 | + const int pixels_per_cache_block_x = | ||
204 | + (image_size_.width + level_dim - 1) / level_dim; | ||
205 | + const int pixels_per_cache_block_y = | ||
206 | + (image_size_.height + level_dim - 1) / level_dim; | ||
207 | + const int index_x = x / pixels_per_cache_block_x; | ||
208 | + const int index_y = y / pixels_per_cache_block_y; | ||
209 | + | ||
210 | + Point2f displacement; | ||
211 | + if (!(*has_cache_[cache_level])[index_y][index_x]) { | ||
212 | + (*has_cache_[cache_level])[index_y][index_x] = true; | ||
213 | + | ||
214 | + // Get the lower cache level's best guess, if it exists. | ||
215 | + displacement = cache_level >= kNumCacheLevels - 1 ? | ||
216 | + Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y); | ||
217 | + // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level, | ||
218 | + // best_guess.x, best_guess.y); | ||
219 | + | ||
220 | + // Find the center of the block. | ||
221 | + const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x; | ||
222 | + const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y; | ||
223 | + const int pyramid_level = PyramidLevelForCacheLevel(cache_level); | ||
224 | + | ||
225 | + // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] " | ||
226 | + // "Querying %5.2f, %5.2f at pyramid level %d, ", | ||
227 | + // cache_level, index_x, index_y, | ||
228 | + // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y, | ||
229 | + // center_x, center_y, pyramid_level); | ||
230 | + | ||
231 | + // TODO(andrewharp): Turn on FB error filtering. | ||
232 | + const bool success = optical_flow_.FindFlowAtPointSingleLevel( | ||
233 | + pyramid_level, center_x, center_y, false, | ||
234 | + &displacement.x, &displacement.y); | ||
235 | + | ||
236 | + if (!success) { | ||
237 | + LOGV("Computation of cached value failed for level %d!", cache_level); | ||
238 | + } | ||
239 | + | ||
240 | + // Store the value for later use. | ||
241 | + (*displacements_[cache_level])[index_y][index_x] = displacement; | ||
242 | + } else { | ||
243 | + displacement = (*displacements_[cache_level])[index_y][index_x]; | ||
244 | + } | ||
245 | + | ||
246 | + // LOGI("Returning %5.2f, %5.2f for level %d", | ||
247 | + // displacement.x, displacement.y, cache_level); | ||
248 | + return displacement; | ||
249 | + } | ||
250 | + | ||
251 | + Point2f LookupGuess(const float x, const float y) const { | ||
252 | + if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) { | ||
253 | + return Point2f(0, 0); | ||
254 | + } | ||
255 | + | ||
256 | + // LOGI("Looking up guess at %5.2f %5.2f.", x, y); | ||
257 | + if (kNumCacheLevels > 0) { | ||
258 | + return LookupGuessFromLevel(0, x, y); | ||
259 | + } else { | ||
260 | + return Point2f(0, 0); | ||
261 | + } | ||
262 | + } | ||
263 | + | ||
264 | + // Returns the number of cache bins in each dimension for a given level | ||
265 | + // of the cache. | ||
266 | + int BlockDimForCacheLevel(const int cache_level) const { | ||
267 | + // The highest (coarsest) cache level has a block dim of kCacheBranchFactor, | ||
268 | + // thus if there are 4 cache levels, requesting level 3 (0-based) should | ||
269 | + // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2, | ||
270 | + // and so on. | ||
271 | + int block_dim = kNumCacheLevels; | ||
272 | + for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level; | ||
273 | + --curr_level) { | ||
274 | + block_dim *= kCacheBranchFactor; | ||
275 | + } | ||
276 | + return block_dim; | ||
277 | + } | ||
278 | + | ||
279 | + // Returns the level of the image pyramid that a given cache level maps to. | ||
280 | + int PyramidLevelForCacheLevel(const int cache_level) const { | ||
281 | + // Higher cache and pyramid levels have smaller dimensions. The highest | ||
282 | + // cache level should refer to the highest image pyramid level. The | ||
283 | + // lower, finer image pyramid levels are uncached (assuming | ||
284 | + // kNumCacheLevels < kNumPyramidLevels). | ||
285 | + return cache_level + (kNumPyramidLevels - kNumCacheLevels); | ||
286 | + } | ||
287 | + | ||
288 | + const OpticalFlowConfig* const config_; | ||
289 | + | ||
290 | + const Size image_size_; | ||
291 | + OpticalFlow optical_flow_; | ||
292 | + | ||
293 | + float* fullframe_matrix_; | ||
294 | + | ||
295 | + // Whether this value is currently present in the cache. | ||
296 | + Image<bool>* has_cache_[kNumCacheLevels]; | ||
297 | + | ||
298 | + // The cached displacement values. | ||
299 | + Image<Point2f>* displacements_[kNumCacheLevels]; | ||
300 | + | ||
301 | + TF_DISALLOW_COPY_AND_ASSIGN(FlowCache); | ||
302 | +}; | ||
303 | + | ||
304 | +} // namespace tf_tracking | ||
305 | + | ||
306 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#include <float.h> | ||
17 | + | ||
18 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" | ||
20 | + | ||
21 | +namespace tf_tracking { | ||
22 | + | ||
23 | +void FramePair::Init(const int64_t start_time, const int64_t end_time) { | ||
24 | + start_time_ = start_time; | ||
25 | + end_time_ = end_time; | ||
26 | + memset(optical_flow_found_keypoint_, false, | ||
27 | + sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints); | ||
28 | + number_of_keypoints_ = 0; | ||
29 | +} | ||
30 | + | ||
31 | +void FramePair::AdjustBox(const BoundingBox box, | ||
32 | + float* const translation_x, | ||
33 | + float* const translation_y, | ||
34 | + float* const scale_x, | ||
35 | + float* const scale_y) const { | ||
36 | + static float weights[kMaxKeypoints]; | ||
37 | + static Point2f deltas[kMaxKeypoints]; | ||
38 | + memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints); | ||
39 | + | ||
40 | + BoundingBox resized_box(box); | ||
41 | + resized_box.Scale(0.4f, 0.4f); | ||
42 | + FillWeights(resized_box, weights); | ||
43 | + FillTranslations(deltas); | ||
44 | + | ||
45 | + const Point2f translation = GetWeightedMedian(weights, deltas); | ||
46 | + | ||
47 | + *translation_x = translation.x; | ||
48 | + *translation_y = translation.y; | ||
49 | + | ||
50 | + const Point2f old_center = box.GetCenter(); | ||
51 | + const int good_scale_points = | ||
52 | + FillScales(old_center, translation, weights, deltas); | ||
53 | + | ||
54 | + // Default scale factor is 1 for x and y. | ||
55 | + *scale_x = 1.0f; | ||
56 | + *scale_y = 1.0f; | ||
57 | + | ||
58 | + // The assumption is that all deltas that make it to this stage with a | ||
59 | + // corresponding optical_flow_found_keypoint_[i] == true are not in | ||
60 | + // themselves degenerate. | ||
61 | + // | ||
62 | + // The degeneracy with scale arose because if the points are too close to the | ||
63 | + // center of the objects, the scale ratio determination might be incalculable. | ||
64 | + // | ||
65 | + // The check for kMinNumInRange is not a degeneracy check, but merely an | ||
66 | + // attempt to ensure some sort of stability. The actual degeneracy check is in | ||
67 | + // the comparison to EPSILON in FillScales (which I've updated to return the | ||
68 | + // number good remaining as well). | ||
69 | + static const int kMinNumInRange = 5; | ||
70 | + if (good_scale_points >= kMinNumInRange) { | ||
71 | + const float scale_factor = GetWeightedMedianScale(weights, deltas); | ||
72 | + | ||
73 | + if (scale_factor > 0.0f) { | ||
74 | + *scale_x = scale_factor; | ||
75 | + *scale_y = scale_factor; | ||
76 | + } | ||
77 | + } | ||
78 | +} | ||
79 | + | ||
80 | +int FramePair::FillWeights(const BoundingBox& box, | ||
81 | + float* const weights) const { | ||
82 | + // Compute the max score. | ||
83 | + float max_score = -FLT_MAX; | ||
84 | + float min_score = FLT_MAX; | ||
85 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
86 | + if (optical_flow_found_keypoint_[i]) { | ||
87 | + max_score = MAX(max_score, frame1_keypoints_[i].score_); | ||
88 | + min_score = MIN(min_score, frame1_keypoints_[i].score_); | ||
89 | + } | ||
90 | + } | ||
91 | + | ||
92 | + int num_in_range = 0; | ||
93 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
94 | + if (!optical_flow_found_keypoint_[i]) { | ||
95 | + weights[i] = 0.0f; | ||
96 | + continue; | ||
97 | + } | ||
98 | + | ||
99 | + const bool in_box = box.Contains(frame1_keypoints_[i].pos_); | ||
100 | + if (in_box) { | ||
101 | + ++num_in_range; | ||
102 | + } | ||
103 | + | ||
104 | + // The weighting based off distance. Anything within the bounding box | ||
105 | + // has a weight of 1, and everything outside of that is within the range | ||
106 | + // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio. | ||
107 | + float distance_score = 1.0f; | ||
108 | + if (!in_box) { | ||
109 | + const Point2f initial = box.GetCenter(); | ||
110 | + const float sq_x_dist = | ||
111 | + Square(initial.x - frame1_keypoints_[i].pos_.x); | ||
112 | + const float sq_y_dist = | ||
113 | + Square(initial.y - frame1_keypoints_[i].pos_.y); | ||
114 | + const float squared_half_width = Square(box.GetWidth() / 2.0f); | ||
115 | + const float squared_half_height = Square(box.GetHeight() / 2.0f); | ||
116 | + | ||
117 | + static const float kOutOfBoxMultiplier = 0.5f; | ||
118 | + distance_score = kOutOfBoxMultiplier * | ||
119 | + MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist); | ||
120 | + } | ||
121 | + | ||
122 | + // The weighting based on relative score strength. kBaseScore - 1.0f. | ||
123 | + float intrinsic_score = 1.0f; | ||
124 | + if (max_score > min_score) { | ||
125 | + static const float kBaseScore = 0.5f; | ||
126 | + intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) / | ||
127 | + (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore; | ||
128 | + } | ||
129 | + | ||
130 | + // The final score will be in the range [0, 1]. | ||
131 | + weights[i] = distance_score * intrinsic_score; | ||
132 | + } | ||
133 | + | ||
134 | + return num_in_range; | ||
135 | +} | ||
136 | + | ||
137 | +void FramePair::FillTranslations(Point2f* const translations) const { | ||
138 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
139 | + if (!optical_flow_found_keypoint_[i]) { | ||
140 | + continue; | ||
141 | + } | ||
142 | + translations[i].x = | ||
143 | + frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x; | ||
144 | + translations[i].y = | ||
145 | + frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y; | ||
146 | + } | ||
147 | +} | ||
148 | + | ||
149 | +int FramePair::FillScales(const Point2f& old_center, | ||
150 | + const Point2f& translation, | ||
151 | + float* const weights, | ||
152 | + Point2f* const scales) const { | ||
153 | + int num_good = 0; | ||
154 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
155 | + if (!optical_flow_found_keypoint_[i]) { | ||
156 | + continue; | ||
157 | + } | ||
158 | + | ||
159 | + const Keypoint keypoint1 = frame1_keypoints_[i]; | ||
160 | + const Keypoint keypoint2 = frame2_keypoints_[i]; | ||
161 | + | ||
162 | + const float dist1_x = keypoint1.pos_.x - old_center.x; | ||
163 | + const float dist1_y = keypoint1.pos_.y - old_center.y; | ||
164 | + | ||
165 | + const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x; | ||
166 | + const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y; | ||
167 | + | ||
168 | + // Make sure that the scale makes sense; points too close to the center | ||
169 | + // will result in either NaNs or infinite results for scale due to | ||
170 | + // limited tracking and floating point resolution. | ||
171 | + // Also check that the parity of the points is the same with respect to | ||
172 | + // x and y, as we can't really make sense of data that has flipped. | ||
173 | + if (((dist2_x > EPSILON && dist1_x > EPSILON) || | ||
174 | + (dist2_x < -EPSILON && dist1_x < -EPSILON)) && | ||
175 | + ((dist2_y > EPSILON && dist1_y > EPSILON) || | ||
176 | + (dist2_y < -EPSILON && dist1_y < -EPSILON))) { | ||
177 | + scales[i].x = dist2_x / dist1_x; | ||
178 | + scales[i].y = dist2_y / dist1_y; | ||
179 | + ++num_good; | ||
180 | + } else { | ||
181 | + weights[i] = 0.0f; | ||
182 | + scales[i].x = 1.0f; | ||
183 | + scales[i].y = 1.0f; | ||
184 | + } | ||
185 | + } | ||
186 | + return num_good; | ||
187 | +} | ||
188 | + | ||
189 | +struct WeightedDelta { | ||
190 | + float weight; | ||
191 | + float delta; | ||
192 | +}; | ||
193 | + | ||
194 | +// Sort by delta, not by weight. | ||
195 | +inline int WeightedDeltaCompare(const void* const a, const void* const b) { | ||
196 | + return (reinterpret_cast<const WeightedDelta*>(a)->delta - | ||
197 | + reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1; | ||
198 | +} | ||
199 | + | ||
200 | +// Returns the median delta from a sorted set of weighted deltas. | ||
201 | +static float GetMedian(const int num_items, | ||
202 | + const WeightedDelta* const weighted_deltas, | ||
203 | + const float sum) { | ||
204 | + if (num_items == 0 || sum < EPSILON) { | ||
205 | + return 0.0f; | ||
206 | + } | ||
207 | + | ||
208 | + float current_weight = 0.0f; | ||
209 | + const float target_weight = sum / 2.0f; | ||
210 | + for (int i = 0; i < num_items; ++i) { | ||
211 | + if (weighted_deltas[i].weight > 0.0f) { | ||
212 | + current_weight += weighted_deltas[i].weight; | ||
213 | + if (current_weight >= target_weight) { | ||
214 | + return weighted_deltas[i].delta; | ||
215 | + } | ||
216 | + } | ||
217 | + } | ||
218 | + LOGW("Median not found! %d points, sum of %.2f", num_items, sum); | ||
219 | + return 0.0f; | ||
220 | +} | ||
221 | + | ||
222 | +Point2f FramePair::GetWeightedMedian( | ||
223 | + const float* const weights, const Point2f* const deltas) const { | ||
224 | + Point2f median_delta; | ||
225 | + | ||
226 | + // TODO(andrewharp): only sort deltas that could possibly have an effect. | ||
227 | + static WeightedDelta weighted_deltas[kMaxKeypoints]; | ||
228 | + | ||
229 | + // Compute median X value. | ||
230 | + { | ||
231 | + float total_weight = 0.0f; | ||
232 | + | ||
233 | + // Compute weighted mean and deltas. | ||
234 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
235 | + weighted_deltas[i].delta = deltas[i].x; | ||
236 | + const float weight = weights[i]; | ||
237 | + weighted_deltas[i].weight = weight; | ||
238 | + if (weight > 0.0f) { | ||
239 | + total_weight += weight; | ||
240 | + } | ||
241 | + } | ||
242 | + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), | ||
243 | + WeightedDeltaCompare); | ||
244 | + median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); | ||
245 | + } | ||
246 | + | ||
247 | + // Compute median Y value. | ||
248 | + { | ||
249 | + float total_weight = 0.0f; | ||
250 | + | ||
251 | + // Compute weighted mean and deltas. | ||
252 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
253 | + const float weight = weights[i]; | ||
254 | + weighted_deltas[i].weight = weight; | ||
255 | + weighted_deltas[i].delta = deltas[i].y; | ||
256 | + if (weight > 0.0f) { | ||
257 | + total_weight += weight; | ||
258 | + } | ||
259 | + } | ||
260 | + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), | ||
261 | + WeightedDeltaCompare); | ||
262 | + median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); | ||
263 | + } | ||
264 | + | ||
265 | + return median_delta; | ||
266 | +} | ||
267 | + | ||
268 | +float FramePair::GetWeightedMedianScale( | ||
269 | + const float* const weights, const Point2f* const deltas) const { | ||
270 | + float median_delta; | ||
271 | + | ||
272 | + // TODO(andrewharp): only sort deltas that could possibly have an effect. | ||
273 | + static WeightedDelta weighted_deltas[kMaxKeypoints * 2]; | ||
274 | + | ||
275 | + // Compute median scale value across x and y. | ||
276 | + { | ||
277 | + float total_weight = 0.0f; | ||
278 | + | ||
279 | + // Add X values. | ||
280 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
281 | + weighted_deltas[i].delta = deltas[i].x; | ||
282 | + const float weight = weights[i]; | ||
283 | + weighted_deltas[i].weight = weight; | ||
284 | + if (weight > 0.0f) { | ||
285 | + total_weight += weight; | ||
286 | + } | ||
287 | + } | ||
288 | + | ||
289 | + // Add Y values. | ||
290 | + for (int i = 0; i < kMaxKeypoints; ++i) { | ||
291 | + weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y; | ||
292 | + const float weight = weights[i]; | ||
293 | + weighted_deltas[i + kMaxKeypoints].weight = weight; | ||
294 | + if (weight > 0.0f) { | ||
295 | + total_weight += weight; | ||
296 | + } | ||
297 | + } | ||
298 | + | ||
299 | + qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta), | ||
300 | + WeightedDeltaCompare); | ||
301 | + | ||
302 | + median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight); | ||
303 | + } | ||
304 | + | ||
305 | + return median_delta; | ||
306 | +} | ||
307 | + | ||
308 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" | ||
20 | + | ||
21 | +namespace tf_tracking { | ||
22 | + | ||
23 | +// A class that records keypoint correspondences from pairs of | ||
24 | +// consecutive frames. | ||
25 | +class FramePair { | ||
26 | + public: | ||
27 | + FramePair() | ||
28 | + : start_time_(0), | ||
29 | + end_time_(0), | ||
30 | + number_of_keypoints_(0) {} | ||
31 | + | ||
32 | + // Cleans up the FramePair so that they can be reused. | ||
33 | + void Init(const int64_t start_time, const int64_t end_time); | ||
34 | + | ||
35 | + void AdjustBox(const BoundingBox box, | ||
36 | + float* const translation_x, | ||
37 | + float* const translation_y, | ||
38 | + float* const scale_x, | ||
39 | + float* const scale_y) const; | ||
40 | + | ||
41 | + private: | ||
42 | + // Returns the weighted median of the given deltas, computed independently on | ||
43 | + // x and y. Returns 0,0 in case of failure. The assumption is that a | ||
44 | + // translation of 0.0 in the degenerate case is the best that can be done, and | ||
45 | + // should not be considered an error. | ||
46 | + // | ||
47 | + // In the case of scale, a slight exception is made just to be safe and | ||
48 | + // there is a check for 0.0 explicitly, but that shouldn't ever be possible to | ||
49 | + // happen naturally because of the non-zero + parity checks in FillScales. | ||
50 | + Point2f GetWeightedMedian(const float* const weights, | ||
51 | + const Point2f* const deltas) const; | ||
52 | + | ||
53 | + float GetWeightedMedianScale(const float* const weights, | ||
54 | + const Point2f* const deltas) const; | ||
55 | + | ||
56 | + // Weights points based on the query_point and cutoff_dist. | ||
57 | + int FillWeights(const BoundingBox& box, | ||
58 | + float* const weights) const; | ||
59 | + | ||
60 | + // Fills in the array of deltas with the translations of the points | ||
61 | + // between frames. | ||
62 | + void FillTranslations(Point2f* const translations) const; | ||
63 | + | ||
64 | + // Fills in the array of deltas with the relative scale factor of points | ||
65 | + // relative to a given center. Has the ability to override the weight to 0 if | ||
66 | + // a degenerate scale is detected. | ||
67 | + // Translation is the amount the center of the box has moved from one frame to | ||
68 | + // the next. | ||
69 | + int FillScales(const Point2f& old_center, | ||
70 | + const Point2f& translation, | ||
71 | + float* const weights, | ||
72 | + Point2f* const scales) const; | ||
73 | + | ||
74 | + // TODO(andrewharp): Make these private. | ||
75 | + public: | ||
76 | + // The time at frame1. | ||
77 | + int64_t start_time_; | ||
78 | + | ||
79 | + // The time at frame2. | ||
80 | + int64_t end_time_; | ||
81 | + | ||
82 | + // This array will contain the keypoints found in frame 1. | ||
83 | + Keypoint frame1_keypoints_[kMaxKeypoints]; | ||
84 | + | ||
85 | + // Contain the locations of the keypoints from frame 1 in frame 2. | ||
86 | + Keypoint frame2_keypoints_[kMaxKeypoints]; | ||
87 | + | ||
88 | + // The number of keypoints in frame 1. | ||
89 | + int number_of_keypoints_; | ||
90 | + | ||
91 | + // Keeps track of which keypoint correspondences were actually found from one | ||
92 | + // frame to another. | ||
93 | + // The i-th element of this array will be non-zero if and only if the i-th | ||
94 | + // keypoint of frame 1 was found in frame 2. | ||
95 | + bool optical_flow_found_keypoint_[kMaxKeypoints]; | ||
96 | + | ||
97 | + private: | ||
98 | + TF_DISALLOW_COPY_AND_ASSIGN(FramePair); | ||
99 | +}; | ||
100 | + | ||
101 | +} // namespace tf_tracking | ||
102 | + | ||
103 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ |
android/android/jni/object_tracking/geom.h
0 → 100644
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
21 | + | ||
22 | +namespace tf_tracking { | ||
23 | + | ||
24 | +struct Size { | ||
25 | + Size(const int width, const int height) : width(width), height(height) {} | ||
26 | + | ||
27 | + int width; | ||
28 | + int height; | ||
29 | +}; | ||
30 | + | ||
31 | + | ||
32 | +class Point2f { | ||
33 | + public: | ||
34 | + Point2f() : x(0.0f), y(0.0f) {} | ||
35 | + Point2f(const float x, const float y) : x(x), y(y) {} | ||
36 | + | ||
37 | + inline Point2f operator- (const Point2f& that) const { | ||
38 | + return Point2f(this->x - that.x, this->y - that.y); | ||
39 | + } | ||
40 | + | ||
41 | + inline Point2f operator+ (const Point2f& that) const { | ||
42 | + return Point2f(this->x + that.x, this->y + that.y); | ||
43 | + } | ||
44 | + | ||
45 | + inline Point2f& operator+= (const Point2f& that) { | ||
46 | + this->x += that.x; | ||
47 | + this->y += that.y; | ||
48 | + return *this; | ||
49 | + } | ||
50 | + | ||
51 | + inline Point2f& operator-= (const Point2f& that) { | ||
52 | + this->x -= that.x; | ||
53 | + this->y -= that.y; | ||
54 | + return *this; | ||
55 | + } | ||
56 | + | ||
57 | + inline Point2f operator- (const Point2f& that) { | ||
58 | + return Point2f(this->x - that.x, this->y - that.y); | ||
59 | + } | ||
60 | + | ||
61 | + inline float LengthSquared() { | ||
62 | + return Square(this->x) + Square(this->y); | ||
63 | + } | ||
64 | + | ||
65 | + inline float Length() { | ||
66 | + return sqrtf(LengthSquared()); | ||
67 | + } | ||
68 | + | ||
69 | + inline float DistanceSquared(const Point2f& that) { | ||
70 | + return Square(this->x - that.x) + Square(this->y - that.y); | ||
71 | + } | ||
72 | + | ||
73 | + inline float Distance(const Point2f& that) { | ||
74 | + return sqrtf(DistanceSquared(that)); | ||
75 | + } | ||
76 | + | ||
77 | + float x; | ||
78 | + float y; | ||
79 | +}; | ||
80 | + | ||
81 | +inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) { | ||
82 | + stream << point.x << "," << point.y; | ||
83 | + return stream; | ||
84 | +} | ||
85 | + | ||
86 | +class BoundingBox { | ||
87 | + public: | ||
88 | + BoundingBox() | ||
89 | + : left_(0), | ||
90 | + top_(0), | ||
91 | + right_(0), | ||
92 | + bottom_(0) {} | ||
93 | + | ||
94 | + BoundingBox(const BoundingBox& bounding_box) | ||
95 | + : left_(bounding_box.left_), | ||
96 | + top_(bounding_box.top_), | ||
97 | + right_(bounding_box.right_), | ||
98 | + bottom_(bounding_box.bottom_) { | ||
99 | + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_); | ||
100 | + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_); | ||
101 | + } | ||
102 | + | ||
103 | + BoundingBox(const float left, | ||
104 | + const float top, | ||
105 | + const float right, | ||
106 | + const float bottom) | ||
107 | + : left_(left), | ||
108 | + top_(top), | ||
109 | + right_(right), | ||
110 | + bottom_(bottom) { | ||
111 | + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_); | ||
112 | + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_); | ||
113 | + } | ||
114 | + | ||
115 | + BoundingBox(const Point2f& point1, const Point2f& point2) | ||
116 | + : left_(MIN(point1.x, point2.x)), | ||
117 | + top_(MIN(point1.y, point2.y)), | ||
118 | + right_(MAX(point1.x, point2.x)), | ||
119 | + bottom_(MAX(point1.y, point2.y)) {} | ||
120 | + | ||
121 | + inline void CopyToArray(float* const bounds_array) const { | ||
122 | + bounds_array[0] = left_; | ||
123 | + bounds_array[1] = top_; | ||
124 | + bounds_array[2] = right_; | ||
125 | + bounds_array[3] = bottom_; | ||
126 | + } | ||
127 | + | ||
128 | + inline float GetWidth() const { | ||
129 | + return right_ - left_; | ||
130 | + } | ||
131 | + | ||
132 | + inline float GetHeight() const { | ||
133 | + return bottom_ - top_; | ||
134 | + } | ||
135 | + | ||
136 | + inline float GetArea() const { | ||
137 | + const float width = GetWidth(); | ||
138 | + const float height = GetHeight(); | ||
139 | + if (width <= 0 || height <= 0) { | ||
140 | + return 0.0f; | ||
141 | + } | ||
142 | + | ||
143 | + return width * height; | ||
144 | + } | ||
145 | + | ||
146 | + inline Point2f GetCenter() const { | ||
147 | + return Point2f((left_ + right_) / 2.0f, | ||
148 | + (top_ + bottom_) / 2.0f); | ||
149 | + } | ||
150 | + | ||
151 | + inline bool ValidBox() const { | ||
152 | + return GetArea() > 0.0f; | ||
153 | + } | ||
154 | + | ||
155 | + // Returns a bounding box created from the overlapping area of these two. | ||
156 | + inline BoundingBox Intersect(const BoundingBox& that) const { | ||
157 | + const float new_left = MAX(this->left_, that.left_); | ||
158 | + const float new_right = MIN(this->right_, that.right_); | ||
159 | + | ||
160 | + if (new_left >= new_right) { | ||
161 | + return BoundingBox(); | ||
162 | + } | ||
163 | + | ||
164 | + const float new_top = MAX(this->top_, that.top_); | ||
165 | + const float new_bottom = MIN(this->bottom_, that.bottom_); | ||
166 | + | ||
167 | + if (new_top >= new_bottom) { | ||
168 | + return BoundingBox(); | ||
169 | + } | ||
170 | + | ||
171 | + return BoundingBox(new_left, new_top, new_right, new_bottom); | ||
172 | + } | ||
173 | + | ||
174 | + // Returns a bounding box that can contain both boxes. | ||
175 | + inline BoundingBox Union(const BoundingBox& that) const { | ||
176 | + return BoundingBox(MIN(this->left_, that.left_), | ||
177 | + MIN(this->top_, that.top_), | ||
178 | + MAX(this->right_, that.right_), | ||
179 | + MAX(this->bottom_, that.bottom_)); | ||
180 | + } | ||
181 | + | ||
182 | + inline float PascalScore(const BoundingBox& that) const { | ||
183 | + SCHECK(GetArea() > 0.0f, "Empty bounding box!"); | ||
184 | + SCHECK(that.GetArea() > 0.0f, "Empty bounding box!"); | ||
185 | + | ||
186 | + const float intersect_area = this->Intersect(that).GetArea(); | ||
187 | + | ||
188 | + if (intersect_area <= 0) { | ||
189 | + return 0; | ||
190 | + } | ||
191 | + | ||
192 | + const float score = | ||
193 | + intersect_area / (GetArea() + that.GetArea() - intersect_area); | ||
194 | + SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score); | ||
195 | + return score; | ||
196 | + } | ||
197 | + | ||
198 | + inline bool Intersects(const BoundingBox& that) const { | ||
199 | + return InRange(that.left_, left_, right_) | ||
200 | + || InRange(that.right_, left_, right_) | ||
201 | + || InRange(that.top_, top_, bottom_) | ||
202 | + || InRange(that.bottom_, top_, bottom_); | ||
203 | + } | ||
204 | + | ||
205 | + // Returns whether another bounding box is completely inside of this bounding | ||
206 | + // box. Sharing edges is ok. | ||
207 | + inline bool Contains(const BoundingBox& that) const { | ||
208 | + return that.left_ >= left_ && | ||
209 | + that.right_ <= right_ && | ||
210 | + that.top_ >= top_ && | ||
211 | + that.bottom_ <= bottom_; | ||
212 | + } | ||
213 | + | ||
214 | + inline bool Contains(const Point2f& point) const { | ||
215 | + return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_); | ||
216 | + } | ||
217 | + | ||
218 | + inline void Shift(const Point2f shift_amount) { | ||
219 | + left_ += shift_amount.x; | ||
220 | + top_ += shift_amount.y; | ||
221 | + right_ += shift_amount.x; | ||
222 | + bottom_ += shift_amount.y; | ||
223 | + } | ||
224 | + | ||
225 | + inline void ScaleOrigin(const float scale_x, const float scale_y) { | ||
226 | + left_ *= scale_x; | ||
227 | + right_ *= scale_x; | ||
228 | + top_ *= scale_y; | ||
229 | + bottom_ *= scale_y; | ||
230 | + } | ||
231 | + | ||
232 | + inline void Scale(const float scale_x, const float scale_y) { | ||
233 | + const Point2f center = GetCenter(); | ||
234 | + const float half_width = GetWidth() / 2.0f; | ||
235 | + const float half_height = GetHeight() / 2.0f; | ||
236 | + | ||
237 | + left_ = center.x - half_width * scale_x; | ||
238 | + right_ = center.x + half_width * scale_x; | ||
239 | + | ||
240 | + top_ = center.y - half_height * scale_y; | ||
241 | + bottom_ = center.y + half_height * scale_y; | ||
242 | + } | ||
243 | + | ||
244 | + float left_; | ||
245 | + float top_; | ||
246 | + float right_; | ||
247 | + float bottom_; | ||
248 | +}; | ||
249 | +inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) { | ||
250 | + stream << "[" << box.left_ << " - " << box.right_ | ||
251 | + << ", " << box.top_ << " - " << box.bottom_ | ||
252 | + << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]"; | ||
253 | + return stream; | ||
254 | +} | ||
255 | + | ||
256 | + | ||
257 | +class BoundingSquare { | ||
258 | + public: | ||
259 | + BoundingSquare(const float x, const float y, const float size) | ||
260 | + : x_(x), y_(y), size_(size) {} | ||
261 | + | ||
262 | + explicit BoundingSquare(const BoundingBox& box) | ||
263 | + : x_(box.left_), y_(box.top_), size_(box.GetWidth()) { | ||
264 | +#ifdef SANITY_CHECKS | ||
265 | + if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) { | ||
266 | + LOG(WARNING) << "This is not a square: " << box << std::endl; | ||
267 | + } | ||
268 | +#endif | ||
269 | + } | ||
270 | + | ||
271 | + inline BoundingBox ToBoundingBox() const { | ||
272 | + return BoundingBox(x_, y_, x_ + size_, y_ + size_); | ||
273 | + } | ||
274 | + | ||
275 | + inline bool ValidBox() { | ||
276 | + return size_ > 0.0f; | ||
277 | + } | ||
278 | + | ||
279 | + inline void Shift(const Point2f shift_amount) { | ||
280 | + x_ += shift_amount.x; | ||
281 | + y_ += shift_amount.y; | ||
282 | + } | ||
283 | + | ||
284 | + inline void Scale(const float scale) { | ||
285 | + const float new_size = size_ * scale; | ||
286 | + const float position_diff = (new_size - size_) / 2.0f; | ||
287 | + x_ -= position_diff; | ||
288 | + y_ -= position_diff; | ||
289 | + size_ = new_size; | ||
290 | + } | ||
291 | + | ||
292 | + float x_; | ||
293 | + float y_; | ||
294 | + float size_; | ||
295 | +}; | ||
296 | +inline std::ostream& operator<<(std::ostream& stream, | ||
297 | + const BoundingSquare& square) { | ||
298 | + stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]"; | ||
299 | + return stream; | ||
300 | +} | ||
301 | + | ||
302 | + | ||
303 | +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box, | ||
304 | + const float size) { | ||
305 | + const float width_diff = (original_box.GetWidth() - size) / 2.0f; | ||
306 | + const float height_diff = (original_box.GetHeight() - size) / 2.0f; | ||
307 | + return BoundingSquare(original_box.left_ + width_diff, | ||
308 | + original_box.top_ + height_diff, | ||
309 | + size); | ||
310 | +} | ||
311 | + | ||
312 | +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) { | ||
313 | + return GetCenteredSquare( | ||
314 | + original_box, MIN(original_box.GetWidth(), original_box.GetHeight())); | ||
315 | +} | ||
316 | + | ||
317 | +} // namespace tf_tracking | ||
318 | + | ||
319 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ | ||
18 | + | ||
19 | +#include <GLES/gl.h> | ||
20 | +#include <GLES/glext.h> | ||
21 | + | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
23 | + | ||
24 | +namespace tf_tracking { | ||
25 | + | ||
26 | +// Draws a box at the given position. | ||
27 | +inline static void DrawBox(const BoundingBox& bounding_box) { | ||
28 | + const GLfloat line[] = { | ||
29 | + bounding_box.left_, bounding_box.bottom_, | ||
30 | + bounding_box.left_, bounding_box.top_, | ||
31 | + bounding_box.left_, bounding_box.top_, | ||
32 | + bounding_box.right_, bounding_box.top_, | ||
33 | + bounding_box.right_, bounding_box.top_, | ||
34 | + bounding_box.right_, bounding_box.bottom_, | ||
35 | + bounding_box.right_, bounding_box.bottom_, | ||
36 | + bounding_box.left_, bounding_box.bottom_ | ||
37 | + }; | ||
38 | + | ||
39 | + glVertexPointer(2, GL_FLOAT, 0, line); | ||
40 | + glEnableClientState(GL_VERTEX_ARRAY); | ||
41 | + | ||
42 | + glDrawArrays(GL_LINES, 0, 8); | ||
43 | +} | ||
44 | + | ||
45 | + | ||
46 | +// Changes the coordinate system such that drawing to an arbitrary square in | ||
47 | +// the world can thereafter be drawn to using coordinates 0 - 1. | ||
48 | +inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) { | ||
49 | + glScalef(square.size_, square.size_, 1.0f); | ||
50 | + glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f); | ||
51 | +} | ||
52 | + | ||
53 | +} // namespace tf_tracking | ||
54 | + | ||
55 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ | ||
18 | + | ||
19 | +#include <stdint.h> | ||
20 | + | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
24 | + | ||
25 | +namespace tf_tracking { | ||
26 | + | ||
27 | +template <typename T> | ||
28 | +Image<T>::Image(const int width, const int height) | ||
29 | + : width_less_one_(width - 1), | ||
30 | + height_less_one_(height - 1), | ||
31 | + data_size_(width * height), | ||
32 | + own_data_(true), | ||
33 | + width_(width), | ||
34 | + height_(height), | ||
35 | + stride_(width) { | ||
36 | + Allocate(); | ||
37 | +} | ||
38 | + | ||
39 | +template <typename T> | ||
40 | +Image<T>::Image(const Size& size) | ||
41 | + : width_less_one_(size.width - 1), | ||
42 | + height_less_one_(size.height - 1), | ||
43 | + data_size_(size.width * size.height), | ||
44 | + own_data_(true), | ||
45 | + width_(size.width), | ||
46 | + height_(size.height), | ||
47 | + stride_(size.width) { | ||
48 | + Allocate(); | ||
49 | +} | ||
50 | + | ||
51 | +// Constructor that creates an image from preallocated data. | ||
52 | +// Note: The image takes ownership of the data lifecycle, unless own_data is | ||
53 | +// set to false. | ||
54 | +template <typename T> | ||
55 | +Image<T>::Image(const int width, const int height, T* const image_data, | ||
56 | + const bool own_data) : | ||
57 | + width_less_one_(width - 1), | ||
58 | + height_less_one_(height - 1), | ||
59 | + data_size_(width * height), | ||
60 | + own_data_(own_data), | ||
61 | + width_(width), | ||
62 | + height_(height), | ||
63 | + stride_(width) { | ||
64 | + image_data_ = image_data; | ||
65 | + SCHECK(image_data_ != NULL, "Can't create image with NULL data!"); | ||
66 | +} | ||
67 | + | ||
68 | +template <typename T> | ||
69 | +Image<T>::~Image() { | ||
70 | + if (own_data_) { | ||
71 | + delete[] image_data_; | ||
72 | + } | ||
73 | + image_data_ = NULL; | ||
74 | +} | ||
75 | + | ||
76 | +template<typename T> | ||
77 | +template<class DstType> | ||
78 | +bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x, | ||
79 | + const int fp_y, | ||
80 | + const int patchwidth, | ||
81 | + const int patchheight, | ||
82 | + DstType* to_data) const { | ||
83 | + // Calculate weights. | ||
84 | + const int trunc_x = fp_x >> 16; | ||
85 | + const int trunc_y = fp_y >> 16; | ||
86 | + | ||
87 | + if (trunc_x < 0 || trunc_y < 0 || | ||
88 | + (trunc_x + patchwidth) >= width_less_one_ || | ||
89 | + (trunc_y + patchheight) >= height_less_one_) { | ||
90 | + return false; | ||
91 | + } | ||
92 | + | ||
93 | + // Now walk over destination patch and fill from interpolated source image. | ||
94 | + for (int y = 0; y < patchheight; ++y, to_data += patchwidth) { | ||
95 | + for (int x = 0; x < patchwidth; ++x) { | ||
96 | + to_data[x] = | ||
97 | + static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16), | ||
98 | + fp_y + (y << 16))); | ||
99 | + } | ||
100 | + } | ||
101 | + | ||
102 | + return true; | ||
103 | +} | ||
104 | + | ||
105 | +template <typename T> | ||
106 | +Image<T>* Image<T>::Crop( | ||
107 | + const int left, const int top, const int right, const int bottom) const { | ||
108 | + SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left); | ||
109 | + SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right); | ||
110 | + SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top); | ||
111 | + SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom); | ||
112 | + | ||
113 | + SCHECK(left <= right, "mismatch!"); | ||
114 | + SCHECK(top <= bottom, "mismatch!"); | ||
115 | + | ||
116 | + const int new_width = right - left + 1; | ||
117 | + const int new_height = bottom - top + 1; | ||
118 | + | ||
119 | + Image<T>* const cropped_image = new Image(new_width, new_height); | ||
120 | + | ||
121 | + for (int y = 0; y < new_height; ++y) { | ||
122 | + memcpy((*cropped_image)[y], ((*this)[y + top] + left), | ||
123 | + new_width * sizeof(T)); | ||
124 | + } | ||
125 | + | ||
126 | + return cropped_image; | ||
127 | +} | ||
128 | + | ||
129 | +template <typename T> | ||
130 | +inline float Image<T>::GetPixelInterp(const float x, const float y) const { | ||
131 | + // Do int conversion one time. | ||
132 | + const int floored_x = static_cast<int>(x); | ||
133 | + const int floored_y = static_cast<int>(y); | ||
134 | + | ||
135 | + // Note: it might be the case that the *_[min|max] values are clipped, and | ||
136 | + // these (the a b c d vals) aren't (for speed purposes), but that doesn't | ||
137 | + // matter. We'll just be blending the pixel with itself in that case anyway. | ||
138 | + const float b = x - floored_x; | ||
139 | + const float a = 1.0f - b; | ||
140 | + | ||
141 | + const float d = y - floored_y; | ||
142 | + const float c = 1.0f - d; | ||
143 | + | ||
144 | + SCHECK(ValidInterpPixel(x, y), | ||
145 | + "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)", | ||
146 | + x, width_less_one_, y, height_less_one_); | ||
147 | + | ||
148 | + const T* const pix_ptr = (*this)[floored_y] + floored_x; | ||
149 | + | ||
150 | + // Get the pixel values surrounding this point. | ||
151 | + const T& p1 = pix_ptr[0]; | ||
152 | + const T& p2 = pix_ptr[1]; | ||
153 | + const T& p3 = pix_ptr[width_]; | ||
154 | + const T& p4 = pix_ptr[width_ + 1]; | ||
155 | + | ||
156 | + // Simple bilinear interpolation between four reference pixels. | ||
157 | + // If x is the value requested: | ||
158 | + // a b | ||
159 | + // ------- | ||
160 | + // c |p1 p2| | ||
161 | + // | x | | ||
162 | + // d |p3 p4| | ||
163 | + // ------- | ||
164 | + return c * ((a * p1) + (b * p2)) + | ||
165 | + d * ((a * p3) + (b * p4)); | ||
166 | +} | ||
167 | + | ||
168 | + | ||
169 | +template <typename T> | ||
170 | +inline T Image<T>::GetPixelInterpFixed1616( | ||
171 | + const int fp_x_whole, const int fp_y_whole) const { | ||
172 | + static const int kFixedPointOne = 0x00010000; | ||
173 | + static const int kFixedPointHalf = 0x00008000; | ||
174 | + static const int kFixedPointTruncateMask = 0xFFFF0000; | ||
175 | + | ||
176 | + int trunc_x = fp_x_whole & kFixedPointTruncateMask; | ||
177 | + int trunc_y = fp_y_whole & kFixedPointTruncateMask; | ||
178 | + const int fp_x = fp_x_whole - trunc_x; | ||
179 | + const int fp_y = fp_y_whole - trunc_y; | ||
180 | + | ||
181 | + // Scale the truncated values back to regular ints. | ||
182 | + trunc_x >>= 16; | ||
183 | + trunc_y >>= 16; | ||
184 | + | ||
185 | + const int one_minus_fp_x = kFixedPointOne - fp_x; | ||
186 | + const int one_minus_fp_y = kFixedPointOne - fp_y; | ||
187 | + | ||
188 | + const T* trunc_start = (*this)[trunc_y] + trunc_x; | ||
189 | + | ||
190 | + const T a = trunc_start[0]; | ||
191 | + const T b = trunc_start[1]; | ||
192 | + const T c = trunc_start[stride_]; | ||
193 | + const T d = trunc_start[stride_ + 1]; | ||
194 | + | ||
195 | + return ( | ||
196 | + (one_minus_fp_y * static_cast<int64_t>(one_minus_fp_x * a + fp_x * b) + | ||
197 | + fp_y * static_cast<int64_t>(one_minus_fp_x * c + fp_x * d) + | ||
198 | + kFixedPointHalf) >> | ||
199 | + 32); | ||
200 | +} | ||
201 | + | ||
202 | +template <typename T> | ||
203 | +inline bool Image<T>::ValidPixel(const int x, const int y) const { | ||
204 | + return InRange(x, ZERO, width_less_one_) && | ||
205 | + InRange(y, ZERO, height_less_one_); | ||
206 | +} | ||
207 | + | ||
208 | +template <typename T> | ||
209 | +inline BoundingBox Image<T>::GetContainingBox() const { | ||
210 | + return BoundingBox( | ||
211 | + 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON); | ||
212 | +} | ||
213 | + | ||
214 | +template <typename T> | ||
215 | +inline bool Image<T>::Contains(const BoundingBox& bounding_box) const { | ||
216 | + // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds | ||
217 | + // are ok. | ||
218 | + return GetContainingBox().Contains(bounding_box); | ||
219 | +} | ||
220 | + | ||
221 | +template <typename T> | ||
222 | +inline bool Image<T>::ValidInterpPixel(const float x, const float y) const { | ||
223 | + // Exclusive of max because we can be more efficient if we don't handle | ||
224 | + // interpolating on or past the last pixel. | ||
225 | + return (x >= ZERO) && (x < width_less_one_) && | ||
226 | + (y >= ZERO) && (y < height_less_one_); | ||
227 | +} | ||
228 | + | ||
229 | +template <typename T> | ||
230 | +void Image<T>::DownsampleAveraged(const T* const original, const int stride, | ||
231 | + const int factor) { | ||
232 | +#ifdef __ARM_NEON | ||
233 | + if (factor == 4 || factor == 2) { | ||
234 | + DownsampleAveragedNeon(original, stride, factor); | ||
235 | + return; | ||
236 | + } | ||
237 | +#endif | ||
238 | + | ||
239 | + // TODO(andrewharp): delete or enable this for non-uint8_t downsamples. | ||
240 | + const int pixels_per_block = factor * factor; | ||
241 | + | ||
242 | + // For every pixel in resulting image. | ||
243 | + for (int y = 0; y < height_; ++y) { | ||
244 | + const int orig_y = y * factor; | ||
245 | + const int y_bound = orig_y + factor; | ||
246 | + | ||
247 | + // Sum up the original pixels. | ||
248 | + for (int x = 0; x < width_; ++x) { | ||
249 | + const int orig_x = x * factor; | ||
250 | + const int x_bound = orig_x + factor; | ||
251 | + | ||
252 | + // Making this int32_t because type U or T might overflow. | ||
253 | + int32_t pixel_sum = 0; | ||
254 | + | ||
255 | + // Grab all the pixels that make up this pixel. | ||
256 | + for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) { | ||
257 | + const T* p = original + curr_y * stride + orig_x; | ||
258 | + | ||
259 | + for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) { | ||
260 | + pixel_sum += *p++; | ||
261 | + } | ||
262 | + } | ||
263 | + | ||
264 | + (*this)[y][x] = pixel_sum / pixels_per_block; | ||
265 | + } | ||
266 | + } | ||
267 | +} | ||
268 | + | ||
269 | +template <typename T> | ||
270 | +void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) { | ||
271 | + // Calculating the scaling factors based on target image size. | ||
272 | + const float factor_x = static_cast<float>(original.GetWidth()) / | ||
273 | + static_cast<float>(width_); | ||
274 | + const float factor_y = static_cast<float>(original.GetHeight()) / | ||
275 | + static_cast<float>(height_); | ||
276 | + | ||
277 | + // Calculating initial offset in x-axis. | ||
278 | + const float offset_x = 0.5f * (original.GetWidth() - width_) / width_; | ||
279 | + | ||
280 | + // Calculating initial offset in y-axis. | ||
281 | + const float offset_y = 0.5f * (original.GetHeight() - height_) / height_; | ||
282 | + | ||
283 | + float orig_y = offset_y; | ||
284 | + | ||
285 | + // For every pixel in resulting image. | ||
286 | + for (int y = 0; y < height_; ++y) { | ||
287 | + float orig_x = offset_x; | ||
288 | + | ||
289 | + // Finding nearest pixel on y-axis. | ||
290 | + const int nearest_y = static_cast<int>(orig_y + 0.5f); | ||
291 | + const T* row_data = original[nearest_y]; | ||
292 | + | ||
293 | + T* pixel_ptr = (*this)[y]; | ||
294 | + | ||
295 | + for (int x = 0; x < width_; ++x) { | ||
296 | + // Finding nearest pixel on x-axis. | ||
297 | + const int nearest_x = static_cast<int>(orig_x + 0.5f); | ||
298 | + | ||
299 | + *pixel_ptr++ = row_data[nearest_x]; | ||
300 | + | ||
301 | + orig_x += factor_x; | ||
302 | + } | ||
303 | + | ||
304 | + orig_y += factor_y; | ||
305 | + } | ||
306 | +} | ||
307 | + | ||
308 | +template <typename T> | ||
309 | +void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) { | ||
310 | + // TODO(andrewharp): Turn this into a general compare sizes/bulk | ||
311 | + // copy method. | ||
312 | + if (original.GetWidth() == GetWidth() && | ||
313 | + original.GetHeight() == GetHeight() && | ||
314 | + original.stride() == stride()) { | ||
315 | + memcpy(image_data_, original.data(), data_size_ * sizeof(T)); | ||
316 | + return; | ||
317 | + } | ||
318 | + | ||
319 | + // Calculating the scaling factors based on target image size. | ||
320 | + const float factor_x = static_cast<float>(original.GetWidth()) / | ||
321 | + static_cast<float>(width_); | ||
322 | + const float factor_y = static_cast<float>(original.GetHeight()) / | ||
323 | + static_cast<float>(height_); | ||
324 | + | ||
325 | + // Calculating initial offset in x-axis. | ||
326 | + const float offset_x = 0; | ||
327 | + const int offset_x_fp = RealToFixed1616(offset_x); | ||
328 | + | ||
329 | + // Calculating initial offset in y-axis. | ||
330 | + const float offset_y = 0; | ||
331 | + const int offset_y_fp = RealToFixed1616(offset_y); | ||
332 | + | ||
333 | + // Get the fixed point scaling factor value. | ||
334 | + // Shift by 8 so we can fit everything into a 4 byte int later for speed | ||
335 | + // reasons. This means the precision is limited to 1 / 256th of a pixel, | ||
336 | + // but this should be good enough. | ||
337 | + const int factor_x_fp = RealToFixed1616(factor_x) >> 8; | ||
338 | + const int factor_y_fp = RealToFixed1616(factor_y) >> 8; | ||
339 | + | ||
340 | + int src_y_fp = offset_y_fp >> 8; | ||
341 | + | ||
342 | + static const int kFixedPointOne8 = 0x00000100; | ||
343 | + static const int kFixedPointHalf8 = 0x00000080; | ||
344 | + static const int kFixedPointTruncateMask8 = 0xFFFFFF00; | ||
345 | + | ||
346 | + // For every pixel in resulting image. | ||
347 | + for (int y = 0; y < height_; ++y) { | ||
348 | + int src_x_fp = offset_x_fp >> 8; | ||
349 | + | ||
350 | + int trunc_y = src_y_fp & kFixedPointTruncateMask8; | ||
351 | + const int fp_y = src_y_fp - trunc_y; | ||
352 | + | ||
353 | + // Scale the truncated values back to regular ints. | ||
354 | + trunc_y >>= 8; | ||
355 | + | ||
356 | + const int one_minus_fp_y = kFixedPointOne8 - fp_y; | ||
357 | + | ||
358 | + T* pixel_ptr = (*this)[y]; | ||
359 | + | ||
360 | + // Make sure not to read from an invalid row. | ||
361 | + const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1); | ||
362 | + const T* other_top_ptr = original[trunc_y]; | ||
363 | + const T* other_bot_ptr = original[trunc_y_b]; | ||
364 | + | ||
365 | + int last_trunc_x = -1; | ||
366 | + int trunc_x = -1; | ||
367 | + | ||
368 | + T a = 0; | ||
369 | + T b = 0; | ||
370 | + T c = 0; | ||
371 | + T d = 0; | ||
372 | + | ||
373 | + for (int x = 0; x < width_; ++x) { | ||
374 | + trunc_x = src_x_fp & kFixedPointTruncateMask8; | ||
375 | + | ||
376 | + const int fp_x = (src_x_fp - trunc_x) >> 8; | ||
377 | + | ||
378 | + // Scale the truncated values back to regular ints. | ||
379 | + trunc_x >>= 8; | ||
380 | + | ||
381 | + // It's possible we're reading from the same pixels | ||
382 | + if (trunc_x != last_trunc_x) { | ||
383 | + // Make sure not to read from an invalid column. | ||
384 | + const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1); | ||
385 | + a = other_top_ptr[trunc_x]; | ||
386 | + b = other_top_ptr[trunc_x_b]; | ||
387 | + c = other_bot_ptr[trunc_x]; | ||
388 | + d = other_bot_ptr[trunc_x_b]; | ||
389 | + last_trunc_x = trunc_x; | ||
390 | + } | ||
391 | + | ||
392 | + const int one_minus_fp_x = kFixedPointOne8 - fp_x; | ||
393 | + | ||
394 | + const int32_t value = | ||
395 | + ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) + | ||
396 | + (fp_y * one_minus_fp_x * c + fp_x * d) + kFixedPointHalf8) >> | ||
397 | + 16; | ||
398 | + | ||
399 | + *pixel_ptr++ = value; | ||
400 | + | ||
401 | + src_x_fp += factor_x_fp; | ||
402 | + } | ||
403 | + src_y_fp += factor_y_fp; | ||
404 | + } | ||
405 | +} | ||
406 | + | ||
407 | +template <typename T> | ||
408 | +void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) { | ||
409 | + for (int y = 0; y < height_; ++y) { | ||
410 | + const int orig_y = Clip(2 * y, ZERO, original.height_less_one_); | ||
411 | + const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_); | ||
412 | + const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_); | ||
413 | + | ||
414 | + for (int x = 0; x < width_; ++x) { | ||
415 | + const int orig_x = Clip(2 * x, ZERO, original.width_less_one_); | ||
416 | + const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_); | ||
417 | + const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_); | ||
418 | + | ||
419 | + // Center. | ||
420 | + int32_t pixel_sum = original[orig_y][orig_x] * 4; | ||
421 | + | ||
422 | + // Sides. | ||
423 | + pixel_sum += (original[orig_y][max_x] + | ||
424 | + original[orig_y][min_x] + | ||
425 | + original[max_y][orig_x] + | ||
426 | + original[min_y][orig_x]) * 2; | ||
427 | + | ||
428 | + // Diagonals. | ||
429 | + pixel_sum += (original[min_y][max_x] + | ||
430 | + original[min_y][min_x] + | ||
431 | + original[max_y][max_x] + | ||
432 | + original[max_y][min_x]); | ||
433 | + | ||
434 | + (*this)[y][x] = pixel_sum >> 4; // 16 | ||
435 | + } | ||
436 | + } | ||
437 | +} | ||
438 | + | ||
439 | +template <typename T> | ||
440 | +void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) { | ||
441 | + const int max_x = original.width_less_one_; | ||
442 | + const int max_y = original.height_less_one_; | ||
443 | + | ||
444 | + // The JY Bouget paper on Lucas-Kanade recommends a | ||
445 | + // [1/16 1/4 3/8 1/4 1/16]^2 filter. | ||
446 | + // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below. | ||
447 | + static const int window_radius = 2; | ||
448 | + static const int window_size = window_radius*2 + 1; | ||
449 | + static const int window_weights[] = {1, 4, 6, 4, 1, // 16 + | ||
450 | + 4, 16, 24, 16, 4, // 64 + | ||
451 | + 6, 24, 36, 24, 6, // 96 + | ||
452 | + 4, 16, 24, 16, 4, // 64 + | ||
453 | + 1, 4, 6, 4, 1}; // 16 = 256 | ||
454 | + | ||
455 | + // We'll multiply and sum with the whole numbers first, then divide by | ||
456 | + // the total weight to normalize at the last moment. | ||
457 | + for (int y = 0; y < height_; ++y) { | ||
458 | + for (int x = 0; x < width_; ++x) { | ||
459 | + int32_t pixel_sum = 0; | ||
460 | + | ||
461 | + const int* w = window_weights; | ||
462 | + const int start_x = Clip((x << 1) - window_radius, ZERO, max_x); | ||
463 | + | ||
464 | + // Clip the boundaries to the size of the image. | ||
465 | + for (int window_y = 0; window_y < window_size; ++window_y) { | ||
466 | + const int start_y = | ||
467 | + Clip((y << 1) - window_radius + window_y, ZERO, max_y); | ||
468 | + | ||
469 | + const T* p = original[start_y] + start_x; | ||
470 | + | ||
471 | + for (int window_x = 0; window_x < window_size; ++window_x) { | ||
472 | + pixel_sum += *p++ * *w++; | ||
473 | + } | ||
474 | + } | ||
475 | + | ||
476 | + // Conversion to type T will happen here after shifting right 8 bits to | ||
477 | + // divide by 256. | ||
478 | + (*this)[y][x] = pixel_sum >> 8; | ||
479 | + } | ||
480 | + } | ||
481 | +} | ||
482 | + | ||
483 | +template <typename T> | ||
484 | +template <typename U> | ||
485 | +inline T Image<T>::ScharrPixelX(const Image<U>& original, | ||
486 | + const int center_x, const int center_y) const { | ||
487 | + const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_); | ||
488 | + const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_); | ||
489 | + const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_); | ||
490 | + const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_); | ||
491 | + | ||
492 | + // Convolution loop unrolled for performance... | ||
493 | + return (3 * (original[min_y][max_x] | ||
494 | + + original[max_y][max_x] | ||
495 | + - original[min_y][min_x] | ||
496 | + - original[max_y][min_x]) | ||
497 | + + 10 * (original[center_y][max_x] | ||
498 | + - original[center_y][min_x])) / 32; | ||
499 | +} | ||
500 | + | ||
501 | +template <typename T> | ||
502 | +template <typename U> | ||
503 | +inline T Image<T>::ScharrPixelY(const Image<U>& original, | ||
504 | + const int center_x, const int center_y) const { | ||
505 | + const int min_x = Clip(center_x - 1, 0, original.width_less_one_); | ||
506 | + const int max_x = Clip(center_x + 1, 0, original.width_less_one_); | ||
507 | + const int min_y = Clip(center_y - 1, 0, original.height_less_one_); | ||
508 | + const int max_y = Clip(center_y + 1, 0, original.height_less_one_); | ||
509 | + | ||
510 | + // Convolution loop unrolled for performance... | ||
511 | + return (3 * (original[max_y][min_x] | ||
512 | + + original[max_y][max_x] | ||
513 | + - original[min_y][min_x] | ||
514 | + - original[min_y][max_x]) | ||
515 | + + 10 * (original[max_y][center_x] | ||
516 | + - original[min_y][center_x])) / 32; | ||
517 | +} | ||
518 | + | ||
519 | +template <typename T> | ||
520 | +template <typename U> | ||
521 | +inline void Image<T>::ScharrX(const Image<U>& original) { | ||
522 | + for (int y = 0; y < height_; ++y) { | ||
523 | + for (int x = 0; x < width_; ++x) { | ||
524 | + SetPixel(x, y, ScharrPixelX(original, x, y)); | ||
525 | + } | ||
526 | + } | ||
527 | +} | ||
528 | + | ||
529 | +template <typename T> | ||
530 | +template <typename U> | ||
531 | +inline void Image<T>::ScharrY(const Image<U>& original) { | ||
532 | + for (int y = 0; y < height_; ++y) { | ||
533 | + for (int x = 0; x < width_; ++x) { | ||
534 | + SetPixel(x, y, ScharrPixelY(original, x, y)); | ||
535 | + } | ||
536 | + } | ||
537 | +} | ||
538 | + | ||
539 | +template <typename T> | ||
540 | +template <typename U> | ||
541 | +void Image<T>::DerivativeX(const Image<U>& original) { | ||
542 | + for (int y = 0; y < height_; ++y) { | ||
543 | + const U* const source_row = original[y]; | ||
544 | + T* const dest_row = (*this)[y]; | ||
545 | + | ||
546 | + // Compute first pixel. Approximated with forward difference. | ||
547 | + dest_row[0] = source_row[1] - source_row[0]; | ||
548 | + | ||
549 | + // All the pixels in between. Central difference method. | ||
550 | + const U* source_prev_pixel = source_row; | ||
551 | + T* dest_pixel = dest_row + 1; | ||
552 | + const U* source_next_pixel = source_row + 2; | ||
553 | + for (int x = 1; x < width_less_one_; ++x) { | ||
554 | + *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++); | ||
555 | + } | ||
556 | + | ||
557 | + // Last pixel. Approximated with backward difference. | ||
558 | + dest_row[width_less_one_] = | ||
559 | + source_row[width_less_one_] - source_row[width_less_one_ - 1]; | ||
560 | + } | ||
561 | +} | ||
562 | + | ||
563 | +template <typename T> | ||
564 | +template <typename U> | ||
565 | +void Image<T>::DerivativeY(const Image<U>& original) { | ||
566 | + const int src_stride = original.stride(); | ||
567 | + | ||
568 | + // Compute 1st row. Approximated with forward difference. | ||
569 | + { | ||
570 | + const U* const src_row = original[0]; | ||
571 | + T* dest_row = (*this)[0]; | ||
572 | + for (int x = 0; x < width_; ++x) { | ||
573 | + dest_row[x] = src_row[x + src_stride] - src_row[x]; | ||
574 | + } | ||
575 | + } | ||
576 | + | ||
577 | + // Compute all rows in between using central difference. | ||
578 | + for (int y = 1; y < height_less_one_; ++y) { | ||
579 | + T* dest_row = (*this)[y]; | ||
580 | + | ||
581 | + const U* source_prev_pixel = original[y - 1]; | ||
582 | + const U* source_next_pixel = original[y + 1]; | ||
583 | + for (int x = 0; x < width_; ++x) { | ||
584 | + *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++); | ||
585 | + } | ||
586 | + } | ||
587 | + | ||
588 | + // Compute last row. Approximated with backward difference. | ||
589 | + { | ||
590 | + const U* const src_row = original[height_less_one_]; | ||
591 | + T* dest_row = (*this)[height_less_one_]; | ||
592 | + for (int x = 0; x < width_; ++x) { | ||
593 | + dest_row[x] = src_row[x] - src_row[x - src_stride]; | ||
594 | + } | ||
595 | + } | ||
596 | +} | ||
597 | + | ||
598 | +template <typename T> | ||
599 | +template <typename U> | ||
600 | +inline T Image<T>::ConvolvePixel3x3(const Image<U>& original, | ||
601 | + const int* const filter, | ||
602 | + const int center_x, const int center_y, | ||
603 | + const int total) const { | ||
604 | + int32_t sum = 0; | ||
605 | + for (int filter_y = 0; filter_y < 3; ++filter_y) { | ||
606 | + const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight()); | ||
607 | + for (int filter_x = 0; filter_x < 3; ++filter_x) { | ||
608 | + const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth()); | ||
609 | + sum += original[y][x] * filter[filter_y * 3 + filter_x]; | ||
610 | + } | ||
611 | + } | ||
612 | + return sum / total; | ||
613 | +} | ||
614 | + | ||
615 | +template <typename T> | ||
616 | +template <typename U> | ||
617 | +inline void Image<T>::Convolve3x3(const Image<U>& original, | ||
618 | + const int32_t* const filter) { | ||
619 | + int32_t sum = 0; | ||
620 | + for (int i = 0; i < 9; ++i) { | ||
621 | + sum += abs(filter[i]); | ||
622 | + } | ||
623 | + for (int y = 0; y < height_; ++y) { | ||
624 | + for (int x = 0; x < width_; ++x) { | ||
625 | + SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum)); | ||
626 | + } | ||
627 | + } | ||
628 | +} | ||
629 | + | ||
630 | +template <typename T> | ||
631 | +inline void Image<T>::FromArray(const T* const pixels, const int stride, | ||
632 | + const int factor) { | ||
633 | + if (factor == 1 && stride == width_) { | ||
634 | + // If not subsampling, memcpy per line should be faster. | ||
635 | + memcpy(this->image_data_, pixels, data_size_ * sizeof(T)); | ||
636 | + return; | ||
637 | + } | ||
638 | + | ||
639 | + DownsampleAveraged(pixels, stride, factor); | ||
640 | +} | ||
641 | + | ||
642 | +} // namespace tf_tracking | ||
643 | + | ||
644 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ |
android/android/jni/object_tracking/image.h
0 → 100644
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ | ||
18 | + | ||
19 | +#include <stdint.h> | ||
20 | + | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
23 | + | ||
24 | +// TODO(andrewharp): Make this a cast to uint32_t if/when we go unsigned for | ||
25 | +// operations. | ||
26 | +#define ZERO 0 | ||
27 | + | ||
28 | +#ifdef SANITY_CHECKS | ||
29 | + #define CHECK_PIXEL(IMAGE, X, Y) {\ | ||
30 | + SCHECK((IMAGE)->ValidPixel((X), (Y)), \ | ||
31 | + "CHECK_PIXEL(%d,%d) in %dx%d image.", \ | ||
32 | + static_cast<int>(X), static_cast<int>(Y), \ | ||
33 | + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\ | ||
34 | + } | ||
35 | + | ||
36 | + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\ | ||
37 | + SCHECK((IMAGE)->validInterpPixel((X), (Y)), \ | ||
38 | + "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \ | ||
39 | + static_cast<float>(X), static_cast<float>(Y), \ | ||
40 | + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\ | ||
41 | + } | ||
42 | +#else | ||
43 | + #define CHECK_PIXEL(image, x, y) {} | ||
44 | + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {} | ||
45 | +#endif | ||
46 | + | ||
47 | +namespace tf_tracking { | ||
48 | + | ||
49 | +#ifdef SANITY_CHECKS | ||
50 | +// Class which exists solely to provide bounds checking for array-style image | ||
51 | +// data access. | ||
52 | +template <typename T> | ||
53 | +class RowData { | ||
54 | + public: | ||
55 | + RowData(T* const row_data, const int max_col) | ||
56 | + : row_data_(row_data), max_col_(max_col) {} | ||
57 | + | ||
58 | + inline T& operator[](const int col) const { | ||
59 | + SCHECK(InRange(col, 0, max_col_), | ||
60 | + "Column out of range: %d (%d max)", col, max_col_); | ||
61 | + return row_data_[col]; | ||
62 | + } | ||
63 | + | ||
64 | + inline operator T*() const { | ||
65 | + return row_data_; | ||
66 | + } | ||
67 | + | ||
68 | + private: | ||
69 | + T* const row_data_; | ||
70 | + const int max_col_; | ||
71 | +}; | ||
72 | +#endif | ||
73 | + | ||
74 | +// Naive templated sorting function. | ||
75 | +template <typename T> | ||
76 | +int Comp(const void* a, const void* b) { | ||
77 | + const T val1 = *reinterpret_cast<const T*>(a); | ||
78 | + const T val2 = *reinterpret_cast<const T*>(b); | ||
79 | + | ||
80 | + if (val1 == val2) { | ||
81 | + return 0; | ||
82 | + } else if (val1 < val2) { | ||
83 | + return -1; | ||
84 | + } else { | ||
85 | + return 1; | ||
86 | + } | ||
87 | +} | ||
88 | + | ||
89 | +// TODO(andrewharp): Make explicit which operations support negative numbers or | ||
90 | +// struct/class types in image data (possibly create fast multi-dim array class | ||
91 | +// for data where pixel arithmetic does not make sense). | ||
92 | + | ||
93 | +// Image class optimized for working on numeric arrays as grayscale image data. | ||
94 | +// Supports other data types as a 2D array class, so long as no pixel math | ||
95 | +// operations are called (convolution, downsampling, etc). | ||
96 | +template <typename T> | ||
97 | +class Image { | ||
98 | + public: | ||
99 | + Image(const int width, const int height); | ||
100 | + explicit Image(const Size& size); | ||
101 | + | ||
102 | + // Constructor that creates an image from preallocated data. | ||
103 | + // Note: The image takes ownership of the data lifecycle, unless own_data is | ||
104 | + // set to false. | ||
105 | + Image(const int width, const int height, T* const image_data, | ||
106 | + const bool own_data = true); | ||
107 | + | ||
108 | + ~Image(); | ||
109 | + | ||
110 | + // Extract a pixel patch from this image, starting at a subpixel location. | ||
111 | + // Uses 16:16 fixed point format for representing real values and doing the | ||
112 | + // bilinear interpolation. | ||
113 | + // | ||
114 | + // Arguments fp_x and fp_y tell the subpixel position in fixed point format, | ||
115 | + // patchwidth/patchheight give the size of the patch in pixels and | ||
116 | + // to_data must be a valid pointer to a *contiguous* destination data array. | ||
117 | + template<class DstType> | ||
118 | + bool ExtractPatchAtSubpixelFixed1616(const int fp_x, | ||
119 | + const int fp_y, | ||
120 | + const int patchwidth, | ||
121 | + const int patchheight, | ||
122 | + DstType* to_data) const; | ||
123 | + | ||
124 | + Image<T>* Crop( | ||
125 | + const int left, const int top, const int right, const int bottom) const; | ||
126 | + | ||
127 | + inline int GetWidth() const { return width_; } | ||
128 | + inline int GetHeight() const { return height_; } | ||
129 | + | ||
130 | + // Bilinearly sample a value between pixels. Values must be within the image. | ||
131 | + inline float GetPixelInterp(const float x, const float y) const; | ||
132 | + | ||
133 | + // Bilinearly sample a pixels at a subpixel position using fixed point | ||
134 | + // arithmetic. | ||
135 | + // Avoids float<->int conversions. | ||
136 | + // Values must be within the image. | ||
137 | + // Arguments fp_x and fp_y tell the subpixel position in | ||
138 | + // 16:16 fixed point format. | ||
139 | + // | ||
140 | + // Important: This function only makes sense for integer-valued images, such | ||
141 | + // as Image<uint8_t> or Image<int> etc. | ||
142 | + inline T GetPixelInterpFixed1616(const int fp_x_whole, | ||
143 | + const int fp_y_whole) const; | ||
144 | + | ||
145 | + // Returns true iff the pixel is in the image's boundaries. | ||
146 | + inline bool ValidPixel(const int x, const int y) const; | ||
147 | + | ||
148 | + inline BoundingBox GetContainingBox() const; | ||
149 | + | ||
150 | + inline bool Contains(const BoundingBox& bounding_box) const; | ||
151 | + | ||
152 | + inline T GetMedianValue() { | ||
153 | + qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>); | ||
154 | + return image_data_[data_size_ >> 1]; | ||
155 | + } | ||
156 | + | ||
157 | + // Returns true iff the pixel is in the image's boundaries for interpolation | ||
158 | + // purposes. | ||
159 | + // TODO(andrewharp): check in interpolation follow-up change. | ||
160 | + inline bool ValidInterpPixel(const float x, const float y) const; | ||
161 | + | ||
162 | + // Safe lookup with boundary enforcement. | ||
163 | + inline T GetPixelClipped(const int x, const int y) const { | ||
164 | + return (*this)[Clip(y, ZERO, height_less_one_)] | ||
165 | + [Clip(x, ZERO, width_less_one_)]; | ||
166 | + } | ||
167 | + | ||
168 | +#ifdef SANITY_CHECKS | ||
169 | + inline RowData<T> operator[](const int row) { | ||
170 | + SCHECK(InRange(row, 0, height_less_one_), | ||
171 | + "Row out of range: %d (%d max)", row, height_less_one_); | ||
172 | + return RowData<T>(image_data_ + row * stride_, width_less_one_); | ||
173 | + } | ||
174 | + | ||
175 | + inline const RowData<T> operator[](const int row) const { | ||
176 | + SCHECK(InRange(row, 0, height_less_one_), | ||
177 | + "Row out of range: %d (%d max)", row, height_less_one_); | ||
178 | + return RowData<T>(image_data_ + row * stride_, width_less_one_); | ||
179 | + } | ||
180 | +#else | ||
181 | + inline T* operator[](const int row) { | ||
182 | + return image_data_ + row * stride_; | ||
183 | + } | ||
184 | + | ||
185 | + inline const T* operator[](const int row) const { | ||
186 | + return image_data_ + row * stride_; | ||
187 | + } | ||
188 | +#endif | ||
189 | + | ||
190 | + const T* data() const { return image_data_; } | ||
191 | + | ||
192 | + inline int stride() const { return stride_; } | ||
193 | + | ||
194 | + // Clears image to a single value. | ||
195 | + inline void Clear(const T& val) { | ||
196 | + memset(image_data_, val, sizeof(*image_data_) * data_size_); | ||
197 | + } | ||
198 | + | ||
199 | +#ifdef __ARM_NEON | ||
200 | + void Downsample2x32ColumnsNeon(const uint8_t* const original, | ||
201 | + const int stride, const int orig_x); | ||
202 | + | ||
203 | + void Downsample4x32ColumnsNeon(const uint8_t* const original, | ||
204 | + const int stride, const int orig_x); | ||
205 | + | ||
206 | + void DownsampleAveragedNeon(const uint8_t* const original, const int stride, | ||
207 | + const int factor); | ||
208 | +#endif | ||
209 | + | ||
210 | + // Naive downsampler that reduces image size by factor by averaging pixels in | ||
211 | + // blocks of size factor x factor. | ||
212 | + void DownsampleAveraged(const T* const original, const int stride, | ||
213 | + const int factor); | ||
214 | + | ||
215 | + // Naive downsampler that reduces image size by factor by averaging pixels in | ||
216 | + // blocks of size factor x factor. | ||
217 | + inline void DownsampleAveraged(const Image<T>& original, const int factor) { | ||
218 | + DownsampleAveraged(original.data(), original.GetWidth(), factor); | ||
219 | + } | ||
220 | + | ||
221 | + // Native downsampler that reduces image size using nearest interpolation | ||
222 | + void DownsampleInterpolateNearest(const Image<T>& original); | ||
223 | + | ||
224 | + // Native downsampler that reduces image size using fixed-point bilinear | ||
225 | + // interpolation | ||
226 | + void DownsampleInterpolateLinear(const Image<T>& original); | ||
227 | + | ||
228 | + // Relatively efficient downsampling of an image by a factor of two with a | ||
229 | + // low-pass 3x3 smoothing operation thrown in. | ||
230 | + void DownsampleSmoothed3x3(const Image<T>& original); | ||
231 | + | ||
232 | + // Relatively efficient downsampling of an image by a factor of two with a | ||
233 | + // low-pass 5x5 smoothing operation thrown in. | ||
234 | + void DownsampleSmoothed5x5(const Image<T>& original); | ||
235 | + | ||
236 | + // Optimized Scharr filter on a single pixel in the X direction. | ||
237 | + // Scharr filters are like central-difference operators, but have more | ||
238 | + // rotational symmetry in their response because they also consider the | ||
239 | + // diagonal neighbors. | ||
240 | + template <typename U> | ||
241 | + inline T ScharrPixelX(const Image<U>& original, | ||
242 | + const int center_x, const int center_y) const; | ||
243 | + | ||
244 | + // Optimized Scharr filter on a single pixel in the X direction. | ||
245 | + // Scharr filters are like central-difference operators, but have more | ||
246 | + // rotational symmetry in their response because they also consider the | ||
247 | + // diagonal neighbors. | ||
248 | + template <typename U> | ||
249 | + inline T ScharrPixelY(const Image<U>& original, | ||
250 | + const int center_x, const int center_y) const; | ||
251 | + | ||
252 | + // Convolve the image with a Scharr filter in the X direction. | ||
253 | + // Much faster than an equivalent generic convolution. | ||
254 | + template <typename U> | ||
255 | + inline void ScharrX(const Image<U>& original); | ||
256 | + | ||
257 | + // Convolve the image with a Scharr filter in the Y direction. | ||
258 | + // Much faster than an equivalent generic convolution. | ||
259 | + template <typename U> | ||
260 | + inline void ScharrY(const Image<U>& original); | ||
261 | + | ||
262 | + static inline T HalfDiff(int32_t first, int32_t second) { | ||
263 | + return (second - first) / 2; | ||
264 | + } | ||
265 | + | ||
266 | + template <typename U> | ||
267 | + void DerivativeX(const Image<U>& original); | ||
268 | + | ||
269 | + template <typename U> | ||
270 | + void DerivativeY(const Image<U>& original); | ||
271 | + | ||
272 | + // Generic function for convolving pixel with 3x3 filter. | ||
273 | + // Filter pixels should be in row major order. | ||
274 | + template <typename U> | ||
275 | + inline T ConvolvePixel3x3(const Image<U>& original, | ||
276 | + const int* const filter, | ||
277 | + const int center_x, const int center_y, | ||
278 | + const int total) const; | ||
279 | + | ||
280 | + // Generic function for convolving an image with a 3x3 filter. | ||
281 | + // TODO(andrewharp): Generalize this for any size filter. | ||
282 | + template <typename U> | ||
283 | + inline void Convolve3x3(const Image<U>& original, | ||
284 | + const int32_t* const filter); | ||
285 | + | ||
286 | + // Load this image's data from a data array. The data at pixels is assumed to | ||
287 | + // have dimensions equivalent to this image's dimensions * factor. | ||
288 | + inline void FromArray(const T* const pixels, const int stride, | ||
289 | + const int factor = 1); | ||
290 | + | ||
291 | + // Copy the image back out to an appropriately sized data array. | ||
292 | + inline void ToArray(T* const pixels) const { | ||
293 | + // If not subsampling, memcpy should be faster. | ||
294 | + memcpy(pixels, this->image_data_, data_size_ * sizeof(T)); | ||
295 | + } | ||
296 | + | ||
297 | + // Precompute these for efficiency's sake as they're used by a lot of | ||
298 | + // clipping code and loop code. | ||
299 | + // TODO(andrewharp): make these only accessible by other Images. | ||
300 | + const int width_less_one_; | ||
301 | + const int height_less_one_; | ||
302 | + | ||
303 | + // The raw size of the allocated data. | ||
304 | + const int data_size_; | ||
305 | + | ||
306 | + private: | ||
307 | + inline void Allocate() { | ||
308 | + image_data_ = new T[data_size_]; | ||
309 | + if (image_data_ == NULL) { | ||
310 | + LOGE("Couldn't allocate image data!"); | ||
311 | + } | ||
312 | + } | ||
313 | + | ||
314 | + T* image_data_; | ||
315 | + | ||
316 | + bool own_data_; | ||
317 | + | ||
318 | + const int width_; | ||
319 | + const int height_; | ||
320 | + | ||
321 | + // The image stride (offset to next row). | ||
322 | + // TODO(andrewharp): Make sure that stride is honored in all code. | ||
323 | + const int stride_; | ||
324 | + | ||
325 | + TF_DISALLOW_COPY_AND_ASSIGN(Image); | ||
326 | +}; | ||
327 | + | ||
328 | +template <typename t> | ||
329 | +inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) { | ||
330 | + for (int y = 0; y < image.GetHeight(); ++y) { | ||
331 | + for (int x = 0; x < image.GetWidth(); ++x) { | ||
332 | + stream << image[y][x] << " "; | ||
333 | + } | ||
334 | + stream << std::endl; | ||
335 | + } | ||
336 | + return stream; | ||
337 | +} | ||
338 | + | ||
339 | +} // namespace tf_tracking | ||
340 | + | ||
341 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ | ||
18 | + | ||
19 | +#include <stdint.h> | ||
20 | +#include <memory> | ||
21 | + | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
27 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
28 | + | ||
29 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
30 | + | ||
31 | +namespace tf_tracking { | ||
32 | + | ||
33 | +// Class that encapsulates all bulky processed data for a frame. | ||
34 | +class ImageData { | ||
35 | + public: | ||
36 | + explicit ImageData(const int width, const int height) | ||
37 | + : uv_frame_width_(width << 1), | ||
38 | + uv_frame_height_(height << 1), | ||
39 | + timestamp_(0), | ||
40 | + image_(width, height) { | ||
41 | + InitPyramid(width, height); | ||
42 | + ResetComputationCache(); | ||
43 | + } | ||
44 | + | ||
45 | + private: | ||
46 | + void ResetComputationCache() { | ||
47 | + uv_data_computed_ = false; | ||
48 | + integral_image_computed_ = false; | ||
49 | + for (int i = 0; i < kNumPyramidLevels; ++i) { | ||
50 | + spatial_x_computed_[i] = false; | ||
51 | + spatial_y_computed_[i] = false; | ||
52 | + pyramid_sqrt2_computed_[i * 2] = false; | ||
53 | + pyramid_sqrt2_computed_[i * 2 + 1] = false; | ||
54 | + } | ||
55 | + } | ||
56 | + | ||
57 | + void InitPyramid(const int width, const int height) { | ||
58 | + int level_width = width; | ||
59 | + int level_height = height; | ||
60 | + | ||
61 | + for (int i = 0; i < kNumPyramidLevels; ++i) { | ||
62 | + pyramid_sqrt2_[i * 2] = NULL; | ||
63 | + pyramid_sqrt2_[i * 2 + 1] = NULL; | ||
64 | + spatial_x_[i] = NULL; | ||
65 | + spatial_y_[i] = NULL; | ||
66 | + | ||
67 | + level_width /= 2; | ||
68 | + level_height /= 2; | ||
69 | + } | ||
70 | + | ||
71 | + // Alias the first pyramid level to image_. | ||
72 | + pyramid_sqrt2_[0] = &image_; | ||
73 | + } | ||
74 | + | ||
75 | + public: | ||
76 | + ~ImageData() { | ||
77 | + // The first pyramid level is actually an alias to image_, | ||
78 | + // so make sure it doesn't get deleted here. | ||
79 | + pyramid_sqrt2_[0] = NULL; | ||
80 | + | ||
81 | + for (int i = 0; i < kNumPyramidLevels; ++i) { | ||
82 | + SAFE_DELETE(pyramid_sqrt2_[i * 2]); | ||
83 | + SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]); | ||
84 | + SAFE_DELETE(spatial_x_[i]); | ||
85 | + SAFE_DELETE(spatial_y_[i]); | ||
86 | + } | ||
87 | + } | ||
88 | + | ||
89 | + void SetData(const uint8_t* const new_frame, const int stride, | ||
90 | + const int64_t timestamp, const int downsample_factor) { | ||
91 | + SetData(new_frame, NULL, stride, timestamp, downsample_factor); | ||
92 | + } | ||
93 | + | ||
94 | + void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame, | ||
95 | + const int stride, const int64_t timestamp, | ||
96 | + const int downsample_factor) { | ||
97 | + ResetComputationCache(); | ||
98 | + | ||
99 | + timestamp_ = timestamp; | ||
100 | + | ||
101 | + TimeLog("SetData!"); | ||
102 | + | ||
103 | + pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor); | ||
104 | + pyramid_sqrt2_computed_[0] = true; | ||
105 | + TimeLog("Downsampled image"); | ||
106 | + | ||
107 | + if (uv_frame != NULL) { | ||
108 | + if (u_data_.get() == NULL) { | ||
109 | + u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_)); | ||
110 | + v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_)); | ||
111 | + } | ||
112 | + | ||
113 | + GetUV(uv_frame, u_data_.get(), v_data_.get()); | ||
114 | + uv_data_computed_ = true; | ||
115 | + TimeLog("Copied UV data"); | ||
116 | + } else { | ||
117 | + LOGV("No uv data!"); | ||
118 | + } | ||
119 | + | ||
120 | +#ifdef LOG_TIME | ||
121 | + // If profiling is enabled, precompute here to make it easier to distinguish | ||
122 | + // total costs. | ||
123 | + Precompute(); | ||
124 | +#endif | ||
125 | + } | ||
126 | + | ||
127 | + inline const uint64_t GetTimestamp() const { return timestamp_; } | ||
128 | + | ||
129 | + inline const Image<uint8_t>* GetImage() const { | ||
130 | + SCHECK(pyramid_sqrt2_computed_[0], "image not set!"); | ||
131 | + return pyramid_sqrt2_[0]; | ||
132 | + } | ||
133 | + | ||
134 | + const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const { | ||
135 | + if (!pyramid_sqrt2_computed_[level]) { | ||
136 | + SCHECK(level != 0, "Level equals 0!"); | ||
137 | + if (level == 1) { | ||
138 | + const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0); | ||
139 | + if (pyramid_sqrt2_[level] == NULL) { | ||
140 | + const int new_width = | ||
141 | + (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2; | ||
142 | + const int new_height = | ||
143 | + (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 * | ||
144 | + 2; | ||
145 | + | ||
146 | + pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height); | ||
147 | + } | ||
148 | + pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level); | ||
149 | + } else { | ||
150 | + const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2); | ||
151 | + if (pyramid_sqrt2_[level] == NULL) { | ||
152 | + pyramid_sqrt2_[level] = new Image<uint8_t>( | ||
153 | + upper_level.GetWidth() / 2, upper_level.GetHeight() / 2); | ||
154 | + } | ||
155 | + pyramid_sqrt2_[level]->DownsampleAveraged( | ||
156 | + upper_level.data(), upper_level.stride(), 2); | ||
157 | + } | ||
158 | + pyramid_sqrt2_computed_[level] = true; | ||
159 | + } | ||
160 | + return pyramid_sqrt2_[level]; | ||
161 | + } | ||
162 | + | ||
163 | + inline const Image<int32_t>* GetSpatialX(const int level) const { | ||
164 | + if (!spatial_x_computed_[level]) { | ||
165 | + const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2); | ||
166 | + if (spatial_x_[level] == NULL) { | ||
167 | + spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight()); | ||
168 | + } | ||
169 | + spatial_x_[level]->DerivativeX(src); | ||
170 | + spatial_x_computed_[level] = true; | ||
171 | + } | ||
172 | + return spatial_x_[level]; | ||
173 | + } | ||
174 | + | ||
175 | + inline const Image<int32_t>* GetSpatialY(const int level) const { | ||
176 | + if (!spatial_y_computed_[level]) { | ||
177 | + const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2); | ||
178 | + if (spatial_y_[level] == NULL) { | ||
179 | + spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight()); | ||
180 | + } | ||
181 | + spatial_y_[level]->DerivativeY(src); | ||
182 | + spatial_y_computed_[level] = true; | ||
183 | + } | ||
184 | + return spatial_y_[level]; | ||
185 | + } | ||
186 | + | ||
187 | + // The integral image is currently only used for object detection, so lazily | ||
188 | + // initialize it on request. | ||
189 | + inline const IntegralImage* GetIntegralImage() const { | ||
190 | + if (integral_image_.get() == NULL) { | ||
191 | + integral_image_.reset(new IntegralImage(image_)); | ||
192 | + } else if (!integral_image_computed_) { | ||
193 | + integral_image_->Recompute(image_); | ||
194 | + } | ||
195 | + integral_image_computed_ = true; | ||
196 | + return integral_image_.get(); | ||
197 | + } | ||
198 | + | ||
199 | + inline const Image<uint8_t>* GetU() const { | ||
200 | + SCHECK(uv_data_computed_, "UV data not provided!"); | ||
201 | + return u_data_.get(); | ||
202 | + } | ||
203 | + | ||
204 | + inline const Image<uint8_t>* GetV() const { | ||
205 | + SCHECK(uv_data_computed_, "UV data not provided!"); | ||
206 | + return v_data_.get(); | ||
207 | + } | ||
208 | + | ||
209 | + private: | ||
210 | + void Precompute() { | ||
211 | + // Create the smoothed pyramids. | ||
212 | + for (int i = 0; i < kNumPyramidLevels * 2; i += 2) { | ||
213 | + (void) GetPyramidSqrt2Level(i); | ||
214 | + } | ||
215 | + TimeLog("Created smoothed pyramids"); | ||
216 | + | ||
217 | + // Create the smoothed pyramids. | ||
218 | + for (int i = 1; i < kNumPyramidLevels * 2; i += 2) { | ||
219 | + (void) GetPyramidSqrt2Level(i); | ||
220 | + } | ||
221 | + TimeLog("Created smoothed sqrt pyramids"); | ||
222 | + | ||
223 | + // Create the spatial derivatives for frame 1. | ||
224 | + for (int i = 0; i < kNumPyramidLevels; ++i) { | ||
225 | + (void) GetSpatialX(i); | ||
226 | + (void) GetSpatialY(i); | ||
227 | + } | ||
228 | + TimeLog("Created spatial derivatives"); | ||
229 | + | ||
230 | + (void) GetIntegralImage(); | ||
231 | + TimeLog("Got integral image!"); | ||
232 | + } | ||
233 | + | ||
234 | + const int uv_frame_width_; | ||
235 | + const int uv_frame_height_; | ||
236 | + | ||
237 | + int64_t timestamp_; | ||
238 | + | ||
239 | + Image<uint8_t> image_; | ||
240 | + | ||
241 | + bool uv_data_computed_; | ||
242 | + std::unique_ptr<Image<uint8_t> > u_data_; | ||
243 | + std::unique_ptr<Image<uint8_t> > v_data_; | ||
244 | + | ||
245 | + mutable bool spatial_x_computed_[kNumPyramidLevels]; | ||
246 | + mutable Image<int32_t>* spatial_x_[kNumPyramidLevels]; | ||
247 | + | ||
248 | + mutable bool spatial_y_computed_[kNumPyramidLevels]; | ||
249 | + mutable Image<int32_t>* spatial_y_[kNumPyramidLevels]; | ||
250 | + | ||
251 | + // Mutable so the lazy initialization can work when this class is const. | ||
252 | + // Whether or not the integral image has been computed for the current image. | ||
253 | + mutable bool integral_image_computed_; | ||
254 | + mutable std::unique_ptr<IntegralImage> integral_image_; | ||
255 | + | ||
256 | + mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2]; | ||
257 | + mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2]; | ||
258 | + | ||
259 | + TF_DISALLOW_COPY_AND_ASSIGN(ImageData); | ||
260 | +}; | ||
261 | + | ||
262 | +} // namespace tf_tracking | ||
263 | + | ||
264 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// NEON implementations of Image methods for compatible devices. Control | ||
17 | +// should never enter this compilation unit on incompatible devices. | ||
18 | + | ||
19 | +#ifdef __ARM_NEON | ||
20 | + | ||
21 | +#include <arm_neon.h> | ||
22 | + | ||
23 | +#include <stdint.h> | ||
24 | + | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
27 | +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" | ||
28 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
29 | + | ||
30 | +namespace tf_tracking { | ||
31 | + | ||
32 | +// This function does the bulk of the work. | ||
33 | +template <> | ||
34 | +void Image<uint8_t>::Downsample2x32ColumnsNeon(const uint8_t* const original, | ||
35 | + const int stride, | ||
36 | + const int orig_x) { | ||
37 | + // Divide input x offset by 2 to find output offset. | ||
38 | + const int new_x = orig_x >> 1; | ||
39 | + | ||
40 | + // Initial offset into top row. | ||
41 | + const uint8_t* offset = original + orig_x; | ||
42 | + | ||
43 | + // This points to the leftmost pixel of our 8 horizontally arranged | ||
44 | + // pixels in the destination data. | ||
45 | + uint8_t* ptr_dst = (*this)[0] + new_x; | ||
46 | + | ||
47 | + // Sum along vertical columns. | ||
48 | + // Process 32x2 input pixels and 16x1 output pixels per iteration. | ||
49 | + for (int new_y = 0; new_y < height_; ++new_y) { | ||
50 | + uint16x8_t accum1 = vdupq_n_u16(0); | ||
51 | + uint16x8_t accum2 = vdupq_n_u16(0); | ||
52 | + | ||
53 | + // Go top to bottom across the four rows of input pixels that make up | ||
54 | + // this output row. | ||
55 | + for (int row_num = 0; row_num < 2; ++row_num) { | ||
56 | + // First 16 bytes. | ||
57 | + { | ||
58 | + // Load 16 bytes of data from current offset. | ||
59 | + const uint8x16_t curr_data1 = vld1q_u8(offset); | ||
60 | + | ||
61 | + // Pairwise add and accumulate into accum vectors (16 bit to account | ||
62 | + // for values above 255). | ||
63 | + accum1 = vpadalq_u8(accum1, curr_data1); | ||
64 | + } | ||
65 | + | ||
66 | + // Second 16 bytes. | ||
67 | + { | ||
68 | + // Load 16 bytes of data from current offset. | ||
69 | + const uint8x16_t curr_data2 = vld1q_u8(offset + 16); | ||
70 | + | ||
71 | + // Pairwise add and accumulate into accum vectors (16 bit to account | ||
72 | + // for values above 255). | ||
73 | + accum2 = vpadalq_u8(accum2, curr_data2); | ||
74 | + } | ||
75 | + | ||
76 | + // Move offset down one row. | ||
77 | + offset += stride; | ||
78 | + } | ||
79 | + | ||
80 | + // Divide by 4 (number of input pixels per output | ||
81 | + // pixel) and narrow data from 16 bits per pixel to 8 bpp. | ||
82 | + const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2); | ||
83 | + const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2); | ||
84 | + | ||
85 | + // Concatenate 8x1 pixel strips into 16x1 pixel strip. | ||
86 | + const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2); | ||
87 | + | ||
88 | + // Copy all pixels from composite 16x1 vector into output strip. | ||
89 | + vst1q_u8(ptr_dst, allpixels); | ||
90 | + | ||
91 | + ptr_dst += stride_; | ||
92 | + } | ||
93 | +} | ||
94 | + | ||
95 | +// This function does the bulk of the work. | ||
96 | +template <> | ||
97 | +void Image<uint8_t>::Downsample4x32ColumnsNeon(const uint8_t* const original, | ||
98 | + const int stride, | ||
99 | + const int orig_x) { | ||
100 | + // Divide input x offset by 4 to find output offset. | ||
101 | + const int new_x = orig_x >> 2; | ||
102 | + | ||
103 | + // Initial offset into top row. | ||
104 | + const uint8_t* offset = original + orig_x; | ||
105 | + | ||
106 | + // This points to the leftmost pixel of our 8 horizontally arranged | ||
107 | + // pixels in the destination data. | ||
108 | + uint8_t* ptr_dst = (*this)[0] + new_x; | ||
109 | + | ||
110 | + // Sum along vertical columns. | ||
111 | + // Process 32x4 input pixels and 8x1 output pixels per iteration. | ||
112 | + for (int new_y = 0; new_y < height_; ++new_y) { | ||
113 | + uint16x8_t accum1 = vdupq_n_u16(0); | ||
114 | + uint16x8_t accum2 = vdupq_n_u16(0); | ||
115 | + | ||
116 | + // Go top to bottom across the four rows of input pixels that make up | ||
117 | + // this output row. | ||
118 | + for (int row_num = 0; row_num < 4; ++row_num) { | ||
119 | + // First 16 bytes. | ||
120 | + { | ||
121 | + // Load 16 bytes of data from current offset. | ||
122 | + const uint8x16_t curr_data1 = vld1q_u8(offset); | ||
123 | + | ||
124 | + // Pairwise add and accumulate into accum vectors (16 bit to account | ||
125 | + // for values above 255). | ||
126 | + accum1 = vpadalq_u8(accum1, curr_data1); | ||
127 | + } | ||
128 | + | ||
129 | + // Second 16 bytes. | ||
130 | + { | ||
131 | + // Load 16 bytes of data from current offset. | ||
132 | + const uint8x16_t curr_data2 = vld1q_u8(offset + 16); | ||
133 | + | ||
134 | + // Pairwise add and accumulate into accum vectors (16 bit to account | ||
135 | + // for values above 255). | ||
136 | + accum2 = vpadalq_u8(accum2, curr_data2); | ||
137 | + } | ||
138 | + | ||
139 | + // Move offset down one row. | ||
140 | + offset += stride; | ||
141 | + } | ||
142 | + | ||
143 | + // Add and widen, then divide by 16 (number of input pixels per output | ||
144 | + // pixel) and narrow data from 32 bits per pixel to 16 bpp. | ||
145 | + const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4); | ||
146 | + const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4); | ||
147 | + | ||
148 | + // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from | ||
149 | + // 16 bits to 8 bits per pixel. | ||
150 | + const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2)); | ||
151 | + | ||
152 | + // Copy all pixels from composite 8x1 vector into output strip. | ||
153 | + vst1_u8(ptr_dst, allpixels); | ||
154 | + | ||
155 | + ptr_dst += stride_; | ||
156 | + } | ||
157 | +} | ||
158 | + | ||
159 | + | ||
160 | +// Hardware accelerated downsampling method for supported devices. | ||
161 | +// Requires that image size be a multiple of 16 pixels in each dimension, | ||
162 | +// and that downsampling be by a factor of 2 or 4. | ||
163 | +template <> | ||
164 | +void Image<uint8_t>::DownsampleAveragedNeon(const uint8_t* const original, | ||
165 | + const int stride, | ||
166 | + const int factor) { | ||
167 | + // TODO(andrewharp): stride is a bad approximation for the src image's width. | ||
168 | + // Better to pass that in directly. | ||
169 | + SCHECK(width_ * factor <= stride, "Uh oh!"); | ||
170 | + const int last_starting_index = width_ * factor - 32; | ||
171 | + | ||
172 | + // We process 32 input pixels lengthwise at a time. | ||
173 | + // The output per pass of this loop is an 8 wide by downsampled height tall | ||
174 | + // pixel strip. | ||
175 | + int orig_x = 0; | ||
176 | + for (; orig_x <= last_starting_index; orig_x += 32) { | ||
177 | + if (factor == 2) { | ||
178 | + Downsample2x32ColumnsNeon(original, stride, orig_x); | ||
179 | + } else { | ||
180 | + Downsample4x32ColumnsNeon(original, stride, orig_x); | ||
181 | + } | ||
182 | + } | ||
183 | + | ||
184 | + // If a last pass is required, push it to the left enough so that it never | ||
185 | + // goes out of bounds. This will result in some extra computation on devices | ||
186 | + // whose frame widths are multiples of 16 and not 32. | ||
187 | + if (orig_x < last_starting_index + 32) { | ||
188 | + if (factor == 2) { | ||
189 | + Downsample2x32ColumnsNeon(original, stride, last_starting_index); | ||
190 | + } else { | ||
191 | + Downsample4x32ColumnsNeon(original, stride, last_starting_index); | ||
192 | + } | ||
193 | + } | ||
194 | +} | ||
195 | + | ||
196 | + | ||
197 | +// Puts the image gradient matrix about a pixel into the 2x2 float array G. | ||
198 | +// vals_x should be an array of the window x gradient values, whose indices | ||
199 | +// can be in any order but are parallel to the vals_y entries. | ||
200 | +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details. | ||
201 | +void CalculateGNeon(const float* const vals_x, const float* const vals_y, | ||
202 | + const int num_vals, float* const G) { | ||
203 | + const float32_t* const arm_vals_x = (const float32_t*) vals_x; | ||
204 | + const float32_t* const arm_vals_y = (const float32_t*) vals_y; | ||
205 | + | ||
206 | + // Running sums. | ||
207 | + float32x4_t xx = vdupq_n_f32(0.0f); | ||
208 | + float32x4_t xy = vdupq_n_f32(0.0f); | ||
209 | + float32x4_t yy = vdupq_n_f32(0.0f); | ||
210 | + | ||
211 | + // Maximum index we can load 4 consecutive values from. | ||
212 | + // e.g. if there are 81 values, our last full pass can be from index 77: | ||
213 | + // 81-4=>77 (77, 78, 79, 80) | ||
214 | + const int max_i = num_vals - 4; | ||
215 | + | ||
216 | + // Defined here because we want to keep track of how many values were | ||
217 | + // processed by NEON, so that we can finish off the remainder the normal | ||
218 | + // way. | ||
219 | + int i = 0; | ||
220 | + | ||
221 | + // Process values 4 at a time, accumulating the sums of | ||
222 | + // the pixel-wise x*x, x*y, and y*y values. | ||
223 | + for (; i <= max_i; i += 4) { | ||
224 | + // Load xs | ||
225 | + float32x4_t x = vld1q_f32(arm_vals_x + i); | ||
226 | + | ||
227 | + // Multiply x*x and accumulate. | ||
228 | + xx = vmlaq_f32(xx, x, x); | ||
229 | + | ||
230 | + // Load ys | ||
231 | + float32x4_t y = vld1q_f32(arm_vals_y + i); | ||
232 | + | ||
233 | + // Multiply x*y and accumulate. | ||
234 | + xy = vmlaq_f32(xy, x, y); | ||
235 | + | ||
236 | + // Multiply y*y and accumulate. | ||
237 | + yy = vmlaq_f32(yy, y, y); | ||
238 | + } | ||
239 | + | ||
240 | + static float32_t xx_vals[4]; | ||
241 | + static float32_t xy_vals[4]; | ||
242 | + static float32_t yy_vals[4]; | ||
243 | + | ||
244 | + vst1q_f32(xx_vals, xx); | ||
245 | + vst1q_f32(xy_vals, xy); | ||
246 | + vst1q_f32(yy_vals, yy); | ||
247 | + | ||
248 | + // Accumulated values are store in sets of 4, we have to manually add | ||
249 | + // the last bits together. | ||
250 | + for (int j = 0; j < 4; ++j) { | ||
251 | + G[0] += xx_vals[j]; | ||
252 | + G[1] += xy_vals[j]; | ||
253 | + G[3] += yy_vals[j]; | ||
254 | + } | ||
255 | + | ||
256 | + // Finishes off last few values (< 4) from above. | ||
257 | + for (; i < num_vals; ++i) { | ||
258 | + G[0] += Square(vals_x[i]); | ||
259 | + G[1] += vals_x[i] * vals_y[i]; | ||
260 | + G[3] += Square(vals_y[i]); | ||
261 | + } | ||
262 | + | ||
263 | + // The matrix is symmetric, so this is a given. | ||
264 | + G[2] = G[1]; | ||
265 | +} | ||
266 | + | ||
267 | +} // namespace tf_tracking | ||
268 | + | ||
269 | +#endif |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ | ||
18 | + | ||
19 | +#include <stdint.h> | ||
20 | + | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
25 | + | ||
26 | + | ||
27 | +namespace tf_tracking { | ||
28 | + | ||
29 | +inline void GetUV(const uint8_t* const input, Image<uint8_t>* const u, | ||
30 | + Image<uint8_t>* const v) { | ||
31 | + const uint8_t* pUV = input; | ||
32 | + | ||
33 | + for (int row = 0; row < u->GetHeight(); ++row) { | ||
34 | + uint8_t* u_curr = (*u)[row]; | ||
35 | + uint8_t* v_curr = (*v)[row]; | ||
36 | + for (int col = 0; col < u->GetWidth(); ++col) { | ||
37 | +#ifdef __APPLE__ | ||
38 | + *u_curr++ = *pUV++; | ||
39 | + *v_curr++ = *pUV++; | ||
40 | +#else | ||
41 | + *v_curr++ = *pUV++; | ||
42 | + *u_curr++ = *pUV++; | ||
43 | +#endif | ||
44 | + } | ||
45 | + } | ||
46 | +} | ||
47 | + | ||
48 | +// Marks every point within a circle of a given radius on the given boolean | ||
49 | +// image true. | ||
50 | +template <typename U> | ||
51 | +inline static void MarkImage(const int x, const int y, const int radius, | ||
52 | + Image<U>* const img) { | ||
53 | + SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y); | ||
54 | + | ||
55 | + // Precomputed for efficiency. | ||
56 | + const int squared_radius = Square(radius); | ||
57 | + | ||
58 | + // Mark every row in the circle. | ||
59 | + for (int d_y = 0; d_y <= radius; ++d_y) { | ||
60 | + const int squared_y_dist = Square(d_y); | ||
61 | + | ||
62 | + const int min_y = MAX(y - d_y, 0); | ||
63 | + const int max_y = MIN(y + d_y, img->height_less_one_); | ||
64 | + | ||
65 | + // The max d_x of the circle must be strictly greater or equal to | ||
66 | + // radius - d_y for any positive d_y. Thus, starting from radius - d_y will | ||
67 | + // reduce the number of iterations required as compared to starting from | ||
68 | + // either 0 and counting up or radius and counting down. | ||
69 | + for (int d_x = radius - d_y; d_x <= radius; ++d_x) { | ||
70 | + // The first time this criteria is met, we know the width of the circle at | ||
71 | + // this row (without using sqrt). | ||
72 | + if (squared_y_dist + Square(d_x) >= squared_radius) { | ||
73 | + const int min_x = MAX(x - d_x, 0); | ||
74 | + const int max_x = MIN(x + d_x, img->width_less_one_); | ||
75 | + | ||
76 | + // Mark both above and below the center row. | ||
77 | + bool* const top_row_start = (*img)[min_y] + min_x; | ||
78 | + bool* const bottom_row_start = (*img)[max_y] + min_x; | ||
79 | + | ||
80 | + const int x_width = max_x - min_x + 1; | ||
81 | + memset(top_row_start, true, sizeof(*top_row_start) * x_width); | ||
82 | + memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width); | ||
83 | + | ||
84 | + // This row is marked, time to move on to the next row. | ||
85 | + break; | ||
86 | + } | ||
87 | + } | ||
88 | + } | ||
89 | +} | ||
90 | + | ||
91 | +#ifdef __ARM_NEON | ||
92 | +void CalculateGNeon( | ||
93 | + const float* const vals_x, const float* const vals_y, | ||
94 | + const int num_vals, float* const G); | ||
95 | +#endif | ||
96 | + | ||
97 | +// Puts the image gradient matrix about a pixel into the 2x2 float array G. | ||
98 | +// vals_x should be an array of the window x gradient values, whose indices | ||
99 | +// can be in any order but are parallel to the vals_y entries. | ||
100 | +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details. | ||
101 | +inline void CalculateG(const float* const vals_x, const float* const vals_y, | ||
102 | + const int num_vals, float* const G) { | ||
103 | +#ifdef __ARM_NEON | ||
104 | + CalculateGNeon(vals_x, vals_y, num_vals, G); | ||
105 | + return; | ||
106 | +#endif | ||
107 | + | ||
108 | + // Non-accelerated version. | ||
109 | + for (int i = 0; i < num_vals; ++i) { | ||
110 | + G[0] += Square(vals_x[i]); | ||
111 | + G[1] += vals_x[i] * vals_y[i]; | ||
112 | + G[3] += Square(vals_y[i]); | ||
113 | + } | ||
114 | + | ||
115 | + // The matrix is symmetric, so this is a given. | ||
116 | + G[2] = G[1]; | ||
117 | +} | ||
118 | + | ||
119 | +inline void CalculateGInt16(const int16_t* const vals_x, | ||
120 | + const int16_t* const vals_y, const int num_vals, | ||
121 | + int* const G) { | ||
122 | + // Non-accelerated version. | ||
123 | + for (int i = 0; i < num_vals; ++i) { | ||
124 | + G[0] += Square(vals_x[i]); | ||
125 | + G[1] += vals_x[i] * vals_y[i]; | ||
126 | + G[3] += Square(vals_y[i]); | ||
127 | + } | ||
128 | + | ||
129 | + // The matrix is symmetric, so this is a given. | ||
130 | + G[2] = G[1]; | ||
131 | +} | ||
132 | + | ||
133 | + | ||
134 | +// Puts the image gradient matrix about a pixel into the 2x2 float array G. | ||
135 | +// Looks up interpolated pixels, then calls above method for implementation. | ||
136 | +inline void CalculateG(const int window_radius, const float center_x, | ||
137 | + const float center_y, const Image<int32_t>& I_x, | ||
138 | + const Image<int32_t>& I_y, float* const G) { | ||
139 | + SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!"); | ||
140 | + | ||
141 | + // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels). | ||
142 | + static const int kMaxWindowRadius = 5; | ||
143 | + SCHECK(window_radius <= kMaxWindowRadius, | ||
144 | + "Window %d > %d!", window_radius, kMaxWindowRadius); | ||
145 | + | ||
146 | + // Diameter of window is 2 * radius + 1 for center pixel. | ||
147 | + static const int kWindowBufferSize = | ||
148 | + (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1); | ||
149 | + | ||
150 | + // Preallocate buffers statically for efficiency. | ||
151 | + static int16_t vals_x[kWindowBufferSize]; | ||
152 | + static int16_t vals_y[kWindowBufferSize]; | ||
153 | + | ||
154 | + const int src_left_fixed = RealToFixed1616(center_x - window_radius); | ||
155 | + const int src_top_fixed = RealToFixed1616(center_y - window_radius); | ||
156 | + | ||
157 | + int16_t* vals_x_ptr = vals_x; | ||
158 | + int16_t* vals_y_ptr = vals_y; | ||
159 | + | ||
160 | + const int window_size = 2 * window_radius + 1; | ||
161 | + for (int y = 0; y < window_size; ++y) { | ||
162 | + const int fp_y = src_top_fixed + (y << 16); | ||
163 | + | ||
164 | + for (int x = 0; x < window_size; ++x) { | ||
165 | + const int fp_x = src_left_fixed + (x << 16); | ||
166 | + | ||
167 | + *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); | ||
168 | + *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); | ||
169 | + } | ||
170 | + } | ||
171 | + | ||
172 | + int32_t g_temp[] = {0, 0, 0, 0}; | ||
173 | + CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp); | ||
174 | + | ||
175 | + for (int i = 0; i < 4; ++i) { | ||
176 | + G[i] = g_temp[i]; | ||
177 | + } | ||
178 | +} | ||
179 | + | ||
180 | +inline float ImageCrossCorrelation(const Image<float>& image1, | ||
181 | + const Image<float>& image2, | ||
182 | + const int x_offset, const int y_offset) { | ||
183 | + SCHECK(image1.GetWidth() == image2.GetWidth() && | ||
184 | + image1.GetHeight() == image2.GetHeight(), | ||
185 | + "Dimension mismatch! %dx%d vs %dx%d", | ||
186 | + image1.GetWidth(), image1.GetHeight(), | ||
187 | + image2.GetWidth(), image2.GetHeight()); | ||
188 | + | ||
189 | + const int num_pixels = image1.GetWidth() * image1.GetHeight(); | ||
190 | + const float* data1 = image1.data(); | ||
191 | + const float* data2 = image2.data(); | ||
192 | + return ComputeCrossCorrelation(data1, data2, num_pixels); | ||
193 | +} | ||
194 | + | ||
195 | +// Copies an arbitrary region of an image to another (floating point) | ||
196 | +// image, scaling as it goes using bilinear interpolation. | ||
197 | +inline void CopyArea(const Image<uint8_t>& image, | ||
198 | + const BoundingBox& area_to_copy, | ||
199 | + Image<float>* const patch_image) { | ||
200 | + VLOG(2) << "Copying from: " << area_to_copy << std::endl; | ||
201 | + | ||
202 | + const int patch_width = patch_image->GetWidth(); | ||
203 | + const int patch_height = patch_image->GetHeight(); | ||
204 | + | ||
205 | + const float x_dist_between_samples = patch_width > 0 ? | ||
206 | + area_to_copy.GetWidth() / (patch_width - 1) : 0; | ||
207 | + | ||
208 | + const float y_dist_between_samples = patch_height > 0 ? | ||
209 | + area_to_copy.GetHeight() / (patch_height - 1) : 0; | ||
210 | + | ||
211 | + for (int y_index = 0; y_index < patch_height; ++y_index) { | ||
212 | + const float sample_y = | ||
213 | + y_index * y_dist_between_samples + area_to_copy.top_; | ||
214 | + | ||
215 | + for (int x_index = 0; x_index < patch_width; ++x_index) { | ||
216 | + const float sample_x = | ||
217 | + x_index * x_dist_between_samples + area_to_copy.left_; | ||
218 | + | ||
219 | + if (image.ValidInterpPixel(sample_x, sample_y)) { | ||
220 | + // TODO(andrewharp): Do area averaging when downsampling. | ||
221 | + (*patch_image)[y_index][x_index] = | ||
222 | + image.GetPixelInterp(sample_x, sample_y); | ||
223 | + } else { | ||
224 | + (*patch_image)[y_index][x_index] = -1.0f; | ||
225 | + } | ||
226 | + } | ||
227 | + } | ||
228 | +} | ||
229 | + | ||
230 | + | ||
231 | +// Takes a floating point image and normalizes it in-place. | ||
232 | +// | ||
233 | +// First, negative values will be set to the mean of the non-negative pixels | ||
234 | +// in the image. | ||
235 | +// | ||
236 | +// Then, the resulting will be normalized such that it has mean value of 0.0 and | ||
237 | +// a standard deviation of 1.0. | ||
238 | +inline void NormalizeImage(Image<float>* const image) { | ||
239 | + const float* const data_ptr = image->data(); | ||
240 | + | ||
241 | + // Copy only the non-negative values to some temp memory. | ||
242 | + float running_sum = 0.0f; | ||
243 | + int num_data_gte_zero = 0; | ||
244 | + { | ||
245 | + float* const curr_data = (*image)[0]; | ||
246 | + for (int i = 0; i < image->data_size_; ++i) { | ||
247 | + if (curr_data[i] >= 0.0f) { | ||
248 | + running_sum += curr_data[i]; | ||
249 | + ++num_data_gte_zero; | ||
250 | + } else { | ||
251 | + curr_data[i] = -1.0f; | ||
252 | + } | ||
253 | + } | ||
254 | + } | ||
255 | + | ||
256 | + // If none of the pixels are valid, just set the entire thing to 0.0f. | ||
257 | + if (num_data_gte_zero == 0) { | ||
258 | + image->Clear(0.0f); | ||
259 | + return; | ||
260 | + } | ||
261 | + | ||
262 | + const float corrected_mean = running_sum / num_data_gte_zero; | ||
263 | + | ||
264 | + float* curr_data = (*image)[0]; | ||
265 | + for (int i = 0; i < image->data_size_; ++i) { | ||
266 | + const float curr_val = *curr_data; | ||
267 | + *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean; | ||
268 | + } | ||
269 | + | ||
270 | + const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f); | ||
271 | + | ||
272 | + if (std_dev > 0.0f) { | ||
273 | + curr_data = (*image)[0]; | ||
274 | + for (int i = 0; i < image->data_size_; ++i) { | ||
275 | + *curr_data++ /= std_dev; | ||
276 | + } | ||
277 | + | ||
278 | +#ifdef SANITY_CHECKS | ||
279 | + LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev); | ||
280 | + const float correlation = | ||
281 | + ComputeCrossCorrelation(image->data(), | ||
282 | + image->data(), | ||
283 | + image->data_size_); | ||
284 | + | ||
285 | + if (std::abs(correlation - 1.0f) > EPSILON) { | ||
286 | + LOG(ERROR) << "Bad image!" << std::endl; | ||
287 | + LOG(ERROR) << *image << std::endl; | ||
288 | + } | ||
289 | + | ||
290 | + SCHECK(std::abs(correlation - 1.0f) < EPSILON, | ||
291 | + "Correlation wasn't 1.0f: %.10f", correlation); | ||
292 | +#endif | ||
293 | + } | ||
294 | +} | ||
295 | + | ||
296 | +} // namespace tf_tracking | ||
297 | + | ||
298 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
23 | + | ||
24 | +namespace tf_tracking { | ||
25 | + | ||
26 | +typedef uint8_t Code; | ||
27 | + | ||
28 | +class IntegralImage : public Image<uint32_t> { | ||
29 | + public: | ||
30 | + explicit IntegralImage(const Image<uint8_t>& image_base) | ||
31 | + : Image<uint32_t>(image_base.GetWidth(), image_base.GetHeight()) { | ||
32 | + Recompute(image_base); | ||
33 | + } | ||
34 | + | ||
35 | + IntegralImage(const int width, const int height) | ||
36 | + : Image<uint32_t>(width, height) {} | ||
37 | + | ||
38 | + void Recompute(const Image<uint8_t>& image_base) { | ||
39 | + SCHECK(image_base.GetWidth() == GetWidth() && | ||
40 | + image_base.GetHeight() == GetHeight(), "Dimensions don't match!"); | ||
41 | + | ||
42 | + // Sum along first row. | ||
43 | + { | ||
44 | + int x_sum = 0; | ||
45 | + for (int x = 0; x < image_base.GetWidth(); ++x) { | ||
46 | + x_sum += image_base[0][x]; | ||
47 | + (*this)[0][x] = x_sum; | ||
48 | + } | ||
49 | + } | ||
50 | + | ||
51 | + // Sum everything else. | ||
52 | + for (int y = 1; y < image_base.GetHeight(); ++y) { | ||
53 | + uint32_t* curr_sum = (*this)[y]; | ||
54 | + | ||
55 | + // Previously summed pointers. | ||
56 | + const uint32_t* up_one = (*this)[y - 1]; | ||
57 | + | ||
58 | + // Current value pointer. | ||
59 | + const uint8_t* curr_delta = image_base[y]; | ||
60 | + | ||
61 | + uint32_t row_till_now = 0; | ||
62 | + | ||
63 | + for (int x = 0; x < GetWidth(); ++x) { | ||
64 | + // Add the one above and the one to the left. | ||
65 | + row_till_now += *curr_delta; | ||
66 | + *curr_sum = *up_one + row_till_now; | ||
67 | + | ||
68 | + // Scoot everything along. | ||
69 | + ++curr_sum; | ||
70 | + ++up_one; | ||
71 | + ++curr_delta; | ||
72 | + } | ||
73 | + } | ||
74 | + | ||
75 | + SCHECK(VerifyData(image_base), "Images did not match!"); | ||
76 | + } | ||
77 | + | ||
78 | + bool VerifyData(const Image<uint8_t>& image_base) { | ||
79 | + for (int y = 0; y < GetHeight(); ++y) { | ||
80 | + for (int x = 0; x < GetWidth(); ++x) { | ||
81 | + uint32_t curr_val = (*this)[y][x]; | ||
82 | + | ||
83 | + if (x > 0) { | ||
84 | + curr_val -= (*this)[y][x - 1]; | ||
85 | + } | ||
86 | + | ||
87 | + if (y > 0) { | ||
88 | + curr_val -= (*this)[y - 1][x]; | ||
89 | + } | ||
90 | + | ||
91 | + if (x > 0 && y > 0) { | ||
92 | + curr_val += (*this)[y - 1][x - 1]; | ||
93 | + } | ||
94 | + | ||
95 | + if (curr_val != image_base[y][x]) { | ||
96 | + LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]); | ||
97 | + return false; | ||
98 | + } | ||
99 | + | ||
100 | + if (GetRegionSum(x, y, x, y) != curr_val) { | ||
101 | + LOGE("Mismatch!"); | ||
102 | + } | ||
103 | + } | ||
104 | + } | ||
105 | + | ||
106 | + return true; | ||
107 | + } | ||
108 | + | ||
109 | + // Returns the sum of all pixels in the specified region. | ||
110 | + inline uint32_t GetRegionSum(const int x1, const int y1, const int x2, | ||
111 | + const int y2) const { | ||
112 | + SCHECK(x1 >= 0 && y1 >= 0 && | ||
113 | + x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(), | ||
114 | + "indices out of bounds! %d-%d / %d, %d-%d / %d, ", | ||
115 | + x1, x2, GetWidth(), y1, y2, GetHeight()); | ||
116 | + | ||
117 | + const uint32_t everything = (*this)[y2][x2]; | ||
118 | + | ||
119 | + uint32_t sum = everything; | ||
120 | + if (x1 > 0 && y1 > 0) { | ||
121 | + // Most common case. | ||
122 | + const uint32_t left = (*this)[y2][x1 - 1]; | ||
123 | + const uint32_t top = (*this)[y1 - 1][x2]; | ||
124 | + const uint32_t top_left = (*this)[y1 - 1][x1 - 1]; | ||
125 | + | ||
126 | + sum = everything - left - top + top_left; | ||
127 | + SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d", | ||
128 | + everything, left, top, top_left, sum, x1, y1, x2, y2); | ||
129 | + } else if (x1 > 0) { | ||
130 | + // Flush against top of image. | ||
131 | + // Subtract out the region to the left only. | ||
132 | + const uint32_t top = (*this)[y2][x1 - 1]; | ||
133 | + sum = everything - top; | ||
134 | + SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum); | ||
135 | + } else if (y1 > 0) { | ||
136 | + // Flush against left side of image. | ||
137 | + // Subtract out the region above only. | ||
138 | + const uint32_t left = (*this)[y1 - 1][x2]; | ||
139 | + sum = everything - left; | ||
140 | + SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum); | ||
141 | + } | ||
142 | + | ||
143 | + SCHECK(sum >= 0, "Negative sum!"); | ||
144 | + | ||
145 | + return sum; | ||
146 | + } | ||
147 | + | ||
148 | + // Returns the 2bit code associated with this region, which represents | ||
149 | + // the overall gradient. | ||
150 | + inline Code GetCode(const BoundingBox& bounding_box) const { | ||
151 | + return GetCode(bounding_box.left_, bounding_box.top_, | ||
152 | + bounding_box.right_, bounding_box.bottom_); | ||
153 | + } | ||
154 | + | ||
155 | + inline Code GetCode(const int x1, const int y1, | ||
156 | + const int x2, const int y2) const { | ||
157 | + SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d", | ||
158 | + x1, y1, x2, y2); | ||
159 | + | ||
160 | + // Gradient computed vertically. | ||
161 | + const int box_height = (y2 - y1) / 2; | ||
162 | + const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height); | ||
163 | + const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2); | ||
164 | + const bool vertical_code = top_sum > bottom_sum; | ||
165 | + | ||
166 | + // Gradient computed horizontally. | ||
167 | + const int box_width = (x2 - x1) / 2; | ||
168 | + const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2); | ||
169 | + const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2); | ||
170 | + const bool horizontal_code = left_sum > right_sum; | ||
171 | + | ||
172 | + const Code final_code = (vertical_code << 1) | horizontal_code; | ||
173 | + | ||
174 | + SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)), | ||
175 | + "Invalid code! %d", final_code); | ||
176 | + | ||
177 | + // Returns a value 0-3. | ||
178 | + return final_code; | ||
179 | + } | ||
180 | + | ||
181 | + private: | ||
182 | + TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage); | ||
183 | +}; | ||
184 | + | ||
185 | +} // namespace tf_tracking | ||
186 | + | ||
187 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ | ||
18 | + | ||
19 | +#include <jni.h> | ||
20 | +#include <stdint.h> | ||
21 | + | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
23 | + | ||
24 | +// The JniLongField class is used to access Java fields from native code. This | ||
25 | +// technique of hiding pointers to native objects in opaque Java fields is how | ||
26 | +// the Android hardware libraries work. This reduces the amount of static | ||
27 | +// native methods and makes it easier to manage the lifetime of native objects. | ||
28 | +class JniLongField { | ||
29 | + public: | ||
30 | + JniLongField(const char* field_name) | ||
31 | + : field_name_(field_name), field_ID_(0) {} | ||
32 | + | ||
33 | + int64_t get(JNIEnv* env, jobject thiz) { | ||
34 | + if (field_ID_ == 0) { | ||
35 | + jclass cls = env->GetObjectClass(thiz); | ||
36 | + CHECK_ALWAYS(cls != 0, "Unable to find class"); | ||
37 | + field_ID_ = env->GetFieldID(cls, field_name_, "J"); | ||
38 | + CHECK_ALWAYS(field_ID_ != 0, | ||
39 | + "Unable to find field %s. (Check proguard cfg)", field_name_); | ||
40 | + } | ||
41 | + | ||
42 | + return env->GetLongField(thiz, field_ID_); | ||
43 | + } | ||
44 | + | ||
45 | + void set(JNIEnv* env, jobject thiz, int64_t value) { | ||
46 | + if (field_ID_ == 0) { | ||
47 | + jclass cls = env->GetObjectClass(thiz); | ||
48 | + CHECK_ALWAYS(cls != 0, "Unable to find class"); | ||
49 | + field_ID_ = env->GetFieldID(cls, field_name_, "J"); | ||
50 | + CHECK_ALWAYS(field_ID_ != 0, | ||
51 | + "Unable to find field %s (Check proguard cfg)", field_name_); | ||
52 | + } | ||
53 | + | ||
54 | + env->SetLongField(thiz, field_ID_, value); | ||
55 | + } | ||
56 | + | ||
57 | + private: | ||
58 | + const char* const field_name_; | ||
59 | + | ||
60 | + // This is just a cache | ||
61 | + jfieldID field_ID_; | ||
62 | +}; | ||
63 | + | ||
64 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
25 | + | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
27 | + | ||
28 | +namespace tf_tracking { | ||
29 | + | ||
30 | +// For keeping track of keypoints. | ||
31 | +struct Keypoint { | ||
32 | + Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {} | ||
33 | + Keypoint(const float x, const float y) | ||
34 | + : pos_(x, y), score_(0.0f), type_(0) {} | ||
35 | + | ||
36 | + Point2f pos_; | ||
37 | + float score_; | ||
38 | + uint8_t type_; | ||
39 | +}; | ||
40 | + | ||
41 | +inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) { | ||
42 | + return stream << "[" << keypoint.pos_ << ", " | ||
43 | + << keypoint.score_ << ", " << keypoint.type_ << "]"; | ||
44 | +} | ||
45 | + | ||
46 | +} // namespace tf_tracking | ||
47 | + | ||
48 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// Various keypoint detecting functions. | ||
17 | + | ||
18 | +#include <float.h> | ||
19 | + | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
24 | + | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" | ||
27 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" | ||
28 | + | ||
29 | +namespace tf_tracking { | ||
30 | + | ||
31 | +static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) { | ||
32 | + return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]); | ||
33 | +} | ||
34 | + | ||
35 | +void KeypointDetector::ScoreKeypoints(const ImageData& image_data, | ||
36 | + const int num_candidates, | ||
37 | + Keypoint* const candidate_keypoints) { | ||
38 | + const Image<int>& I_x = *image_data.GetSpatialX(0); | ||
39 | + const Image<int>& I_y = *image_data.GetSpatialY(0); | ||
40 | + | ||
41 | + if (config_->detect_skin) { | ||
42 | + const Image<uint8_t>& u_data = *image_data.GetU(); | ||
43 | + const Image<uint8_t>& v_data = *image_data.GetV(); | ||
44 | + | ||
45 | + static const int reference[] = {111, 155}; | ||
46 | + | ||
47 | + // Score all the keypoints. | ||
48 | + for (int i = 0; i < num_candidates; ++i) { | ||
49 | + Keypoint* const keypoint = candidate_keypoints + i; | ||
50 | + | ||
51 | + const int x_pos = keypoint->pos_.x * 2; | ||
52 | + const int y_pos = keypoint->pos_.y * 2; | ||
53 | + | ||
54 | + const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]}; | ||
55 | + keypoint->score_ = | ||
56 | + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) / | ||
57 | + GetDistSquaredBetween(reference, curr_color); | ||
58 | + } | ||
59 | + } else { | ||
60 | + // Score all the keypoints. | ||
61 | + for (int i = 0; i < num_candidates; ++i) { | ||
62 | + Keypoint* const keypoint = candidate_keypoints + i; | ||
63 | + keypoint->score_ = | ||
64 | + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y); | ||
65 | + } | ||
66 | + } | ||
67 | +} | ||
68 | + | ||
69 | + | ||
70 | +inline int KeypointCompare(const void* const a, const void* const b) { | ||
71 | + return (reinterpret_cast<const Keypoint*>(a)->score_ - | ||
72 | + reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1; | ||
73 | +} | ||
74 | + | ||
75 | + | ||
76 | +// Quicksorts detected keypoints by score. | ||
77 | +void KeypointDetector::SortKeypoints(const int num_candidates, | ||
78 | + Keypoint* const candidate_keypoints) const { | ||
79 | + qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare); | ||
80 | + | ||
81 | +#ifdef SANITY_CHECKS | ||
82 | + // Verify that the array got sorted. | ||
83 | + float last_score = FLT_MAX; | ||
84 | + for (int i = 0; i < num_candidates; ++i) { | ||
85 | + const float curr_score = candidate_keypoints[i].score_; | ||
86 | + | ||
87 | + // Scores should be monotonically increasing. | ||
88 | + SCHECK(last_score >= curr_score, | ||
89 | + "Quicksort failure! %d: %.5f > %d: %.5f (%d total)", | ||
90 | + i - 1, last_score, i, curr_score, num_candidates); | ||
91 | + | ||
92 | + last_score = curr_score; | ||
93 | + } | ||
94 | +#endif | ||
95 | +} | ||
96 | + | ||
97 | + | ||
98 | +int KeypointDetector::SelectKeypointsInBox( | ||
99 | + const BoundingBox& box, | ||
100 | + const Keypoint* const candidate_keypoints, | ||
101 | + const int num_candidates, | ||
102 | + const int max_keypoints, | ||
103 | + const int num_existing_keypoints, | ||
104 | + const Keypoint* const existing_keypoints, | ||
105 | + Keypoint* const final_keypoints) const { | ||
106 | + if (max_keypoints <= 0) { | ||
107 | + return 0; | ||
108 | + } | ||
109 | + | ||
110 | + // This is the distance within which keypoints may be placed to each other | ||
111 | + // within this box, roughly based on the box dimensions. | ||
112 | + const int distance = | ||
113 | + MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f); | ||
114 | + | ||
115 | + // First, mark keypoints that already happen to be inside this region. Ignore | ||
116 | + // keypoints that are outside it, however close they might be. | ||
117 | + interest_map_->Clear(false); | ||
118 | + for (int i = 0; i < num_existing_keypoints; ++i) { | ||
119 | + const Keypoint& candidate = existing_keypoints[i]; | ||
120 | + | ||
121 | + const int x_pos = candidate.pos_.x; | ||
122 | + const int y_pos = candidate.pos_.y; | ||
123 | + if (box.Contains(candidate.pos_)) { | ||
124 | + MarkImage(x_pos, y_pos, distance, interest_map_.get()); | ||
125 | + } | ||
126 | + } | ||
127 | + | ||
128 | + // Now, go through and check which keypoints will still fit in the box. | ||
129 | + int num_keypoints_selected = 0; | ||
130 | + for (int i = 0; i < num_candidates; ++i) { | ||
131 | + const Keypoint& candidate = candidate_keypoints[i]; | ||
132 | + | ||
133 | + const int x_pos = candidate.pos_.x; | ||
134 | + const int y_pos = candidate.pos_.y; | ||
135 | + | ||
136 | + if (!box.Contains(candidate.pos_) || | ||
137 | + !interest_map_->ValidPixel(x_pos, y_pos)) { | ||
138 | + continue; | ||
139 | + } | ||
140 | + | ||
141 | + if (!(*interest_map_)[y_pos][x_pos]) { | ||
142 | + final_keypoints[num_keypoints_selected++] = candidate; | ||
143 | + if (num_keypoints_selected >= max_keypoints) { | ||
144 | + break; | ||
145 | + } | ||
146 | + MarkImage(x_pos, y_pos, distance, interest_map_.get()); | ||
147 | + } | ||
148 | + } | ||
149 | + return num_keypoints_selected; | ||
150 | +} | ||
151 | + | ||
152 | + | ||
153 | +void KeypointDetector::SelectKeypoints( | ||
154 | + const std::vector<BoundingBox>& boxes, | ||
155 | + const Keypoint* const candidate_keypoints, | ||
156 | + const int num_candidates, | ||
157 | + FramePair* const curr_change) const { | ||
158 | + // Now select all the interesting keypoints that fall insider our boxes. | ||
159 | + curr_change->number_of_keypoints_ = 0; | ||
160 | + for (std::vector<BoundingBox>::const_iterator iter = boxes.begin(); | ||
161 | + iter != boxes.end(); ++iter) { | ||
162 | + const BoundingBox bounding_box = *iter; | ||
163 | + | ||
164 | + // Count up keypoints that have already been selected, and fall within our | ||
165 | + // box. | ||
166 | + int num_keypoints_already_in_box = 0; | ||
167 | + for (int i = 0; i < curr_change->number_of_keypoints_; ++i) { | ||
168 | + if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) { | ||
169 | + ++num_keypoints_already_in_box; | ||
170 | + } | ||
171 | + } | ||
172 | + | ||
173 | + const int max_keypoints_to_find_in_box = | ||
174 | + MIN(kMaxKeypointsForObject - num_keypoints_already_in_box, | ||
175 | + kMaxKeypoints - curr_change->number_of_keypoints_); | ||
176 | + | ||
177 | + const int num_new_keypoints_in_box = SelectKeypointsInBox( | ||
178 | + bounding_box, | ||
179 | + candidate_keypoints, | ||
180 | + num_candidates, | ||
181 | + max_keypoints_to_find_in_box, | ||
182 | + curr_change->number_of_keypoints_, | ||
183 | + curr_change->frame1_keypoints_, | ||
184 | + curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_); | ||
185 | + | ||
186 | + curr_change->number_of_keypoints_ += num_new_keypoints_in_box; | ||
187 | + | ||
188 | + LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_); | ||
189 | + } | ||
190 | +} | ||
191 | + | ||
192 | + | ||
193 | +// Walks along the given circle checking for pixels above or below the center. | ||
194 | +// Returns a score, or 0 if the keypoint did not pass the criteria. | ||
195 | +// | ||
196 | +// Parameters: | ||
197 | +// circle_perimeter: the circumference in pixels of the circle. | ||
198 | +// threshold: the minimum number of contiguous pixels that must be above or | ||
199 | +// below the center value. | ||
200 | +// center_ptr: the location of the center pixel in memory | ||
201 | +// offsets: the relative offsets from the center pixel of the edge pixels. | ||
202 | +inline int TestCircle(const int circle_perimeter, const int threshold, | ||
203 | + const uint8_t* const center_ptr, const int* offsets) { | ||
204 | + // Get the actual value of the center pixel for easier reference later on. | ||
205 | + const int center_value = static_cast<int>(*center_ptr); | ||
206 | + | ||
207 | + // Number of total pixels to check. Have to wrap around some in case | ||
208 | + // the contiguous section is split by the array edges. | ||
209 | + const int num_total = circle_perimeter + threshold - 1; | ||
210 | + | ||
211 | + int num_above = 0; | ||
212 | + int above_diff = 0; | ||
213 | + | ||
214 | + int num_below = 0; | ||
215 | + int below_diff = 0; | ||
216 | + | ||
217 | + // Used to tell when this is definitely not going to meet the threshold so we | ||
218 | + // can early abort. | ||
219 | + int minimum_by_now = threshold - num_total + 1; | ||
220 | + | ||
221 | + // Go through every pixel along the perimeter of the circle, and then around | ||
222 | + // again a little bit. | ||
223 | + for (int i = 0; i < num_total; ++i) { | ||
224 | + // This should be faster than mod. | ||
225 | + const int perim_index = i < circle_perimeter ? i : i - circle_perimeter; | ||
226 | + | ||
227 | + // This gets the value of the current pixel along the perimeter by using | ||
228 | + // a precomputed offset. | ||
229 | + const int curr_value = | ||
230 | + static_cast<int>(center_ptr[offsets[perim_index]]); | ||
231 | + | ||
232 | + const int difference = curr_value - center_value; | ||
233 | + | ||
234 | + if (difference > kFastDiffAmount) { | ||
235 | + above_diff += difference; | ||
236 | + ++num_above; | ||
237 | + | ||
238 | + num_below = 0; | ||
239 | + below_diff = 0; | ||
240 | + | ||
241 | + if (num_above >= threshold) { | ||
242 | + return above_diff; | ||
243 | + } | ||
244 | + } else if (difference < -kFastDiffAmount) { | ||
245 | + below_diff += difference; | ||
246 | + ++num_below; | ||
247 | + | ||
248 | + num_above = 0; | ||
249 | + above_diff = 0; | ||
250 | + | ||
251 | + if (num_below >= threshold) { | ||
252 | + return below_diff; | ||
253 | + } | ||
254 | + } else { | ||
255 | + num_above = 0; | ||
256 | + num_below = 0; | ||
257 | + above_diff = 0; | ||
258 | + below_diff = 0; | ||
259 | + } | ||
260 | + | ||
261 | + // See if there's any chance of making the threshold. | ||
262 | + if (MAX(num_above, num_below) < minimum_by_now) { | ||
263 | + // Didn't pass. | ||
264 | + return 0; | ||
265 | + } | ||
266 | + ++minimum_by_now; | ||
267 | + } | ||
268 | + | ||
269 | + // Didn't pass. | ||
270 | + return 0; | ||
271 | +} | ||
272 | + | ||
273 | + | ||
274 | +// Returns a score in the range [0.0, positive infinity) which represents the | ||
275 | +// relative likelihood of a point being a corner. | ||
276 | +float KeypointDetector::HarrisFilter(const Image<int32_t>& I_x, | ||
277 | + const Image<int32_t>& I_y, const float x, | ||
278 | + const float y) const { | ||
279 | + if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) && | ||
280 | + I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) { | ||
281 | + // Image gradient matrix. | ||
282 | + float G[] = { 0, 0, 0, 0 }; | ||
283 | + CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G); | ||
284 | + | ||
285 | + const float dx = G[0]; | ||
286 | + const float dy = G[3]; | ||
287 | + const float dxy = G[1]; | ||
288 | + | ||
289 | + // Harris-Nobel corner score. | ||
290 | + return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN); | ||
291 | + } | ||
292 | + | ||
293 | + return 0.0f; | ||
294 | +} | ||
295 | + | ||
296 | + | ||
297 | +int KeypointDetector::AddExtraCandidatesForBoxes( | ||
298 | + const std::vector<BoundingBox>& boxes, | ||
299 | + const int max_num_keypoints, | ||
300 | + Keypoint* const keypoints) const { | ||
301 | + int num_keypoints_added = 0; | ||
302 | + | ||
303 | + for (std::vector<BoundingBox>::const_iterator iter = boxes.begin(); | ||
304 | + iter != boxes.end(); ++iter) { | ||
305 | + const BoundingBox box = *iter; | ||
306 | + | ||
307 | + for (int i = 0; i < kNumToAddAsCandidates; ++i) { | ||
308 | + for (int j = 0; j < kNumToAddAsCandidates; ++j) { | ||
309 | + if (num_keypoints_added >= max_num_keypoints) { | ||
310 | + LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints); | ||
311 | + return num_keypoints_added; | ||
312 | + } | ||
313 | + | ||
314 | + Keypoint& curr_keypoint = keypoints[num_keypoints_added++]; | ||
315 | + curr_keypoint.pos_ = Point2f( | ||
316 | + box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates, | ||
317 | + box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates); | ||
318 | + curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST; | ||
319 | + } | ||
320 | + } | ||
321 | + } | ||
322 | + | ||
323 | + return num_keypoints_added; | ||
324 | +} | ||
325 | + | ||
326 | + | ||
327 | +void KeypointDetector::FindKeypoints(const ImageData& image_data, | ||
328 | + const std::vector<BoundingBox>& rois, | ||
329 | + const FramePair& prev_change, | ||
330 | + FramePair* const curr_change) { | ||
331 | + // Copy keypoints from second frame of last pass to temp keypoints of this | ||
332 | + // pass. | ||
333 | + int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_); | ||
334 | + | ||
335 | + const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints; | ||
336 | + number_of_tmp_keypoints += | ||
337 | + FindFastKeypoints(image_data, max_num_fast, | ||
338 | + tmp_keypoints_ + number_of_tmp_keypoints); | ||
339 | + | ||
340 | + TimeLog("Found FAST keypoints"); | ||
341 | + | ||
342 | + if (number_of_tmp_keypoints >= kMaxTempKeypoints) { | ||
343 | + LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints", | ||
344 | + kMaxTempKeypoints, number_of_tmp_keypoints); | ||
345 | + } | ||
346 | + | ||
347 | + if (kAddArbitraryKeypoints) { | ||
348 | + // Add some for each object prior to scoring. | ||
349 | + const int max_num_box_keypoints = | ||
350 | + kMaxTempKeypoints - number_of_tmp_keypoints; | ||
351 | + number_of_tmp_keypoints += | ||
352 | + AddExtraCandidatesForBoxes(rois, max_num_box_keypoints, | ||
353 | + tmp_keypoints_ + number_of_tmp_keypoints); | ||
354 | + TimeLog("Added box keypoints"); | ||
355 | + | ||
356 | + if (number_of_tmp_keypoints >= kMaxTempKeypoints) { | ||
357 | + LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints", | ||
358 | + kMaxTempKeypoints, number_of_tmp_keypoints); | ||
359 | + } | ||
360 | + } | ||
361 | + | ||
362 | + // Score them... | ||
363 | + LOGV("Scoring %d keypoints!", number_of_tmp_keypoints); | ||
364 | + ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_); | ||
365 | + TimeLog("Scored keypoints"); | ||
366 | + | ||
367 | + // Now pare it down a bit. | ||
368 | + SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_); | ||
369 | + TimeLog("Sorted keypoints"); | ||
370 | + | ||
371 | + LOGV("%d keypoints to select from!", number_of_tmp_keypoints); | ||
372 | + | ||
373 | + SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change); | ||
374 | + TimeLog("Selected keypoints"); | ||
375 | + | ||
376 | + LOGV("Picked %d (%d max) final keypoints out of %d potential.", | ||
377 | + curr_change->number_of_keypoints_, | ||
378 | + kMaxKeypoints, number_of_tmp_keypoints); | ||
379 | +} | ||
380 | + | ||
381 | + | ||
382 | +int KeypointDetector::CopyKeypoints(const FramePair& prev_change, | ||
383 | + Keypoint* const new_keypoints) { | ||
384 | + int number_of_keypoints = 0; | ||
385 | + | ||
386 | + // Caching values from last pass, just copy and compact. | ||
387 | + for (int i = 0; i < prev_change.number_of_keypoints_; ++i) { | ||
388 | + if (prev_change.optical_flow_found_keypoint_[i]) { | ||
389 | + new_keypoints[number_of_keypoints] = | ||
390 | + prev_change.frame2_keypoints_[i]; | ||
391 | + | ||
392 | + new_keypoints[number_of_keypoints].score_ = | ||
393 | + prev_change.frame1_keypoints_[i].score_; | ||
394 | + | ||
395 | + ++number_of_keypoints; | ||
396 | + } | ||
397 | + } | ||
398 | + | ||
399 | + TimeLog("Copied keypoints"); | ||
400 | + return number_of_keypoints; | ||
401 | +} | ||
402 | + | ||
403 | + | ||
404 | +// FAST keypoint detector. | ||
405 | +int KeypointDetector::FindFastKeypoints(const Image<uint8_t>& frame, | ||
406 | + const int quadrant, | ||
407 | + const int downsample_factor, | ||
408 | + const int max_num_keypoints, | ||
409 | + Keypoint* const keypoints) { | ||
410 | + /* | ||
411 | + // Reference for a circle of diameter 7. | ||
412 | + const int circle[] = {0, 0, 1, 1, 1, 0, 0, | ||
413 | + 0, 1, 0, 0, 0, 1, 0, | ||
414 | + 1, 0, 0, 0, 0, 0, 1, | ||
415 | + 1, 0, 0, 0, 0, 0, 1, | ||
416 | + 1, 0, 0, 0, 0, 0, 1, | ||
417 | + 0, 1, 0, 0, 0, 1, 0, | ||
418 | + 0, 0, 1, 1, 1, 0, 0}; | ||
419 | + const int circle_offset[] = | ||
420 | + {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46}; | ||
421 | + */ | ||
422 | + | ||
423 | + // Quick test of compass directions. Any length 16 circle with a break of up | ||
424 | + // to 4 pixels will have at least 3 of these 4 pixels active. | ||
425 | + static const int short_circle_perimeter = 4; | ||
426 | + static const int short_threshold = 3; | ||
427 | + static const int short_circle_x[] = { -3, 0, +3, 0 }; | ||
428 | + static const int short_circle_y[] = { 0, -3, 0, +3 }; | ||
429 | + | ||
430 | + // Precompute image offsets. | ||
431 | + int short_offsets[short_circle_perimeter]; | ||
432 | + for (int i = 0; i < short_circle_perimeter; ++i) { | ||
433 | + short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth(); | ||
434 | + } | ||
435 | + | ||
436 | + // Large circle values. | ||
437 | + static const int full_circle_perimeter = 16; | ||
438 | + static const int full_threshold = 12; | ||
439 | + static const int full_circle_x[] = | ||
440 | + { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 }; | ||
441 | + static const int full_circle_y[] = | ||
442 | + { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 }; | ||
443 | + | ||
444 | + // Precompute image offsets. | ||
445 | + int full_offsets[full_circle_perimeter]; | ||
446 | + for (int i = 0; i < full_circle_perimeter; ++i) { | ||
447 | + full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth(); | ||
448 | + } | ||
449 | + | ||
450 | + const int scratch_stride = frame.stride(); | ||
451 | + | ||
452 | + keypoint_scratch_->Clear(0); | ||
453 | + | ||
454 | + // Set up the bounds on the region to test based on the passed-in quadrant. | ||
455 | + const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer; | ||
456 | + const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer; | ||
457 | + const int start_x = | ||
458 | + kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width); | ||
459 | + const int start_y = | ||
460 | + kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height); | ||
461 | + const int end_x = start_x + quadrant_width; | ||
462 | + const int end_y = start_y + quadrant_height; | ||
463 | + | ||
464 | + // Loop through once to find FAST keypoint clumps. | ||
465 | + for (int img_y = start_y; img_y < end_y; ++img_y) { | ||
466 | + const uint8_t* curr_pixel_ptr = frame[img_y] + start_x; | ||
467 | + | ||
468 | + for (int img_x = start_x; img_x < end_x; ++img_x) { | ||
469 | + // Only insert it if it meets the quick minimum requirements test. | ||
470 | + if (TestCircle(short_circle_perimeter, short_threshold, | ||
471 | + curr_pixel_ptr, short_offsets) != 0) { | ||
472 | + // Longer test for actual keypoint score.. | ||
473 | + const int fast_score = TestCircle(full_circle_perimeter, | ||
474 | + full_threshold, | ||
475 | + curr_pixel_ptr, | ||
476 | + full_offsets); | ||
477 | + | ||
478 | + // Non-zero score means the keypoint was found. | ||
479 | + if (fast_score != 0) { | ||
480 | + uint8_t* const center_ptr = (*keypoint_scratch_)[img_y] + img_x; | ||
481 | + | ||
482 | + // Increase the keypoint count on this pixel and the pixels in all | ||
483 | + // 4 cardinal directions. | ||
484 | + *center_ptr += 5; | ||
485 | + *(center_ptr - 1) += 1; | ||
486 | + *(center_ptr + 1) += 1; | ||
487 | + *(center_ptr - scratch_stride) += 1; | ||
488 | + *(center_ptr + scratch_stride) += 1; | ||
489 | + } | ||
490 | + } | ||
491 | + | ||
492 | + ++curr_pixel_ptr; | ||
493 | + } // x | ||
494 | + } // y | ||
495 | + | ||
496 | + TimeLog("Found FAST keypoints."); | ||
497 | + | ||
498 | + int num_keypoints = 0; | ||
499 | + // Loop through again and Harris filter pixels in the center of clumps. | ||
500 | + // We can shrink the window by 1 pixel on every side. | ||
501 | + for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) { | ||
502 | + const uint8_t* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x; | ||
503 | + | ||
504 | + for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) { | ||
505 | + if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) { | ||
506 | + Keypoint* const keypoint = keypoints + num_keypoints; | ||
507 | + keypoint->pos_ = Point2f( | ||
508 | + img_x * downsample_factor, img_y * downsample_factor); | ||
509 | + keypoint->score_ = 0; | ||
510 | + keypoint->type_ = KEYPOINT_TYPE_FAST; | ||
511 | + | ||
512 | + ++num_keypoints; | ||
513 | + if (num_keypoints >= max_num_keypoints) { | ||
514 | + return num_keypoints; | ||
515 | + } | ||
516 | + } | ||
517 | + | ||
518 | + ++curr_pixel_ptr; | ||
519 | + } // x | ||
520 | + } // y | ||
521 | + | ||
522 | + TimeLog("Picked FAST keypoints."); | ||
523 | + | ||
524 | + return num_keypoints; | ||
525 | +} | ||
526 | + | ||
527 | +int KeypointDetector::FindFastKeypoints(const ImageData& image_data, | ||
528 | + const int max_num_keypoints, | ||
529 | + Keypoint* const keypoints) { | ||
530 | + int downsample_factor = 1; | ||
531 | + int num_found = 0; | ||
532 | + | ||
533 | + // TODO(andrewharp): Get this working for multiple image scales. | ||
534 | + for (int i = 0; i < 1; ++i) { | ||
535 | + const Image<uint8_t>& frame = *image_data.GetPyramidSqrt2Level(i); | ||
536 | + num_found += FindFastKeypoints( | ||
537 | + frame, fast_quadrant_, | ||
538 | + downsample_factor, max_num_keypoints, keypoints + num_found); | ||
539 | + downsample_factor *= 2; | ||
540 | + } | ||
541 | + | ||
542 | + // Increment the current quadrant. | ||
543 | + fast_quadrant_ = (fast_quadrant_ + 1) % 4; | ||
544 | + | ||
545 | + return num_found; | ||
546 | +} | ||
547 | + | ||
548 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ | ||
18 | + | ||
19 | +#include <stdint.h> | ||
20 | +#include <vector> | ||
21 | + | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" | ||
26 | + | ||
27 | +namespace tf_tracking { | ||
28 | + | ||
29 | +struct Keypoint; | ||
30 | + | ||
31 | +class KeypointDetector { | ||
32 | + public: | ||
33 | + explicit KeypointDetector(const KeypointDetectorConfig* const config) | ||
34 | + : config_(config), | ||
35 | + keypoint_scratch_(new Image<uint8_t>(config_->image_size)), | ||
36 | + interest_map_(new Image<bool>(config_->image_size)), | ||
37 | + fast_quadrant_(0) { | ||
38 | + interest_map_->Clear(false); | ||
39 | + } | ||
40 | + | ||
41 | + ~KeypointDetector() {} | ||
42 | + | ||
43 | + // Finds a new set of keypoints for the current frame, picked from the current | ||
44 | + // set of keypoints and also from a set discovered via a keypoint detector. | ||
45 | + // Special attention is applied to make sure that keypoints are distributed | ||
46 | + // within the supplied ROIs. | ||
47 | + void FindKeypoints(const ImageData& image_data, | ||
48 | + const std::vector<BoundingBox>& rois, | ||
49 | + const FramePair& prev_change, | ||
50 | + FramePair* const curr_change); | ||
51 | + | ||
52 | + private: | ||
53 | + // Compute the corneriness of a point in the image. | ||
54 | + float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y, | ||
55 | + const float x, const float y) const; | ||
56 | + | ||
57 | + // Adds a grid of candidate keypoints to the given box, up to | ||
58 | + // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower. | ||
59 | + int AddExtraCandidatesForBoxes( | ||
60 | + const std::vector<BoundingBox>& boxes, | ||
61 | + const int max_num_keypoints, | ||
62 | + Keypoint* const keypoints) const; | ||
63 | + | ||
64 | + // Scan the frame for potential keypoints using the FAST keypoint detector. | ||
65 | + // Quadrant is an argument 0-3 which refers to the quadrant of the image in | ||
66 | + // which to detect keypoints. | ||
67 | + int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant, | ||
68 | + const int downsample_factor, | ||
69 | + const int max_num_keypoints, Keypoint* const keypoints); | ||
70 | + | ||
71 | + int FindFastKeypoints(const ImageData& image_data, | ||
72 | + const int max_num_keypoints, | ||
73 | + Keypoint* const keypoints); | ||
74 | + | ||
75 | + // Score a bunch of candidate keypoints. Assigns the scores to the input | ||
76 | + // candidate_keypoints array entries. | ||
77 | + void ScoreKeypoints(const ImageData& image_data, | ||
78 | + const int num_candidates, | ||
79 | + Keypoint* const candidate_keypoints); | ||
80 | + | ||
81 | + void SortKeypoints(const int num_candidates, | ||
82 | + Keypoint* const candidate_keypoints) const; | ||
83 | + | ||
84 | + // Selects a set of keypoints falling within the supplied box such that the | ||
85 | + // most highly rated keypoints are picked first, and so that none of them are | ||
86 | + // too close together. | ||
87 | + int SelectKeypointsInBox( | ||
88 | + const BoundingBox& box, | ||
89 | + const Keypoint* const candidate_keypoints, | ||
90 | + const int num_candidates, | ||
91 | + const int max_keypoints, | ||
92 | + const int num_existing_keypoints, | ||
93 | + const Keypoint* const existing_keypoints, | ||
94 | + Keypoint* const final_keypoints) const; | ||
95 | + | ||
96 | + // Selects from the supplied sorted keypoint pool a set of keypoints that will | ||
97 | + // best cover the given set of boxes, such that each box is covered at a | ||
98 | + // resolution proportional to its size. | ||
99 | + void SelectKeypoints( | ||
100 | + const std::vector<BoundingBox>& boxes, | ||
101 | + const Keypoint* const candidate_keypoints, | ||
102 | + const int num_candidates, | ||
103 | + FramePair* const frame_change) const; | ||
104 | + | ||
105 | + // Copies and compacts the found keypoints in the second frame of prev_change | ||
106 | + // into the array at new_keypoints. | ||
107 | + static int CopyKeypoints(const FramePair& prev_change, | ||
108 | + Keypoint* const new_keypoints); | ||
109 | + | ||
110 | + const KeypointDetectorConfig* const config_; | ||
111 | + | ||
112 | + // Scratch memory for keypoint candidacy detection and non-max suppression. | ||
113 | + std::unique_ptr<Image<uint8_t> > keypoint_scratch_; | ||
114 | + | ||
115 | + // Regions of the image to pay special attention to. | ||
116 | + std::unique_ptr<Image<bool> > interest_map_; | ||
117 | + | ||
118 | + // The current quadrant of the image to detect FAST keypoints in. | ||
119 | + // Keypoint detection is staggered for performance reasons. Every four frames | ||
120 | + // a full scan of the frame will have been performed. | ||
121 | + int fast_quadrant_; | ||
122 | + | ||
123 | + Keypoint tmp_keypoints_[kMaxTempKeypoints]; | ||
124 | +}; | ||
125 | + | ||
126 | +} // namespace tf_tracking | ||
127 | + | ||
128 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ |
1 | +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
17 | + | ||
18 | +#ifdef STANDALONE_DEMO_LIB | ||
19 | + | ||
20 | +#include <android/log.h> | ||
21 | +#include <stdlib.h> | ||
22 | +#include <time.h> | ||
23 | +#include <iostream> | ||
24 | +#include <sstream> | ||
25 | + | ||
26 | +LogMessage::LogMessage(const char* fname, int line, int severity) | ||
27 | + : fname_(fname), line_(line), severity_(severity) {} | ||
28 | + | ||
29 | +void LogMessage::GenerateLogMessage() { | ||
30 | + int android_log_level; | ||
31 | + switch (severity_) { | ||
32 | + case INFO: | ||
33 | + android_log_level = ANDROID_LOG_INFO; | ||
34 | + break; | ||
35 | + case WARNING: | ||
36 | + android_log_level = ANDROID_LOG_WARN; | ||
37 | + break; | ||
38 | + case ERROR: | ||
39 | + android_log_level = ANDROID_LOG_ERROR; | ||
40 | + break; | ||
41 | + case FATAL: | ||
42 | + android_log_level = ANDROID_LOG_FATAL; | ||
43 | + break; | ||
44 | + default: | ||
45 | + if (severity_ < INFO) { | ||
46 | + android_log_level = ANDROID_LOG_VERBOSE; | ||
47 | + } else { | ||
48 | + android_log_level = ANDROID_LOG_ERROR; | ||
49 | + } | ||
50 | + break; | ||
51 | + } | ||
52 | + | ||
53 | + std::stringstream ss; | ||
54 | + const char* const partial_name = strrchr(fname_, '/'); | ||
55 | + ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_ | ||
56 | + << " " << str(); | ||
57 | + __android_log_write(android_log_level, "native", ss.str().c_str()); | ||
58 | + | ||
59 | + // Also log to stderr (for standalone Android apps). | ||
60 | + std::cerr << "native : " << ss.str() << std::endl; | ||
61 | + | ||
62 | + // Android logging at level FATAL does not terminate execution, so abort() | ||
63 | + // is still required to stop the program. | ||
64 | + if (severity_ == FATAL) { | ||
65 | + abort(); | ||
66 | + } | ||
67 | +} | ||
68 | + | ||
69 | +namespace { | ||
70 | + | ||
71 | +// Parse log level (int64) from environment variable (char*) | ||
72 | +int64_t LogLevelStrToInt(const char* tf_env_var_val) { | ||
73 | + if (tf_env_var_val == nullptr) { | ||
74 | + return 0; | ||
75 | + } | ||
76 | + | ||
77 | + // Ideally we would use env_var / safe_strto64, but it is | ||
78 | + // hard to use here without pulling in a lot of dependencies, | ||
79 | + // so we use std:istringstream instead | ||
80 | + std::string min_log_level(tf_env_var_val); | ||
81 | + std::istringstream ss(min_log_level); | ||
82 | + int64_t level; | ||
83 | + if (!(ss >> level)) { | ||
84 | + // Invalid vlog level setting, set level to default (0) | ||
85 | + level = 0; | ||
86 | + } | ||
87 | + | ||
88 | + return level; | ||
89 | +} | ||
90 | + | ||
91 | +int64_t MinLogLevelFromEnv() { | ||
92 | + const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL"); | ||
93 | + return LogLevelStrToInt(tf_env_var_val); | ||
94 | +} | ||
95 | + | ||
96 | +int64_t MinVLogLevelFromEnv() { | ||
97 | + const char* tf_env_var_val = getenv("TF_CPP_MIN_VLOG_LEVEL"); | ||
98 | + return LogLevelStrToInt(tf_env_var_val); | ||
99 | +} | ||
100 | + | ||
101 | +} // namespace | ||
102 | + | ||
103 | +LogMessage::~LogMessage() { | ||
104 | + // Read the min log level once during the first call to logging. | ||
105 | + static int64_t min_log_level = MinLogLevelFromEnv(); | ||
106 | + if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage(); | ||
107 | +} | ||
108 | + | ||
109 | +int64_t LogMessage::MinVLogLevel() { | ||
110 | + static const int64_t min_vlog_level = MinVLogLevelFromEnv(); | ||
111 | + return min_vlog_level; | ||
112 | +} | ||
113 | + | ||
114 | +LogMessageFatal::LogMessageFatal(const char* file, int line) | ||
115 | + : LogMessage(file, line, ANDROID_LOG_FATAL) {} | ||
116 | +LogMessageFatal::~LogMessageFatal() { | ||
117 | + // abort() ensures we don't return (we promised we would not via | ||
118 | + // ATTRIBUTE_NORETURN). | ||
119 | + GenerateLogMessage(); | ||
120 | + abort(); | ||
121 | +} | ||
122 | + | ||
123 | +void LogString(const char* fname, int line, int severity, | ||
124 | + const std::string& message) { | ||
125 | + LogMessage(fname, line, severity) << message; | ||
126 | +} | ||
127 | + | ||
128 | +void LogPrintF(const int severity, const char* format, ...) { | ||
129 | + char message[1024]; | ||
130 | + va_list argptr; | ||
131 | + va_start(argptr, format); | ||
132 | + vsnprintf(message, 1024, format, argptr); | ||
133 | + va_end(argptr); | ||
134 | + __android_log_write(severity, "native", message); | ||
135 | + | ||
136 | + // Also log to stderr (for standalone Android apps). | ||
137 | + std::cerr << "native : " << message << std::endl; | ||
138 | +} | ||
139 | + | ||
140 | +#endif |
1 | +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_ | ||
18 | + | ||
19 | +#include <android/log.h> | ||
20 | +#include <string.h> | ||
21 | +#include <ostream> | ||
22 | +#include <sstream> | ||
23 | +#include <string> | ||
24 | + | ||
25 | +// Allow this library to be built without depending on TensorFlow by | ||
26 | +// defining STANDALONE_DEMO_LIB. Otherwise TensorFlow headers will be | ||
27 | +// used. | ||
28 | +#ifdef STANDALONE_DEMO_LIB | ||
29 | + | ||
30 | +// A macro to disallow the copy constructor and operator= functions | ||
31 | +// This is usually placed in the private: declarations for a class. | ||
32 | +#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \ | ||
33 | + TypeName(const TypeName&) = delete; \ | ||
34 | + void operator=(const TypeName&) = delete | ||
35 | + | ||
36 | +#if defined(COMPILER_GCC3) | ||
37 | +#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0)) | ||
38 | +#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) | ||
39 | +#else | ||
40 | +#define TF_PREDICT_FALSE(x) (x) | ||
41 | +#define TF_PREDICT_TRUE(x) (x) | ||
42 | +#endif | ||
43 | + | ||
44 | +// Log levels equivalent to those defined by | ||
45 | +// third_party/tensorflow/core/platform/logging.h | ||
46 | +const int INFO = 0; // base_logging::INFO; | ||
47 | +const int WARNING = 1; // base_logging::WARNING; | ||
48 | +const int ERROR = 2; // base_logging::ERROR; | ||
49 | +const int FATAL = 3; // base_logging::FATAL; | ||
50 | +const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES; | ||
51 | + | ||
52 | +class LogMessage : public std::basic_ostringstream<char> { | ||
53 | + public: | ||
54 | + LogMessage(const char* fname, int line, int severity); | ||
55 | + ~LogMessage(); | ||
56 | + | ||
57 | + // Returns the minimum log level for VLOG statements. | ||
58 | + // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output, | ||
59 | + // but VLOG(3) will not. Defaults to 0. | ||
60 | + static int64_t MinVLogLevel(); | ||
61 | + | ||
62 | + protected: | ||
63 | + void GenerateLogMessage(); | ||
64 | + | ||
65 | + private: | ||
66 | + const char* fname_; | ||
67 | + int line_; | ||
68 | + int severity_; | ||
69 | +}; | ||
70 | + | ||
71 | +// LogMessageFatal ensures the process will exit in failure after | ||
72 | +// logging this message. | ||
73 | +class LogMessageFatal : public LogMessage { | ||
74 | + public: | ||
75 | + LogMessageFatal(const char* file, int line); | ||
76 | + ~LogMessageFatal(); | ||
77 | +}; | ||
78 | + | ||
79 | +#define _TF_LOG_INFO \ | ||
80 | + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO) | ||
81 | +#define _TF_LOG_WARNING \ | ||
82 | + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::WARNING) | ||
83 | +#define _TF_LOG_ERROR \ | ||
84 | + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::ERROR) | ||
85 | +#define _TF_LOG_FATAL \ | ||
86 | + ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__) | ||
87 | + | ||
88 | +#define _TF_LOG_QFATAL _TF_LOG_FATAL | ||
89 | + | ||
90 | +#define LOG(severity) _TF_LOG_##severity | ||
91 | + | ||
92 | +#define VLOG_IS_ON(lvl) ((lvl) <= LogMessage::MinVLogLevel()) | ||
93 | + | ||
94 | +#define VLOG(lvl) \ | ||
95 | + if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \ | ||
96 | + LogMessage(__FILE__, __LINE__, ANDROID_LOG_INFO) | ||
97 | + | ||
98 | +void LogPrintF(const int severity, const char* format, ...); | ||
99 | + | ||
100 | +// Support for printf style logging. | ||
101 | +#define LOGV(...) | ||
102 | +#define LOGD(...) | ||
103 | +#define LOGI(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__); | ||
104 | +#define LOGW(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__); | ||
105 | +#define LOGE(...) LogPrintF(ANDROID_LOG_ERROR, __VA_ARGS__); | ||
106 | + | ||
107 | +#else | ||
108 | + | ||
109 | +#include "tensorflow/core/lib/strings/stringprintf.h" | ||
110 | +#include "tensorflow/core/platform/logging.h" | ||
111 | + | ||
112 | +// Support for printf style logging. | ||
113 | +#define LOGV(...) | ||
114 | +#define LOGD(...) | ||
115 | +#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); | ||
116 | +#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); | ||
117 | +#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); | ||
118 | + | ||
119 | +#endif | ||
120 | + | ||
121 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// NOTE: no native object detectors are currently provided or used by the code | ||
17 | +// in this directory. This class remains mainly for historical reasons. | ||
18 | +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. | ||
19 | + | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" | ||
21 | + | ||
22 | +namespace tf_tracking { | ||
23 | + | ||
24 | +// This is here so that the vtable gets created properly. | ||
25 | +ObjectDetectorBase::~ObjectDetectorBase() {} | ||
26 | + | ||
27 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// NOTE: no native object detectors are currently provided or used by the code | ||
17 | +// in this directory. This class remains mainly for historical reasons. | ||
18 | +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. | ||
19 | + | ||
20 | +// Defines the ObjectDetector class that is the main interface for detecting | ||
21 | +// ObjectModelBases in frames. | ||
22 | + | ||
23 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ | ||
24 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ | ||
25 | + | ||
26 | +#include <float.h> | ||
27 | +#include <map> | ||
28 | +#include <memory> | ||
29 | +#include <sstream> | ||
30 | +#include <string> | ||
31 | +#include <vector> | ||
32 | + | ||
33 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
34 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
35 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
36 | +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" | ||
37 | +#ifdef __RENDER_OPENGL__ | ||
38 | +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" | ||
39 | +#endif | ||
40 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
41 | + | ||
42 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
43 | +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" | ||
44 | +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" | ||
45 | + | ||
46 | +namespace tf_tracking { | ||
47 | + | ||
48 | +// Adds BoundingSquares to a vector such that the first square added is centered | ||
49 | +// in the position given and of square_size, and the remaining squares are added | ||
50 | +// concentrentically, scaling down by scale_factor until the minimum threshold | ||
51 | +// size is passed. | ||
52 | +// Squares that do not fall completely within image_bounds will not be added. | ||
53 | +static inline void FillWithSquares( | ||
54 | + const BoundingBox& image_bounds, | ||
55 | + const BoundingBox& position, | ||
56 | + const float starting_square_size, | ||
57 | + const float smallest_square_size, | ||
58 | + const float scale_factor, | ||
59 | + std::vector<BoundingSquare>* const squares) { | ||
60 | + BoundingSquare descriptor_area = | ||
61 | + GetCenteredSquare(position, starting_square_size); | ||
62 | + | ||
63 | + SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor); | ||
64 | + | ||
65 | + // Use a do/while loop to ensure that at least one descriptor is created. | ||
66 | + do { | ||
67 | + if (image_bounds.Contains(descriptor_area.ToBoundingBox())) { | ||
68 | + squares->push_back(descriptor_area); | ||
69 | + } | ||
70 | + descriptor_area.Scale(scale_factor); | ||
71 | + } while (descriptor_area.size_ >= smallest_square_size - EPSILON); | ||
72 | + LOGV("Created %zu squares starting from size %.2f to min size %.2f " | ||
73 | + "using scale factor: %.2f", | ||
74 | + squares->size(), starting_square_size, smallest_square_size, | ||
75 | + scale_factor); | ||
76 | +} | ||
77 | + | ||
78 | + | ||
79 | +// Represents a potential detection of a specific ObjectExemplar and Descriptor | ||
80 | +// at a specific position in the image. | ||
81 | +class Detection { | ||
82 | + public: | ||
83 | + explicit Detection(const ObjectModelBase* const object_model, | ||
84 | + const MatchScore match_score, | ||
85 | + const BoundingBox& bounding_box) | ||
86 | + : object_model_(object_model), | ||
87 | + match_score_(match_score), | ||
88 | + bounding_box_(bounding_box) {} | ||
89 | + | ||
90 | + Detection(const Detection& other) | ||
91 | + : object_model_(other.object_model_), | ||
92 | + match_score_(other.match_score_), | ||
93 | + bounding_box_(other.bounding_box_) {} | ||
94 | + | ||
95 | + virtual ~Detection() {} | ||
96 | + | ||
97 | + inline BoundingBox GetObjectBoundingBox() const { | ||
98 | + return bounding_box_; | ||
99 | + } | ||
100 | + | ||
101 | + inline MatchScore GetMatchScore() const { | ||
102 | + return match_score_; | ||
103 | + } | ||
104 | + | ||
105 | + inline const ObjectModelBase* GetObjectModel() const { | ||
106 | + return object_model_; | ||
107 | + } | ||
108 | + | ||
109 | + inline bool Intersects(const Detection& other) { | ||
110 | + // Check if any of the four axes separates us, there must be at least one. | ||
111 | + return bounding_box_.Intersects(other.bounding_box_); | ||
112 | + } | ||
113 | + | ||
114 | + struct Comp { | ||
115 | + inline bool operator()(const Detection& a, const Detection& b) const { | ||
116 | + return a.match_score_ > b.match_score_; | ||
117 | + } | ||
118 | + }; | ||
119 | + | ||
120 | + // TODO(andrewharp): add accessors to update these instead. | ||
121 | + const ObjectModelBase* object_model_; | ||
122 | + MatchScore match_score_; | ||
123 | + BoundingBox bounding_box_; | ||
124 | +}; | ||
125 | + | ||
126 | +inline std::ostream& operator<<(std::ostream& stream, | ||
127 | + const Detection& detection) { | ||
128 | + const BoundingBox actual_area = detection.GetObjectBoundingBox(); | ||
129 | + stream << actual_area; | ||
130 | + return stream; | ||
131 | +} | ||
132 | + | ||
133 | +class ObjectDetectorBase { | ||
134 | + public: | ||
135 | + explicit ObjectDetectorBase(const ObjectDetectorConfig* const config) | ||
136 | + : config_(config), | ||
137 | + image_data_(NULL) {} | ||
138 | + | ||
139 | + virtual ~ObjectDetectorBase(); | ||
140 | + | ||
141 | + // Sets the current image data. All calls to ObjectDetector other than | ||
142 | + // FillDescriptors use the image data last set. | ||
143 | + inline void SetImageData(const ImageData* const image_data) { | ||
144 | + image_data_ = image_data; | ||
145 | + } | ||
146 | + | ||
147 | + // Main entry point into the detection algorithm. | ||
148 | + // Scans the frame for candidates, tweaks them, and fills in the | ||
149 | + // given std::vector of Detection objects with acceptable matches. | ||
150 | + virtual void Detect(const std::vector<BoundingSquare>& positions, | ||
151 | + std::vector<Detection>* const detections) const = 0; | ||
152 | + | ||
153 | + virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0; | ||
154 | + | ||
155 | + virtual void DeleteObjectModel(const std::string& name) = 0; | ||
156 | + | ||
157 | + virtual void GetObjectModels( | ||
158 | + std::vector<const ObjectModelBase*>* models) const = 0; | ||
159 | + | ||
160 | + // Creates a new ObjectExemplar from the given position in the context of | ||
161 | + // the last frame passed to NextFrame. | ||
162 | + // Will return null in the case that there's no room for a descriptor to be | ||
163 | + // created in the example area, or the example area is not completely | ||
164 | + // contained within the frame. | ||
165 | + virtual void UpdateModel(const Image<uint8_t>& base_image, | ||
166 | + const IntegralImage& integral_image, | ||
167 | + const BoundingBox& bounding_box, const bool locked, | ||
168 | + ObjectModelBase* model) const = 0; | ||
169 | + | ||
170 | + virtual void Draw() const = 0; | ||
171 | + | ||
172 | + virtual bool AllowSpontaneousDetections() = 0; | ||
173 | + | ||
174 | + protected: | ||
175 | + const std::unique_ptr<const ObjectDetectorConfig> config_; | ||
176 | + | ||
177 | + // The latest frame data, upon which all detections will be performed. | ||
178 | + // Not owned by this object, just provided for reference by ObjectTracker | ||
179 | + // via SetImageData(). | ||
180 | + const ImageData* image_data_; | ||
181 | + | ||
182 | + private: | ||
183 | + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase); | ||
184 | +}; | ||
185 | + | ||
186 | +template <typename ModelType> | ||
187 | +class ObjectDetector : public ObjectDetectorBase { | ||
188 | + public: | ||
189 | + explicit ObjectDetector(const ObjectDetectorConfig* const config) | ||
190 | + : ObjectDetectorBase(config) {} | ||
191 | + | ||
192 | + virtual ~ObjectDetector() { | ||
193 | + typename std::map<std::string, ModelType*>::const_iterator it = | ||
194 | + object_models_.begin(); | ||
195 | + for (; it != object_models_.end(); ++it) { | ||
196 | + ModelType* model = it->second; | ||
197 | + delete model; | ||
198 | + } | ||
199 | + } | ||
200 | + | ||
201 | + virtual void DeleteObjectModel(const std::string& name) { | ||
202 | + ModelType* model = object_models_[name]; | ||
203 | + CHECK_ALWAYS(model != NULL, "Model was null!"); | ||
204 | + object_models_.erase(name); | ||
205 | + SAFE_DELETE(model); | ||
206 | + } | ||
207 | + | ||
208 | + virtual void GetObjectModels( | ||
209 | + std::vector<const ObjectModelBase*>* models) const { | ||
210 | + typename std::map<std::string, ModelType*>::const_iterator it = | ||
211 | + object_models_.begin(); | ||
212 | + for (; it != object_models_.end(); ++it) { | ||
213 | + models->push_back(it->second); | ||
214 | + } | ||
215 | + } | ||
216 | + | ||
217 | + virtual bool AllowSpontaneousDetections() { | ||
218 | + return false; | ||
219 | + } | ||
220 | + | ||
221 | + protected: | ||
222 | + std::map<std::string, ModelType*> object_models_; | ||
223 | + | ||
224 | + private: | ||
225 | + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector); | ||
226 | +}; | ||
227 | + | ||
228 | +} // namespace tf_tracking | ||
229 | + | ||
230 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// NOTE: no native object detectors are currently provided or used by the code | ||
17 | +// in this directory. This class remains mainly for historical reasons. | ||
18 | +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. | ||
19 | + | ||
20 | +// Contains ObjectModelBase declaration. | ||
21 | + | ||
22 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_ | ||
23 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_ | ||
24 | + | ||
25 | +#ifdef __RENDER_OPENGL__ | ||
26 | +#include <GLES/gl.h> | ||
27 | +#include <GLES/glext.h> | ||
28 | +#endif | ||
29 | + | ||
30 | +#include <vector> | ||
31 | + | ||
32 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
33 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
34 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
35 | +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" | ||
36 | +#ifdef __RENDER_OPENGL__ | ||
37 | +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" | ||
38 | +#endif | ||
39 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
40 | + | ||
41 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
42 | +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" | ||
43 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" | ||
44 | + | ||
45 | +namespace tf_tracking { | ||
46 | + | ||
47 | +// The ObjectModelBase class represents all the known appearance information for | ||
48 | +// an object. It is not a specific instance of the object in the world, | ||
49 | +// but just the general appearance information that enables detection. An | ||
50 | +// ObjectModelBase can be reused across multiple-instances of TrackedObjects. | ||
51 | +class ObjectModelBase { | ||
52 | + public: | ||
53 | + ObjectModelBase(const std::string& name) : name_(name) {} | ||
54 | + | ||
55 | + virtual ~ObjectModelBase() {} | ||
56 | + | ||
57 | + // Called when the next step in an ongoing track occurs. | ||
58 | + virtual void TrackStep(const BoundingBox& position, | ||
59 | + const Image<uint8_t>& image, | ||
60 | + const IntegralImage& integral_image, | ||
61 | + const bool authoritative) {} | ||
62 | + | ||
63 | + // Called when an object track is lost. | ||
64 | + virtual void TrackLost() {} | ||
65 | + | ||
66 | + // Called when an object track is confirmed as legitimate. | ||
67 | + virtual void TrackConfirmed() {} | ||
68 | + | ||
69 | + virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0; | ||
70 | + | ||
71 | + virtual MatchScore GetMatchScore( | ||
72 | + const BoundingBox& position, const ImageData& image_data) const = 0; | ||
73 | + | ||
74 | + virtual void Draw(float* const depth) const = 0; | ||
75 | + | ||
76 | + inline const std::string& GetName() const { | ||
77 | + return name_; | ||
78 | + } | ||
79 | + | ||
80 | + protected: | ||
81 | + const std::string name_; | ||
82 | + | ||
83 | + private: | ||
84 | + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase); | ||
85 | +}; | ||
86 | + | ||
87 | +template <typename DetectorType> | ||
88 | +class ObjectModel : public ObjectModelBase { | ||
89 | + public: | ||
90 | + ObjectModel<DetectorType>(const DetectorType* const detector, | ||
91 | + const std::string& name) | ||
92 | + : ObjectModelBase(name), detector_(detector) {} | ||
93 | + | ||
94 | + protected: | ||
95 | + const DetectorType* const detector_; | ||
96 | + | ||
97 | + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>); | ||
98 | +}; | ||
99 | + | ||
100 | +} // namespace tf_tracking | ||
101 | + | ||
102 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifdef __RENDER_OPENGL__ | ||
17 | +#include <GLES/gl.h> | ||
18 | +#include <GLES/glext.h> | ||
19 | +#endif | ||
20 | + | ||
21 | +#include <cinttypes> | ||
22 | +#include <map> | ||
23 | +#include <string> | ||
24 | + | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" | ||
27 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
28 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
29 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
30 | +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" | ||
31 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" | ||
32 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
33 | +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" | ||
34 | +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" | ||
35 | +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" | ||
36 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
37 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
38 | + | ||
39 | +namespace tf_tracking { | ||
40 | + | ||
41 | +ObjectTracker::ObjectTracker(const TrackerConfig* const config, | ||
42 | + ObjectDetectorBase* const detector) | ||
43 | + : config_(config), | ||
44 | + frame_width_(config->image_size.width), | ||
45 | + frame_height_(config->image_size.height), | ||
46 | + curr_time_(0), | ||
47 | + num_frames_(0), | ||
48 | + flow_cache_(&config->flow_config), | ||
49 | + keypoint_detector_(&config->keypoint_detector_config), | ||
50 | + curr_num_frame_pairs_(0), | ||
51 | + first_frame_index_(0), | ||
52 | + frame1_(new ImageData(frame_width_, frame_height_)), | ||
53 | + frame2_(new ImageData(frame_width_, frame_height_)), | ||
54 | + detector_(detector), | ||
55 | + num_detected_(0) { | ||
56 | + for (int i = 0; i < kNumFrames; ++i) { | ||
57 | + frame_pairs_[i].Init(-1, -1); | ||
58 | + } | ||
59 | +} | ||
60 | + | ||
61 | + | ||
62 | +ObjectTracker::~ObjectTracker() { | ||
63 | + for (TrackedObjectMap::iterator iter = objects_.begin(); | ||
64 | + iter != objects_.end(); iter++) { | ||
65 | + TrackedObject* object = iter->second; | ||
66 | + SAFE_DELETE(object); | ||
67 | + } | ||
68 | +} | ||
69 | + | ||
70 | + | ||
71 | +// Finds the correspondences for all the points in the current pair of frames. | ||
72 | +// Stores the results in the given FramePair. | ||
73 | +void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const { | ||
74 | + // Keypoints aren't found until they're found. | ||
75 | + memset(frame_pair->optical_flow_found_keypoint_, false, | ||
76 | + sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints); | ||
77 | + TimeLog("Cleared old found keypoints"); | ||
78 | + | ||
79 | + int num_keypoints_found = 0; | ||
80 | + | ||
81 | + // For every keypoint... | ||
82 | + for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) { | ||
83 | + Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat; | ||
84 | + Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat; | ||
85 | + | ||
86 | + if (flow_cache_.FindNewPositionOfPoint( | ||
87 | + keypoint1->pos_.x, keypoint1->pos_.y, | ||
88 | + &keypoint2->pos_.x, &keypoint2->pos_.y)) { | ||
89 | + frame_pair->optical_flow_found_keypoint_[i_feat] = true; | ||
90 | + ++num_keypoints_found; | ||
91 | + } | ||
92 | + } | ||
93 | + | ||
94 | + TimeLog("Found correspondences"); | ||
95 | + | ||
96 | + LOGV("Found %d of %d keypoint correspondences", | ||
97 | + num_keypoints_found, frame_pair->number_of_keypoints_); | ||
98 | +} | ||
99 | + | ||
100 | +void ObjectTracker::NextFrame(const uint8_t* const new_frame, | ||
101 | + const uint8_t* const uv_frame, | ||
102 | + const int64_t timestamp, | ||
103 | + const float* const alignment_matrix_2x3) { | ||
104 | + IncrementFrameIndex(); | ||
105 | + LOGV("Received frame %d", num_frames_); | ||
106 | + | ||
107 | + FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0); | ||
108 | + curr_change->Init(curr_time_, timestamp); | ||
109 | + | ||
110 | + CHECK_ALWAYS(curr_time_ < timestamp, | ||
111 | + "Timestamp must monotonically increase! Went from %" PRId64 | ||
112 | + " to %" PRId64 " on frame %d.", | ||
113 | + curr_time_, timestamp, num_frames_); | ||
114 | + | ||
115 | + curr_time_ = timestamp; | ||
116 | + | ||
117 | + // Swap the frames. | ||
118 | + frame1_.swap(frame2_); | ||
119 | + | ||
120 | + frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1); | ||
121 | + | ||
122 | + if (detector_.get() != NULL) { | ||
123 | + detector_->SetImageData(frame2_.get()); | ||
124 | + } | ||
125 | + | ||
126 | + flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3); | ||
127 | + | ||
128 | + if (num_frames_ == 1) { | ||
129 | + // This must be the first frame, so abort. | ||
130 | + return; | ||
131 | + } | ||
132 | + | ||
133 | + if (config_->always_track || objects_.size() > 0) { | ||
134 | + LOGV("Tracking %zu targets", objects_.size()); | ||
135 | + ComputeKeypoints(true); | ||
136 | + TimeLog("Keypoints computed!"); | ||
137 | + | ||
138 | + FindCorrespondences(curr_change); | ||
139 | + TimeLog("Flow computed!"); | ||
140 | + | ||
141 | + TrackObjects(); | ||
142 | + } | ||
143 | + TimeLog("Targets tracked!"); | ||
144 | + | ||
145 | + if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) { | ||
146 | + DetectTargets(); | ||
147 | + } | ||
148 | + TimeLog("Detected objects."); | ||
149 | +} | ||
150 | + | ||
151 | +TrackedObject* ObjectTracker::MaybeAddObject( | ||
152 | + const std::string& id, const Image<uint8_t>& source_image, | ||
153 | + const BoundingBox& bounding_box, const ObjectModelBase* object_model) { | ||
154 | + // Train the detector if this is a new object. | ||
155 | + if (objects_.find(id) != objects_.end()) { | ||
156 | + return objects_[id]; | ||
157 | + } | ||
158 | + | ||
159 | + // Need to get a non-const version of the model, or create a new one if it | ||
160 | + // wasn't given. | ||
161 | + ObjectModelBase* model = NULL; | ||
162 | + if (detector_ != NULL) { | ||
163 | + // If a detector is registered, then this new object must have a model. | ||
164 | + CHECK_ALWAYS(object_model != NULL, "No model given!"); | ||
165 | + model = detector_->CreateObjectModel(object_model->GetName()); | ||
166 | + } | ||
167 | + TrackedObject* const object = | ||
168 | + new TrackedObject(id, source_image, bounding_box, model); | ||
169 | + | ||
170 | + objects_[id] = object; | ||
171 | + return object; | ||
172 | +} | ||
173 | + | ||
174 | +void ObjectTracker::RegisterNewObjectWithAppearance( | ||
175 | + const std::string& id, const uint8_t* const new_frame, | ||
176 | + const BoundingBox& bounding_box) { | ||
177 | + ObjectModelBase* object_model = NULL; | ||
178 | + | ||
179 | + Image<uint8_t> image(frame_width_, frame_height_); | ||
180 | + image.FromArray(new_frame, frame_width_, 1); | ||
181 | + | ||
182 | + if (detector_ != NULL) { | ||
183 | + object_model = detector_->CreateObjectModel(id); | ||
184 | + CHECK_ALWAYS(object_model != NULL, "Null object model!"); | ||
185 | + | ||
186 | + const IntegralImage integral_image(image); | ||
187 | + object_model->TrackStep(bounding_box, image, integral_image, true); | ||
188 | + } | ||
189 | + | ||
190 | + // Create an object at this position. | ||
191 | + CHECK_ALWAYS(!HaveObject(id), "Already have this object!"); | ||
192 | + if (objects_.find(id) == objects_.end()) { | ||
193 | + TrackedObject* const object = | ||
194 | + MaybeAddObject(id, image, bounding_box, object_model); | ||
195 | + CHECK_ALWAYS(object != NULL, "Object not created!"); | ||
196 | + } | ||
197 | +} | ||
198 | + | ||
199 | +void ObjectTracker::SetPreviousPositionOfObject(const std::string& id, | ||
200 | + const BoundingBox& bounding_box, | ||
201 | + const int64_t timestamp) { | ||
202 | + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp); | ||
203 | + CHECK_ALWAYS(timestamp <= curr_time_, | ||
204 | + "Timestamp too great! %" PRId64 " vs %" PRId64, timestamp, | ||
205 | + curr_time_); | ||
206 | + | ||
207 | + TrackedObject* const object = GetObject(id); | ||
208 | + | ||
209 | + // Track this bounding box from the past to the current time. | ||
210 | + const BoundingBox current_position = TrackBox(bounding_box, timestamp); | ||
211 | + | ||
212 | + object->UpdatePosition(current_position, curr_time_, *frame2_, false); | ||
213 | + | ||
214 | + VLOG(2) << "Set tracked position for " << id << " to " << bounding_box | ||
215 | + << std::endl; | ||
216 | +} | ||
217 | + | ||
218 | + | ||
219 | +void ObjectTracker::SetCurrentPositionOfObject( | ||
220 | + const std::string& id, const BoundingBox& bounding_box) { | ||
221 | + SetPreviousPositionOfObject(id, bounding_box, curr_time_); | ||
222 | +} | ||
223 | + | ||
224 | + | ||
225 | +void ObjectTracker::ForgetTarget(const std::string& id) { | ||
226 | + LOGV("Forgetting object %s", id.c_str()); | ||
227 | + TrackedObject* const object = GetObject(id); | ||
228 | + delete object; | ||
229 | + objects_.erase(id); | ||
230 | + | ||
231 | + if (detector_ != NULL) { | ||
232 | + detector_->DeleteObjectModel(id); | ||
233 | + } | ||
234 | +} | ||
235 | + | ||
236 | +int ObjectTracker::GetKeypointsPacked(uint16_t* const out_data, | ||
237 | + const float scale) const { | ||
238 | + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)]; | ||
239 | + uint16_t* curr_data = out_data; | ||
240 | + int num_keypoints = 0; | ||
241 | + | ||
242 | + for (int i = 0; i < change.number_of_keypoints_; ++i) { | ||
243 | + if (change.optical_flow_found_keypoint_[i]) { | ||
244 | + ++num_keypoints; | ||
245 | + const Point2f& point1 = change.frame1_keypoints_[i].pos_; | ||
246 | + *curr_data++ = RealToFixed115(point1.x * scale); | ||
247 | + *curr_data++ = RealToFixed115(point1.y * scale); | ||
248 | + | ||
249 | + const Point2f& point2 = change.frame2_keypoints_[i].pos_; | ||
250 | + *curr_data++ = RealToFixed115(point2.x * scale); | ||
251 | + *curr_data++ = RealToFixed115(point2.y * scale); | ||
252 | + } | ||
253 | + } | ||
254 | + | ||
255 | + return num_keypoints; | ||
256 | +} | ||
257 | + | ||
258 | + | ||
259 | +int ObjectTracker::GetKeypoints(const bool only_found, | ||
260 | + float* const out_data) const { | ||
261 | + int curr_keypoint = 0; | ||
262 | + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)]; | ||
263 | + | ||
264 | + for (int i = 0; i < change.number_of_keypoints_; ++i) { | ||
265 | + if (!only_found || change.optical_flow_found_keypoint_[i]) { | ||
266 | + const int base = curr_keypoint * kKeypointStep; | ||
267 | + out_data[base + 0] = change.frame1_keypoints_[i].pos_.x; | ||
268 | + out_data[base + 1] = change.frame1_keypoints_[i].pos_.y; | ||
269 | + | ||
270 | + out_data[base + 2] = | ||
271 | + change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f; | ||
272 | + out_data[base + 3] = change.frame2_keypoints_[i].pos_.x; | ||
273 | + out_data[base + 4] = change.frame2_keypoints_[i].pos_.y; | ||
274 | + | ||
275 | + out_data[base + 5] = change.frame1_keypoints_[i].score_; | ||
276 | + out_data[base + 6] = change.frame1_keypoints_[i].type_; | ||
277 | + ++curr_keypoint; | ||
278 | + } | ||
279 | + } | ||
280 | + | ||
281 | + LOGV("Got %d keypoints.", curr_keypoint); | ||
282 | + | ||
283 | + return curr_keypoint; | ||
284 | +} | ||
285 | + | ||
286 | + | ||
287 | +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region, | ||
288 | + const FramePair& frame_pair) const { | ||
289 | + float translation_x; | ||
290 | + float translation_y; | ||
291 | + | ||
292 | + float scale_x; | ||
293 | + float scale_y; | ||
294 | + | ||
295 | + BoundingBox tracked_box(region); | ||
296 | + frame_pair.AdjustBox( | ||
297 | + tracked_box, &translation_x, &translation_y, &scale_x, &scale_y); | ||
298 | + | ||
299 | + tracked_box.Shift(Point2f(translation_x, translation_y)); | ||
300 | + | ||
301 | + if (scale_x > 0 && scale_y > 0) { | ||
302 | + tracked_box.Scale(scale_x, scale_y); | ||
303 | + } | ||
304 | + return tracked_box; | ||
305 | +} | ||
306 | + | ||
307 | +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region, | ||
308 | + const int64_t timestamp) const { | ||
309 | + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp); | ||
310 | + CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!"); | ||
311 | + | ||
312 | + // Anything that ended before the requested timestamp is of no concern to us. | ||
313 | + bool found_it = false; | ||
314 | + int num_frames_back = -1; | ||
315 | + for (int i = 0; i < curr_num_frame_pairs_; ++i) { | ||
316 | + const FramePair& frame_pair = | ||
317 | + frame_pairs_[GetNthIndexFromEnd(i)]; | ||
318 | + | ||
319 | + if (frame_pair.end_time_ <= timestamp) { | ||
320 | + num_frames_back = i - 1; | ||
321 | + | ||
322 | + if (num_frames_back > 0) { | ||
323 | + LOGV("Went %d out of %d frames before finding frame. (index: %d)", | ||
324 | + num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i)); | ||
325 | + } | ||
326 | + | ||
327 | + found_it = true; | ||
328 | + break; | ||
329 | + } | ||
330 | + } | ||
331 | + | ||
332 | + if (!found_it) { | ||
333 | + LOGW("History did not go back far enough! %" PRId64 " vs %" PRId64, | ||
334 | + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - | ||
335 | + frame_pairs_[GetNthIndexFromStart(0)].end_time_, | ||
336 | + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp); | ||
337 | + } | ||
338 | + | ||
339 | + // Loop over all the frames in the queue, tracking the accumulated delta | ||
340 | + // of the point from frame to frame. It's possible the point could | ||
341 | + // go out of frame, but keep tracking as best we can, using points near | ||
342 | + // the edge of the screen where it went out of bounds. | ||
343 | + BoundingBox tracked_box(region); | ||
344 | + for (int i = num_frames_back; i >= 0; --i) { | ||
345 | + const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)]; | ||
346 | + SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!"); | ||
347 | + tracked_box = TrackBox(tracked_box, frame_pair); | ||
348 | + } | ||
349 | + return tracked_box; | ||
350 | +} | ||
351 | + | ||
352 | + | ||
353 | +// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4 | ||
354 | +// 3d transformation matrix. | ||
355 | +inline void Convert3x3To4x4( | ||
356 | + const float* const in_matrix, float* const out_matrix) { | ||
357 | + // X | ||
358 | + out_matrix[0] = in_matrix[0]; | ||
359 | + out_matrix[1] = in_matrix[3]; | ||
360 | + out_matrix[2] = 0.0f; | ||
361 | + out_matrix[3] = 0.0f; | ||
362 | + | ||
363 | + // Y | ||
364 | + out_matrix[4] = in_matrix[1]; | ||
365 | + out_matrix[5] = in_matrix[4]; | ||
366 | + out_matrix[6] = 0.0f; | ||
367 | + out_matrix[7] = 0.0f; | ||
368 | + | ||
369 | + // Z | ||
370 | + out_matrix[8] = 0.0f; | ||
371 | + out_matrix[9] = 0.0f; | ||
372 | + out_matrix[10] = 1.0f; | ||
373 | + out_matrix[11] = 0.0f; | ||
374 | + | ||
375 | + // Translation | ||
376 | + out_matrix[12] = in_matrix[2]; | ||
377 | + out_matrix[13] = in_matrix[5]; | ||
378 | + out_matrix[14] = 0.0f; | ||
379 | + out_matrix[15] = 1.0f; | ||
380 | +} | ||
381 | + | ||
382 | + | ||
383 | +void ObjectTracker::Draw(const int canvas_width, const int canvas_height, | ||
384 | + const float* const frame_to_canvas) const { | ||
385 | +#ifdef __RENDER_OPENGL__ | ||
386 | + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); | ||
387 | + | ||
388 | + glMatrixMode(GL_PROJECTION); | ||
389 | + glLoadIdentity(); | ||
390 | + | ||
391 | + glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f); | ||
392 | + | ||
393 | + // To make Y go the right direction (0 at top of frame). | ||
394 | + glScalef(1.0f, -1.0f, 1.0f); | ||
395 | + glTranslatef(0.0f, -canvas_height, 0.0f); | ||
396 | + | ||
397 | + glMatrixMode(GL_MODELVIEW); | ||
398 | + glLoadIdentity(); | ||
399 | + | ||
400 | + glPushMatrix(); | ||
401 | + | ||
402 | + // Apply the frame to canvas transformation. | ||
403 | + static GLfloat transformation[16]; | ||
404 | + Convert3x3To4x4(frame_to_canvas, transformation); | ||
405 | + glMultMatrixf(transformation); | ||
406 | + | ||
407 | + // Draw tracked object bounding boxes. | ||
408 | + for (TrackedObjectMap::const_iterator iter = objects_.begin(); | ||
409 | + iter != objects_.end(); ++iter) { | ||
410 | + TrackedObject* tracked_object = iter->second; | ||
411 | + tracked_object->Draw(); | ||
412 | + } | ||
413 | + | ||
414 | + static const bool kRenderDebugPyramid = false; | ||
415 | + if (kRenderDebugPyramid) { | ||
416 | + glColor4f(1.0f, 1.0f, 1.0f, 1.0f); | ||
417 | + for (int i = 0; i < kNumPyramidLevels * 2; ++i) { | ||
418 | + Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw(); | ||
419 | + } | ||
420 | + } | ||
421 | + | ||
422 | + static const bool kRenderDebugDerivative = false; | ||
423 | + if (kRenderDebugDerivative) { | ||
424 | + glColor4f(1.0f, 1.0f, 1.0f, 1.0f); | ||
425 | + for (int i = 0; i < kNumPyramidLevels; ++i) { | ||
426 | + const Image<int32_t>& dx = *frame1_->GetSpatialX(i); | ||
427 | + Image<uint8_t> render_image(dx.GetWidth(), dx.GetHeight()); | ||
428 | + for (int y = 0; y < dx.GetHeight(); ++y) { | ||
429 | + const int32_t* dx_ptr = dx[y]; | ||
430 | + uint8_t* dst_ptr = render_image[y]; | ||
431 | + for (int x = 0; x < dx.GetWidth(); ++x) { | ||
432 | + *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255); | ||
433 | + } | ||
434 | + } | ||
435 | + | ||
436 | + Sprite(render_image).Draw(); | ||
437 | + } | ||
438 | + } | ||
439 | + | ||
440 | + if (detector_ != NULL) { | ||
441 | + glDisable(GL_CULL_FACE); | ||
442 | + detector_->Draw(); | ||
443 | + } | ||
444 | + glPopMatrix(); | ||
445 | +#endif | ||
446 | +} | ||
447 | + | ||
448 | +static void AddQuadrants(const BoundingBox& box, | ||
449 | + std::vector<BoundingBox>* boxes) { | ||
450 | + const Point2f center = box.GetCenter(); | ||
451 | + | ||
452 | + float x1 = box.left_; | ||
453 | + float x2 = center.x; | ||
454 | + float x3 = box.right_; | ||
455 | + | ||
456 | + float y1 = box.top_; | ||
457 | + float y2 = center.y; | ||
458 | + float y3 = box.bottom_; | ||
459 | + | ||
460 | + // Upper left. | ||
461 | + boxes->push_back(BoundingBox(x1, y1, x2, y2)); | ||
462 | + | ||
463 | + // Upper right. | ||
464 | + boxes->push_back(BoundingBox(x2, y1, x3, y2)); | ||
465 | + | ||
466 | + // Bottom left. | ||
467 | + boxes->push_back(BoundingBox(x1, y2, x2, y3)); | ||
468 | + | ||
469 | + // Bottom right. | ||
470 | + boxes->push_back(BoundingBox(x2, y2, x3, y3)); | ||
471 | + | ||
472 | + // Whole thing. | ||
473 | + boxes->push_back(box); | ||
474 | +} | ||
475 | + | ||
476 | +void ObjectTracker::ComputeKeypoints(const bool cached_ok) { | ||
477 | + const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)]; | ||
478 | + FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)]; | ||
479 | + | ||
480 | + std::vector<BoundingBox> boxes; | ||
481 | + | ||
482 | + for (TrackedObjectMap::iterator object_iter = objects_.begin(); | ||
483 | + object_iter != objects_.end(); ++object_iter) { | ||
484 | + BoundingBox box = object_iter->second->GetPosition(); | ||
485 | + box.Scale(config_->object_box_scale_factor_for_features, | ||
486 | + config_->object_box_scale_factor_for_features); | ||
487 | + AddQuadrants(box, &boxes); | ||
488 | + } | ||
489 | + | ||
490 | + AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes); | ||
491 | + | ||
492 | + keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change); | ||
493 | +} | ||
494 | + | ||
495 | + | ||
496 | +// Given a vector of detections and a model, simply returns the Detection for | ||
497 | +// that model with the highest correlation. | ||
498 | +bool ObjectTracker::GetBestObjectForDetection( | ||
499 | + const Detection& detection, TrackedObject** match) const { | ||
500 | + TrackedObject* best_match = NULL; | ||
501 | + float best_overlap = -FLT_MAX; | ||
502 | + | ||
503 | + LOGV("Looking for matches in %zu objects!", objects_.size()); | ||
504 | + for (TrackedObjectMap::const_iterator object_iter = objects_.begin(); | ||
505 | + object_iter != objects_.end(); ++object_iter) { | ||
506 | + TrackedObject* const tracked_object = object_iter->second; | ||
507 | + | ||
508 | + const float overlap = tracked_object->GetPosition().PascalScore( | ||
509 | + detection.GetObjectBoundingBox()); | ||
510 | + | ||
511 | + if (!detector_->AllowSpontaneousDetections() && | ||
512 | + (detection.GetObjectModel() != tracked_object->GetModel())) { | ||
513 | + if (overlap > 0.0f) { | ||
514 | + return false; | ||
515 | + } | ||
516 | + continue; | ||
517 | + } | ||
518 | + | ||
519 | + const float jump_distance = | ||
520 | + (tracked_object->GetPosition().GetCenter() - | ||
521 | + detection.GetObjectBoundingBox().GetCenter()).LengthSquared(); | ||
522 | + | ||
523 | + const float allowed_distance = | ||
524 | + tracked_object->GetAllowableDistanceSquared(); | ||
525 | + | ||
526 | + LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f", | ||
527 | + jump_distance, allowed_distance, overlap); | ||
528 | + | ||
529 | + // TODO(andrewharp): No need to do this verification twice, eliminate | ||
530 | + // one of the score checks (the other being in OnDetection). | ||
531 | + if (jump_distance < allowed_distance && | ||
532 | + overlap > best_overlap && | ||
533 | + tracked_object->GetMatchScore() + kMatchScoreBuffer < | ||
534 | + detection.GetMatchScore()) { | ||
535 | + best_match = tracked_object; | ||
536 | + best_overlap = overlap; | ||
537 | + } else if (overlap > 0.0f) { | ||
538 | + return false; | ||
539 | + } | ||
540 | + } | ||
541 | + | ||
542 | + *match = best_match; | ||
543 | + return true; | ||
544 | +} | ||
545 | + | ||
546 | + | ||
547 | +void ObjectTracker::ProcessDetections( | ||
548 | + std::vector<Detection>* const detections) { | ||
549 | + LOGV("Initial detection done, iterating over %zu detections now.", | ||
550 | + detections->size()); | ||
551 | + | ||
552 | + const bool spontaneous_detections_allowed = | ||
553 | + detector_->AllowSpontaneousDetections(); | ||
554 | + for (std::vector<Detection>::const_iterator it = detections->begin(); | ||
555 | + it != detections->end(); ++it) { | ||
556 | + const Detection& detection = *it; | ||
557 | + SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()), | ||
558 | + "Frame does not contain bounding box!"); | ||
559 | + | ||
560 | + TrackedObject* best_match = NULL; | ||
561 | + | ||
562 | + const bool no_collisions = | ||
563 | + GetBestObjectForDetection(detection, &best_match); | ||
564 | + | ||
565 | + // Need to get a non-const version of the model, or create a new one if it | ||
566 | + // wasn't given. | ||
567 | + ObjectModelBase* model = | ||
568 | + const_cast<ObjectModelBase*>(detection.GetObjectModel()); | ||
569 | + | ||
570 | + if (best_match != NULL) { | ||
571 | + if (model != best_match->GetModel()) { | ||
572 | + CHECK_ALWAYS(detector_->AllowSpontaneousDetections(), | ||
573 | + "Model for object changed but spontaneous detections not allowed!"); | ||
574 | + } | ||
575 | + best_match->OnDetection(model, | ||
576 | + detection.GetObjectBoundingBox(), | ||
577 | + detection.GetMatchScore(), | ||
578 | + curr_time_, *frame2_); | ||
579 | + } else if (no_collisions && spontaneous_detections_allowed) { | ||
580 | + if (detection.GetMatchScore() > kMinimumMatchScore) { | ||
581 | + LOGV("No match, adding it!"); | ||
582 | + const ObjectModelBase* model = detection.GetObjectModel(); | ||
583 | + std::ostringstream ss; | ||
584 | + // TODO(andrewharp): Generate this in a more general fashion. | ||
585 | + ss << "hand_" << num_detected_++; | ||
586 | + std::string object_name = ss.str(); | ||
587 | + MaybeAddObject(object_name, *frame2_->GetImage(), | ||
588 | + detection.GetObjectBoundingBox(), model); | ||
589 | + } | ||
590 | + } | ||
591 | + } | ||
592 | +} | ||
593 | + | ||
594 | + | ||
595 | +void ObjectTracker::DetectTargets() { | ||
596 | + // Detect all object model types that we're currently tracking. | ||
597 | + std::vector<const ObjectModelBase*> object_models; | ||
598 | + detector_->GetObjectModels(&object_models); | ||
599 | + if (object_models.size() == 0) { | ||
600 | + LOGV("No objects to search for, aborting."); | ||
601 | + return; | ||
602 | + } | ||
603 | + | ||
604 | + LOGV("Trying to detect %zu models", object_models.size()); | ||
605 | + | ||
606 | + LOGV("Creating test vector!"); | ||
607 | + std::vector<BoundingSquare> positions; | ||
608 | + | ||
609 | + for (TrackedObjectMap::iterator object_iter = objects_.begin(); | ||
610 | + object_iter != objects_.end(); ++object_iter) { | ||
611 | + TrackedObject* const tracked_object = object_iter->second; | ||
612 | + | ||
613 | +#if DEBUG_PREDATOR | ||
614 | + positions.push_back(GetCenteredSquare( | ||
615 | + frame2_->GetImage()->GetContainingBox(), 32.0f)); | ||
616 | +#else | ||
617 | + const BoundingBox& position = tracked_object->GetPosition(); | ||
618 | + | ||
619 | + const float square_size = MAX( | ||
620 | + kScanMinSquareSize / (kLastKnownPositionScaleFactor * | ||
621 | + kLastKnownPositionScaleFactor), | ||
622 | + MIN(position.GetWidth(), | ||
623 | + position.GetHeight())) / kLastKnownPositionScaleFactor; | ||
624 | + | ||
625 | + FillWithSquares(frame2_->GetImage()->GetContainingBox(), | ||
626 | + tracked_object->GetPosition(), | ||
627 | + square_size, | ||
628 | + kScanMinSquareSize, | ||
629 | + kLastKnownPositionScaleFactor, | ||
630 | + &positions); | ||
631 | + } | ||
632 | +#endif | ||
633 | + | ||
634 | + LOGV("Created test vector!"); | ||
635 | + | ||
636 | + std::vector<Detection> detections; | ||
637 | + LOGV("Detecting!"); | ||
638 | + detector_->Detect(positions, &detections); | ||
639 | + LOGV("Found %zu detections", detections.size()); | ||
640 | + | ||
641 | + TimeLog("Finished detection."); | ||
642 | + | ||
643 | + ProcessDetections(&detections); | ||
644 | + | ||
645 | + TimeLog("iterated over detections"); | ||
646 | + | ||
647 | + LOGV("Done detecting!"); | ||
648 | +} | ||
649 | + | ||
650 | + | ||
651 | +void ObjectTracker::TrackObjects() { | ||
652 | + // TODO(andrewharp): Correlation should be allowed to remove objects too. | ||
653 | + const bool automatic_removal_allowed = detector_.get() != NULL ? | ||
654 | + detector_->AllowSpontaneousDetections() : false; | ||
655 | + | ||
656 | + LOGV("Tracking %zu objects!", objects_.size()); | ||
657 | + std::vector<std::string> dead_objects; | ||
658 | + for (TrackedObjectMap::iterator iter = objects_.begin(); | ||
659 | + iter != objects_.end(); iter++) { | ||
660 | + TrackedObject* object = iter->second; | ||
661 | + const BoundingBox tracked_position = TrackBox( | ||
662 | + object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]); | ||
663 | + object->UpdatePosition(tracked_position, curr_time_, *frame2_, false); | ||
664 | + | ||
665 | + if (automatic_removal_allowed && | ||
666 | + object->GetNumConsecutiveFramesBelowThreshold() > | ||
667 | + kMaxNumDetectionFailures * 5) { | ||
668 | + dead_objects.push_back(iter->first); | ||
669 | + } | ||
670 | + } | ||
671 | + | ||
672 | + if (detector_ != NULL && automatic_removal_allowed) { | ||
673 | + for (std::vector<std::string>::iterator iter = dead_objects.begin(); | ||
674 | + iter != dead_objects.end(); iter++) { | ||
675 | + LOGE("Removing object! %s", iter->c_str()); | ||
676 | + ForgetTarget(*iter); | ||
677 | + } | ||
678 | + } | ||
679 | + TimeLog("Tracked all objects."); | ||
680 | + | ||
681 | + LOGV("%zu objects tracked!", objects_.size()); | ||
682 | +} | ||
683 | + | ||
684 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ | ||
18 | + | ||
19 | +#include <map> | ||
20 | +#include <string> | ||
21 | + | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
27 | + | ||
28 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
29 | +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" | ||
30 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" | ||
31 | +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" | ||
32 | +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" | ||
33 | +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" | ||
34 | + | ||
35 | +namespace tf_tracking { | ||
36 | + | ||
37 | +typedef std::map<const std::string, TrackedObject*> TrackedObjectMap; | ||
38 | + | ||
39 | +inline std::ostream& operator<<(std::ostream& stream, | ||
40 | + const TrackedObjectMap& map) { | ||
41 | + for (TrackedObjectMap::const_iterator iter = map.begin(); | ||
42 | + iter != map.end(); ++iter) { | ||
43 | + const TrackedObject& tracked_object = *iter->second; | ||
44 | + const std::string& key = iter->first; | ||
45 | + stream << key << ": " << tracked_object; | ||
46 | + } | ||
47 | + return stream; | ||
48 | +} | ||
49 | + | ||
50 | + | ||
51 | +// ObjectTracker is the highest-level class in the tracking/detection framework. | ||
52 | +// It handles basic image processing, keypoint detection, keypoint tracking, | ||
53 | +// object tracking, and object detection/relocalization. | ||
54 | +class ObjectTracker { | ||
55 | + public: | ||
56 | + ObjectTracker(const TrackerConfig* const config, | ||
57 | + ObjectDetectorBase* const detector); | ||
58 | + virtual ~ObjectTracker(); | ||
59 | + | ||
60 | + virtual void NextFrame(const uint8_t* const new_frame, | ||
61 | + const int64_t timestamp, | ||
62 | + const float* const alignment_matrix_2x3) { | ||
63 | + NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3); | ||
64 | + } | ||
65 | + | ||
66 | + // Called upon the arrival of a new frame of raw data. | ||
67 | + // Does all image processing, keypoint detection, and object | ||
68 | + // tracking/detection for registered objects. | ||
69 | + // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that | ||
70 | + // represents the main transformation that has happened between the last | ||
71 | + // and the current frame. | ||
72 | + // Argument align_level is the pyramid level (where 0 == finest) that | ||
73 | + // the matrix is valid for. | ||
74 | + virtual void NextFrame(const uint8_t* const new_frame, | ||
75 | + const uint8_t* const uv_frame, const int64_t timestamp, | ||
76 | + const float* const alignment_matrix_2x3); | ||
77 | + | ||
78 | + virtual void RegisterNewObjectWithAppearance(const std::string& id, | ||
79 | + const uint8_t* const new_frame, | ||
80 | + const BoundingBox& bounding_box); | ||
81 | + | ||
82 | + // Updates the position of a tracked object, given that it was known to be at | ||
83 | + // a certain position at some point in the past. | ||
84 | + virtual void SetPreviousPositionOfObject(const std::string& id, | ||
85 | + const BoundingBox& bounding_box, | ||
86 | + const int64_t timestamp); | ||
87 | + | ||
88 | + // Sets the current position of the object in the most recent frame provided. | ||
89 | + virtual void SetCurrentPositionOfObject(const std::string& id, | ||
90 | + const BoundingBox& bounding_box); | ||
91 | + | ||
92 | + // Tells the ObjectTracker to stop tracking a target. | ||
93 | + void ForgetTarget(const std::string& id); | ||
94 | + | ||
95 | + // Fills the given out_data buffer with the latest detected keypoint | ||
96 | + // correspondences, first scaled by scale_factor (to adjust for downsampling | ||
97 | + // that may have occurred elsewhere), then packed in a fixed-point format. | ||
98 | + int GetKeypointsPacked(uint16_t* const out_data, | ||
99 | + const float scale_factor) const; | ||
100 | + | ||
101 | + // Copy the keypoint arrays after computeFlow is called. | ||
102 | + // out_data should be at least kMaxKeypoints * kKeypointStep long. | ||
103 | + // Currently, its format is [x1 y1 found x2 y2 score] repeated N times, | ||
104 | + // where N is the number of keypoints tracked. N is returned as the result. | ||
105 | + int GetKeypoints(const bool only_found, float* const out_data) const; | ||
106 | + | ||
107 | + // Returns the current position of a box, given that it was at a certain | ||
108 | + // position at the given time. | ||
109 | + BoundingBox TrackBox(const BoundingBox& region, | ||
110 | + const int64_t timestamp) const; | ||
111 | + | ||
112 | + // Returns the number of frames that have been passed to NextFrame(). | ||
113 | + inline int GetNumFrames() const { | ||
114 | + return num_frames_; | ||
115 | + } | ||
116 | + | ||
117 | + inline bool HaveObject(const std::string& id) const { | ||
118 | + return objects_.find(id) != objects_.end(); | ||
119 | + } | ||
120 | + | ||
121 | + // Returns the TrackedObject associated with the given id. | ||
122 | + inline const TrackedObject* GetObject(const std::string& id) const { | ||
123 | + TrackedObjectMap::const_iterator iter = objects_.find(id); | ||
124 | + CHECK_ALWAYS(iter != objects_.end(), | ||
125 | + "Unknown object key! \"%s\"", id.c_str()); | ||
126 | + TrackedObject* const object = iter->second; | ||
127 | + return object; | ||
128 | + } | ||
129 | + | ||
130 | + // Returns the TrackedObject associated with the given id. | ||
131 | + inline TrackedObject* GetObject(const std::string& id) { | ||
132 | + TrackedObjectMap::iterator iter = objects_.find(id); | ||
133 | + CHECK_ALWAYS(iter != objects_.end(), | ||
134 | + "Unknown object key! \"%s\"", id.c_str()); | ||
135 | + TrackedObject* const object = iter->second; | ||
136 | + return object; | ||
137 | + } | ||
138 | + | ||
139 | + bool IsObjectVisible(const std::string& id) const { | ||
140 | + SCHECK(HaveObject(id), "Don't have this object."); | ||
141 | + | ||
142 | + const TrackedObject* object = GetObject(id); | ||
143 | + return object->IsVisible(); | ||
144 | + } | ||
145 | + | ||
146 | + virtual void Draw(const int canvas_width, const int canvas_height, | ||
147 | + const float* const frame_to_canvas) const; | ||
148 | + | ||
149 | + protected: | ||
150 | + // Creates a new tracked object at the given position. | ||
151 | + // If an object model is provided, then that model will be associated with the | ||
152 | + // object. If not, a new model may be created from the appearance at the | ||
153 | + // initial position and registered with the object detector. | ||
154 | + virtual TrackedObject* MaybeAddObject(const std::string& id, | ||
155 | + const Image<uint8_t>& image, | ||
156 | + const BoundingBox& bounding_box, | ||
157 | + const ObjectModelBase* object_model); | ||
158 | + | ||
159 | + // Find the keypoints in the frame before the current frame. | ||
160 | + // If only one frame exists, keypoints will be found in that frame. | ||
161 | + void ComputeKeypoints(const bool cached_ok = false); | ||
162 | + | ||
163 | + // Finds the correspondences for all the points in the current pair of frames. | ||
164 | + // Stores the results in the given FramePair. | ||
165 | + void FindCorrespondences(FramePair* const curr_change) const; | ||
166 | + | ||
167 | + inline int GetNthIndexFromEnd(const int offset) const { | ||
168 | + return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset); | ||
169 | + } | ||
170 | + | ||
171 | + BoundingBox TrackBox(const BoundingBox& region, | ||
172 | + const FramePair& frame_pair) const; | ||
173 | + | ||
174 | + inline void IncrementFrameIndex() { | ||
175 | + // Move the current framechange index up. | ||
176 | + ++num_frames_; | ||
177 | + ++curr_num_frame_pairs_; | ||
178 | + | ||
179 | + // If we've got too many, push up the start of the queue. | ||
180 | + if (curr_num_frame_pairs_ > kNumFrames) { | ||
181 | + first_frame_index_ = GetNthIndexFromStart(1); | ||
182 | + --curr_num_frame_pairs_; | ||
183 | + } | ||
184 | + } | ||
185 | + | ||
186 | + inline int GetNthIndexFromStart(const int offset) const { | ||
187 | + SCHECK(offset >= 0 && offset < curr_num_frame_pairs_, | ||
188 | + "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_); | ||
189 | + return (first_frame_index_ + offset) % kNumFrames; | ||
190 | + } | ||
191 | + | ||
192 | + void TrackObjects(); | ||
193 | + | ||
194 | + const std::unique_ptr<const TrackerConfig> config_; | ||
195 | + | ||
196 | + const int frame_width_; | ||
197 | + const int frame_height_; | ||
198 | + | ||
199 | + int64_t curr_time_; | ||
200 | + | ||
201 | + int num_frames_; | ||
202 | + | ||
203 | + TrackedObjectMap objects_; | ||
204 | + | ||
205 | + FlowCache flow_cache_; | ||
206 | + | ||
207 | + KeypointDetector keypoint_detector_; | ||
208 | + | ||
209 | + int curr_num_frame_pairs_; | ||
210 | + int first_frame_index_; | ||
211 | + | ||
212 | + std::unique_ptr<ImageData> frame1_; | ||
213 | + std::unique_ptr<ImageData> frame2_; | ||
214 | + | ||
215 | + FramePair frame_pairs_[kNumFrames]; | ||
216 | + | ||
217 | + std::unique_ptr<ObjectDetectorBase> detector_; | ||
218 | + | ||
219 | + int num_detected_; | ||
220 | + | ||
221 | + private: | ||
222 | + void TrackTarget(TrackedObject* const object); | ||
223 | + | ||
224 | + bool GetBestObjectForDetection( | ||
225 | + const Detection& detection, TrackedObject** match) const; | ||
226 | + | ||
227 | + void ProcessDetections(std::vector<Detection>* const detections); | ||
228 | + | ||
229 | + void DetectTargets(); | ||
230 | + | ||
231 | + // Temp object used in ObjectTracker::CreateNewExample. | ||
232 | + mutable std::vector<BoundingSquare> squares; | ||
233 | + | ||
234 | + friend std::ostream& operator<<(std::ostream& stream, | ||
235 | + const ObjectTracker& tracker); | ||
236 | + | ||
237 | + TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker); | ||
238 | +}; | ||
239 | + | ||
240 | +inline std::ostream& operator<<(std::ostream& stream, | ||
241 | + const ObjectTracker& tracker) { | ||
242 | + stream << "Frame size: " << tracker.frame_width_ << "x" | ||
243 | + << tracker.frame_height_ << std::endl; | ||
244 | + | ||
245 | + stream << "Num frames: " << tracker.num_frames_ << std::endl; | ||
246 | + | ||
247 | + stream << "Curr time: " << tracker.curr_time_ << std::endl; | ||
248 | + | ||
249 | + const int first_frame_index = tracker.GetNthIndexFromStart(0); | ||
250 | + const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index]; | ||
251 | + | ||
252 | + const int last_frame_index = tracker.GetNthIndexFromEnd(0); | ||
253 | + const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index]; | ||
254 | + | ||
255 | + stream << "first frame: " << first_frame_index << "," | ||
256 | + << first_frame_pair.end_time_ << " " | ||
257 | + << "last frame: " << last_frame_index << "," | ||
258 | + << last_frame_pair.end_time_ << " diff: " | ||
259 | + << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms" | ||
260 | + << std::endl; | ||
261 | + | ||
262 | + stream << "Tracked targets:"; | ||
263 | + stream << tracker.objects_; | ||
264 | + | ||
265 | + return stream; | ||
266 | +} | ||
267 | + | ||
268 | +} // namespace tf_tracking | ||
269 | + | ||
270 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#include <android/log.h> | ||
17 | +#include <jni.h> | ||
18 | +#include <stdint.h> | ||
19 | +#include <stdlib.h> | ||
20 | +#include <string.h> | ||
21 | +#include <cstdint> | ||
22 | + | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
27 | + | ||
28 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
29 | +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" | ||
30 | + | ||
31 | +namespace tf_tracking { | ||
32 | + | ||
33 | +#define OBJECT_TRACKER_METHOD(METHOD_NAME) \ | ||
34 | + Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT | ||
35 | + | ||
36 | +JniLongField object_tracker_field("nativeObjectTracker"); | ||
37 | + | ||
38 | +ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) { | ||
39 | + ObjectTracker* const object_tracker = | ||
40 | + reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz)); | ||
41 | + CHECK_ALWAYS(object_tracker != NULL, "null object tracker!"); | ||
42 | + return object_tracker; | ||
43 | +} | ||
44 | + | ||
45 | +void set_object_tracker(JNIEnv* env, jobject thiz, | ||
46 | + const ObjectTracker* object_tracker) { | ||
47 | + object_tracker_field.set(env, thiz, | ||
48 | + reinterpret_cast<intptr_t>(object_tracker)); | ||
49 | +} | ||
50 | + | ||
51 | +#ifdef __cplusplus | ||
52 | +extern "C" { | ||
53 | +#endif | ||
54 | +JNIEXPORT | ||
55 | +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz, | ||
56 | + jint width, jint height, | ||
57 | + jboolean always_track); | ||
58 | + | ||
59 | +JNIEXPORT | ||
60 | +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env, | ||
61 | + jobject thiz); | ||
62 | + | ||
63 | +JNIEXPORT | ||
64 | +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)( | ||
65 | + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, | ||
66 | + jfloat x2, jfloat y2, jbyteArray frame_data); | ||
67 | + | ||
68 | +JNIEXPORT | ||
69 | +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)( | ||
70 | + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, | ||
71 | + jfloat x2, jfloat y2, jlong timestamp); | ||
72 | + | ||
73 | +JNIEXPORT | ||
74 | +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)( | ||
75 | + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, | ||
76 | + jfloat x2, jfloat y2); | ||
77 | + | ||
78 | +JNIEXPORT | ||
79 | +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz, | ||
80 | + jstring object_id); | ||
81 | + | ||
82 | +JNIEXPORT | ||
83 | +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env, | ||
84 | + jobject thiz, | ||
85 | + jstring object_id); | ||
86 | + | ||
87 | +JNIEXPORT | ||
88 | +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env, | ||
89 | + jobject thiz, | ||
90 | + jstring object_id); | ||
91 | + | ||
92 | +JNIEXPORT | ||
93 | +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env, | ||
94 | + jobject thiz, | ||
95 | + jstring object_id); | ||
96 | + | ||
97 | +JNIEXPORT | ||
98 | +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz, | ||
99 | + jstring object_id); | ||
100 | + | ||
101 | +JNIEXPORT | ||
102 | +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)( | ||
103 | + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array); | ||
104 | + | ||
105 | +JNIEXPORT | ||
106 | +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz, | ||
107 | + jbyteArray y_data, | ||
108 | + jbyteArray uv_data, | ||
109 | + jlong timestamp, | ||
110 | + jfloatArray vg_matrix_2x3); | ||
111 | + | ||
112 | +JNIEXPORT | ||
113 | +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz, | ||
114 | + jstring object_id); | ||
115 | + | ||
116 | +JNIEXPORT | ||
117 | +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)( | ||
118 | + JNIEnv* env, jobject thiz, jfloat scale_factor); | ||
119 | + | ||
120 | +JNIEXPORT | ||
121 | +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)( | ||
122 | + JNIEnv* env, jobject thiz, jboolean only_found_); | ||
123 | + | ||
124 | +JNIEXPORT | ||
125 | +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)( | ||
126 | + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1, | ||
127 | + jfloat position_y1, jfloat position_x2, jfloat position_y2, | ||
128 | + jfloatArray delta); | ||
129 | + | ||
130 | +JNIEXPORT | ||
131 | +void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj, | ||
132 | + jint view_width, | ||
133 | + jint view_height, | ||
134 | + jfloatArray delta); | ||
135 | + | ||
136 | +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)( | ||
137 | + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride, | ||
138 | + jbyteArray input, jint factor, jbyteArray output); | ||
139 | + | ||
140 | +#ifdef __cplusplus | ||
141 | +} | ||
142 | +#endif | ||
143 | + | ||
144 | +JNIEXPORT | ||
145 | +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz, | ||
146 | + jint width, jint height, | ||
147 | + jboolean always_track) { | ||
148 | + LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz); | ||
149 | + const Size image_size(width, height); | ||
150 | + TrackerConfig* const tracker_config = new TrackerConfig(image_size); | ||
151 | + tracker_config->always_track = always_track; | ||
152 | + | ||
153 | + // XXX detector | ||
154 | + ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL); | ||
155 | + set_object_tracker(env, thiz, tracker); | ||
156 | + LOGI("Initialized!"); | ||
157 | + | ||
158 | + CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker, | ||
159 | + "Failure to set hand tracker!"); | ||
160 | +} | ||
161 | + | ||
162 | +JNIEXPORT | ||
163 | +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env, | ||
164 | + jobject thiz) { | ||
165 | + delete get_object_tracker(env, thiz); | ||
166 | + set_object_tracker(env, thiz, NULL); | ||
167 | +} | ||
168 | + | ||
169 | +JNIEXPORT | ||
170 | +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)( | ||
171 | + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, | ||
172 | + jfloat x2, jfloat y2, jbyteArray frame_data) { | ||
173 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
174 | + | ||
175 | + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1, | ||
176 | + x2, y2); | ||
177 | + | ||
178 | + jboolean iCopied = JNI_FALSE; | ||
179 | + | ||
180 | + // Copy image into currFrame. | ||
181 | + jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied); | ||
182 | + | ||
183 | + BoundingBox bounding_box(x1, y1, x2, y2); | ||
184 | + get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance( | ||
185 | + id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box); | ||
186 | + | ||
187 | + env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT); | ||
188 | + | ||
189 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
190 | +} | ||
191 | + | ||
192 | +JNIEXPORT | ||
193 | +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)( | ||
194 | + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, | ||
195 | + jfloat x2, jfloat y2, jlong timestamp) { | ||
196 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
197 | + | ||
198 | + LOGI( | ||
199 | + "Registering the position of %s at %.2f,%.2f,%.2f,%.2f" | ||
200 | + " at time %lld", | ||
201 | + id_str, x1, y1, x2, y2, static_cast<long long>(timestamp)); | ||
202 | + | ||
203 | + get_object_tracker(env, thiz)->SetPreviousPositionOfObject( | ||
204 | + id_str, BoundingBox(x1, y1, x2, y2), timestamp); | ||
205 | + | ||
206 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
207 | +} | ||
208 | + | ||
209 | +JNIEXPORT | ||
210 | +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)( | ||
211 | + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, | ||
212 | + jfloat x2, jfloat y2) { | ||
213 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
214 | + | ||
215 | + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1, | ||
216 | + x2, y2); | ||
217 | + | ||
218 | + get_object_tracker(env, thiz)->SetCurrentPositionOfObject( | ||
219 | + id_str, BoundingBox(x1, y1, x2, y2)); | ||
220 | + | ||
221 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
222 | +} | ||
223 | + | ||
224 | +JNIEXPORT | ||
225 | +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz, | ||
226 | + jstring object_id) { | ||
227 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
228 | + | ||
229 | + const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str); | ||
230 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
231 | + return haveObject; | ||
232 | +} | ||
233 | + | ||
234 | +JNIEXPORT | ||
235 | +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env, | ||
236 | + jobject thiz, | ||
237 | + jstring object_id) { | ||
238 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
239 | + | ||
240 | + const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str); | ||
241 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
242 | + return visible; | ||
243 | +} | ||
244 | + | ||
245 | +JNIEXPORT | ||
246 | +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env, | ||
247 | + jobject thiz, | ||
248 | + jstring object_id) { | ||
249 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
250 | + const TrackedObject* const object = | ||
251 | + get_object_tracker(env, thiz)->GetObject(id_str); | ||
252 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
253 | + jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str()); | ||
254 | + return model_name; | ||
255 | +} | ||
256 | + | ||
257 | +JNIEXPORT | ||
258 | +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env, | ||
259 | + jobject thiz, | ||
260 | + jstring object_id) { | ||
261 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
262 | + | ||
263 | + const float correlation = | ||
264 | + get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation(); | ||
265 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
266 | + return correlation; | ||
267 | +} | ||
268 | + | ||
269 | +JNIEXPORT | ||
270 | +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz, | ||
271 | + jstring object_id) { | ||
272 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
273 | + | ||
274 | + const float match_score = | ||
275 | + get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value; | ||
276 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
277 | + return match_score; | ||
278 | +} | ||
279 | + | ||
280 | +JNIEXPORT | ||
281 | +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)( | ||
282 | + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) { | ||
283 | + jboolean iCopied = JNI_FALSE; | ||
284 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
285 | + | ||
286 | + const BoundingBox bounding_box = | ||
287 | + get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition(); | ||
288 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
289 | + | ||
290 | + jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied); | ||
291 | + bounding_box.CopyToArray(reinterpret_cast<float*>(rect)); | ||
292 | + env->ReleaseFloatArrayElements(rect_array, rect, 0); | ||
293 | +} | ||
294 | + | ||
295 | +JNIEXPORT | ||
296 | +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz, | ||
297 | + jbyteArray y_data, | ||
298 | + jbyteArray uv_data, | ||
299 | + jlong timestamp, | ||
300 | + jfloatArray vg_matrix_2x3) { | ||
301 | + TimeLog("Starting object tracker"); | ||
302 | + | ||
303 | + jboolean iCopied = JNI_FALSE; | ||
304 | + | ||
305 | + float vision_gyro_matrix_array[6]; | ||
306 | + jfloat* jmat = NULL; | ||
307 | + | ||
308 | + if (vg_matrix_2x3 != NULL) { | ||
309 | + // Copy the alignment matrix into a float array. | ||
310 | + jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied); | ||
311 | + for (int i = 0; i < 6; ++i) { | ||
312 | + vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]); | ||
313 | + } | ||
314 | + } | ||
315 | + // Copy image into currFrame. | ||
316 | + jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied); | ||
317 | + jbyte* uv_pixels = | ||
318 | + uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL; | ||
319 | + | ||
320 | + TimeLog("Got elements"); | ||
321 | + | ||
322 | + // Add the frame to the object tracker object. | ||
323 | + get_object_tracker(env, thiz)->NextFrame( | ||
324 | + reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels), | ||
325 | + timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL); | ||
326 | + | ||
327 | + env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT); | ||
328 | + | ||
329 | + if (uv_data != NULL) { | ||
330 | + env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT); | ||
331 | + } | ||
332 | + | ||
333 | + if (vg_matrix_2x3 != NULL) { | ||
334 | + env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT); | ||
335 | + } | ||
336 | + | ||
337 | + TimeLog("Released elements"); | ||
338 | + | ||
339 | + PrintTimeLog(); | ||
340 | + ResetTimeLog(); | ||
341 | +} | ||
342 | + | ||
343 | +JNIEXPORT | ||
344 | +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz, | ||
345 | + jstring object_id) { | ||
346 | + const char* const id_str = env->GetStringUTFChars(object_id, 0); | ||
347 | + | ||
348 | + get_object_tracker(env, thiz)->ForgetTarget(id_str); | ||
349 | + | ||
350 | + env->ReleaseStringUTFChars(object_id, id_str); | ||
351 | +} | ||
352 | + | ||
353 | +JNIEXPORT | ||
354 | +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)( | ||
355 | + JNIEnv* env, jobject thiz, jboolean only_found) { | ||
356 | + jfloat keypoint_arr[kMaxKeypoints * kKeypointStep]; | ||
357 | + | ||
358 | + const int number_of_keypoints = | ||
359 | + get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr); | ||
360 | + | ||
361 | + // Create and return the array that will be passed back to Java. | ||
362 | + jfloatArray keypoints = | ||
363 | + env->NewFloatArray(number_of_keypoints * kKeypointStep); | ||
364 | + if (keypoints == NULL) { | ||
365 | + LOGE("null array!"); | ||
366 | + return NULL; | ||
367 | + } | ||
368 | + env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep, | ||
369 | + keypoint_arr); | ||
370 | + | ||
371 | + return keypoints; | ||
372 | +} | ||
373 | + | ||
374 | +JNIEXPORT | ||
375 | +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)( | ||
376 | + JNIEnv* env, jobject thiz, jfloat scale_factor) { | ||
377 | + // 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint. | ||
378 | + const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2; | ||
379 | + jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint]; | ||
380 | + | ||
381 | + const int number_of_keypoints = | ||
382 | + get_object_tracker(env, thiz)->GetKeypointsPacked( | ||
383 | + reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor); | ||
384 | + | ||
385 | + // Create and return the array that will be passed back to Java. | ||
386 | + jbyteArray keypoints = | ||
387 | + env->NewByteArray(number_of_keypoints * bytes_per_keypoint); | ||
388 | + | ||
389 | + if (keypoints == NULL) { | ||
390 | + LOGE("null array!"); | ||
391 | + return NULL; | ||
392 | + } | ||
393 | + | ||
394 | + env->SetByteArrayRegion( | ||
395 | + keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr); | ||
396 | + | ||
397 | + return keypoints; | ||
398 | +} | ||
399 | + | ||
400 | +JNIEXPORT | ||
401 | +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)( | ||
402 | + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1, | ||
403 | + jfloat position_y1, jfloat position_x2, jfloat position_y2, | ||
404 | + jfloatArray delta) { | ||
405 | + jfloat point_arr[4]; | ||
406 | + | ||
407 | + const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox( | ||
408 | + BoundingBox(position_x1, position_y1, position_x2, position_y2), | ||
409 | + timestamp); | ||
410 | + | ||
411 | + new_position.CopyToArray(point_arr); | ||
412 | + env->SetFloatArrayRegion(delta, 0, 4, point_arr); | ||
413 | +} | ||
414 | + | ||
415 | +JNIEXPORT | ||
416 | +void JNICALL OBJECT_TRACKER_METHOD(drawNative)( | ||
417 | + JNIEnv* env, jobject thiz, jint view_width, jint view_height, | ||
418 | + jfloatArray frame_to_canvas_arr) { | ||
419 | + ObjectTracker* object_tracker = get_object_tracker(env, thiz); | ||
420 | + if (object_tracker != NULL) { | ||
421 | + jfloat* frame_to_canvas = | ||
422 | + env->GetFloatArrayElements(frame_to_canvas_arr, NULL); | ||
423 | + | ||
424 | + object_tracker->Draw(view_width, view_height, frame_to_canvas); | ||
425 | + env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas, | ||
426 | + JNI_ABORT); | ||
427 | + } | ||
428 | +} | ||
429 | + | ||
430 | +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)( | ||
431 | + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride, | ||
432 | + jbyteArray input, jint factor, jbyteArray output) { | ||
433 | + if (input == NULL || output == NULL) { | ||
434 | + LOGW("Received null arrays, hopefully this is a test!"); | ||
435 | + return; | ||
436 | + } | ||
437 | + | ||
438 | + jbyte* const input_array = env->GetByteArrayElements(input, 0); | ||
439 | + jbyte* const output_array = env->GetByteArrayElements(output, 0); | ||
440 | + | ||
441 | + { | ||
442 | + tf_tracking::Image<uint8_t> full_image( | ||
443 | + width, height, reinterpret_cast<uint8_t*>(input_array), false); | ||
444 | + | ||
445 | + const int new_width = (width + factor - 1) / factor; | ||
446 | + const int new_height = (height + factor - 1) / factor; | ||
447 | + | ||
448 | + tf_tracking::Image<uint8_t> downsampled_image( | ||
449 | + new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false); | ||
450 | + | ||
451 | + downsampled_image.DownsampleAveraged( | ||
452 | + reinterpret_cast<uint8_t*>(input_array), row_stride, factor); | ||
453 | + } | ||
454 | + | ||
455 | + env->ReleaseByteArrayElements(input, input_array, JNI_ABORT); | ||
456 | + env->ReleaseByteArrayElements(output, output_array, 0); | ||
457 | +} | ||
458 | + | ||
459 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#include <math.h> | ||
17 | + | ||
18 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
23 | + | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" | ||
27 | +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" | ||
28 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" | ||
29 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" | ||
30 | +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" | ||
31 | + | ||
32 | +namespace tf_tracking { | ||
33 | + | ||
34 | +OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config) | ||
35 | + : config_(config), | ||
36 | + frame1_(NULL), | ||
37 | + frame2_(NULL), | ||
38 | + working_size_(config->image_size) {} | ||
39 | + | ||
40 | + | ||
41 | +void OpticalFlow::NextFrame(const ImageData* const image_data) { | ||
42 | + // Special case for the first frame: make sure the image ends up in | ||
43 | + // frame1_ so that keypoint detection can be done on it if desired. | ||
44 | + frame1_ = (frame1_ == NULL) ? image_data : frame2_; | ||
45 | + frame2_ = image_data; | ||
46 | +} | ||
47 | + | ||
48 | + | ||
49 | +// Static heart of the optical flow computation. | ||
50 | +// Lucas Kanade algorithm. | ||
51 | +bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8_t>& img_I, | ||
52 | + const Image<uint8_t>& img_J, | ||
53 | + const Image<int32_t>& I_x, | ||
54 | + const Image<int32_t>& I_y, const float p_x, | ||
55 | + const float p_y, float* out_g_x, | ||
56 | + float* out_g_y) { | ||
57 | + float g_x = *out_g_x; | ||
58 | + float g_y = *out_g_y; | ||
59 | + // Get values for frame 1. They remain constant through the inner | ||
60 | + // iteration loop. | ||
61 | + float vals_I[kFlowArraySize]; | ||
62 | + float vals_I_x[kFlowArraySize]; | ||
63 | + float vals_I_y[kFlowArraySize]; | ||
64 | + | ||
65 | + const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1; | ||
66 | + const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize); | ||
67 | + | ||
68 | +#if USE_FIXED_POINT_FLOW | ||
69 | + const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1; | ||
70 | + const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1; | ||
71 | +#else | ||
72 | + const float real_x_max = I_x.width_less_one_ - EPSILON; | ||
73 | + const float real_y_max = I_x.height_less_one_ - EPSILON; | ||
74 | +#endif | ||
75 | + | ||
76 | + // Get the window around the original point. | ||
77 | + const float src_left_real = p_x - kWindowSizeFloat; | ||
78 | + const float src_top_real = p_y - kWindowSizeFloat; | ||
79 | + float* vals_I_ptr = vals_I; | ||
80 | + float* vals_I_x_ptr = vals_I_x; | ||
81 | + float* vals_I_y_ptr = vals_I_y; | ||
82 | +#if USE_FIXED_POINT_FLOW | ||
83 | + // Source integer coordinates. | ||
84 | + const int src_left_fixed = RealToFixed1616(src_left_real); | ||
85 | + const int src_top_fixed = RealToFixed1616(src_top_real); | ||
86 | + | ||
87 | + for (int y = 0; y < kPatchSize; ++y) { | ||
88 | + const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max); | ||
89 | + | ||
90 | + for (int x = 0; x < kPatchSize; ++x) { | ||
91 | + const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max); | ||
92 | + | ||
93 | + *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y); | ||
94 | + *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); | ||
95 | + *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); | ||
96 | + } | ||
97 | + } | ||
98 | +#else | ||
99 | + for (int y = 0; y < kPatchSize; ++y) { | ||
100 | + const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max); | ||
101 | + | ||
102 | + for (int x = 0; x < kPatchSize; ++x) { | ||
103 | + const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max); | ||
104 | + | ||
105 | + *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos); | ||
106 | + *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos); | ||
107 | + *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos); | ||
108 | + } | ||
109 | + } | ||
110 | +#endif | ||
111 | + | ||
112 | + // Compute the spatial gradient matrix about point p. | ||
113 | + float G[] = { 0, 0, 0, 0 }; | ||
114 | + CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G); | ||
115 | + | ||
116 | + // Find the inverse of G. | ||
117 | + float G_inv[4]; | ||
118 | + if (!Invert2x2(G, G_inv)) { | ||
119 | + return false; | ||
120 | + } | ||
121 | + | ||
122 | +#if NORMALIZE | ||
123 | + const float mean_I = ComputeMean(vals_I, kFlowArraySize); | ||
124 | + const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I); | ||
125 | +#endif | ||
126 | + | ||
127 | + // Iterate kNumIterations times or until we converge. | ||
128 | + for (int iteration = 0; iteration < kNumIterations; ++iteration) { | ||
129 | + // Get values for frame 2. | ||
130 | + float vals_J[kFlowArraySize]; | ||
131 | + | ||
132 | + // Get the window around the destination point. | ||
133 | + const float left_real = p_x + g_x - kWindowSizeFloat; | ||
134 | + const float top_real = p_y + g_y - kWindowSizeFloat; | ||
135 | + float* vals_J_ptr = vals_J; | ||
136 | +#if USE_FIXED_POINT_FLOW | ||
137 | + // The top-left sub-pixel is set for the current iteration (in 16:16 | ||
138 | + // fixed). This is constant over one iteration. | ||
139 | + const int left_fixed = RealToFixed1616(left_real); | ||
140 | + const int top_fixed = RealToFixed1616(top_real); | ||
141 | + | ||
142 | + for (int win_y = 0; win_y < kPatchSize; ++win_y) { | ||
143 | + const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max); | ||
144 | + for (int win_x = 0; win_x < kPatchSize; ++win_x) { | ||
145 | + const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max); | ||
146 | + *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y); | ||
147 | + } | ||
148 | + } | ||
149 | +#else | ||
150 | + for (int win_y = 0; win_y < kPatchSize; ++win_y) { | ||
151 | + const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max); | ||
152 | + for (int win_x = 0; win_x < kPatchSize; ++win_x) { | ||
153 | + const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max); | ||
154 | + *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos); | ||
155 | + } | ||
156 | + } | ||
157 | +#endif | ||
158 | + | ||
159 | +#if NORMALIZE | ||
160 | + const float mean_J = ComputeMean(vals_J, kFlowArraySize); | ||
161 | + const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J); | ||
162 | + | ||
163 | + // TODO(andrewharp): Probably better to completely detect and handle the | ||
164 | + // "corner case" where the patch is fully outside the image diagonally. | ||
165 | + const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f; | ||
166 | +#endif | ||
167 | + | ||
168 | + // Compute image mismatch vector. | ||
169 | + float b_x = 0.0f; | ||
170 | + float b_y = 0.0f; | ||
171 | + | ||
172 | + vals_I_ptr = vals_I; | ||
173 | + vals_J_ptr = vals_J; | ||
174 | + vals_I_x_ptr = vals_I_x; | ||
175 | + vals_I_y_ptr = vals_I_y; | ||
176 | + | ||
177 | + for (int win_y = 0; win_y < kPatchSize; ++win_y) { | ||
178 | + for (int win_x = 0; win_x < kPatchSize; ++win_x) { | ||
179 | +#if NORMALIZE | ||
180 | + // Normalized Image difference. | ||
181 | + const float dI = | ||
182 | + (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio; | ||
183 | +#else | ||
184 | + const float dI = *vals_I_ptr++ - *vals_J_ptr++; | ||
185 | +#endif | ||
186 | + b_x += dI * *vals_I_x_ptr++; | ||
187 | + b_y += dI * *vals_I_y_ptr++; | ||
188 | + } | ||
189 | + } | ||
190 | + | ||
191 | + // Optical flow... solve n = G^-1 * b | ||
192 | + const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y); | ||
193 | + const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y); | ||
194 | + | ||
195 | + // Update best guess with residual displacement from this level and | ||
196 | + // iteration. | ||
197 | + g_x += n_x; | ||
198 | + g_y += n_y; | ||
199 | + | ||
200 | + // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y); | ||
201 | + | ||
202 | + // Abort early if we're already below the threshold. | ||
203 | + if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) { | ||
204 | + break; | ||
205 | + } | ||
206 | + } // Iteration. | ||
207 | + | ||
208 | + // Copy value back into output. | ||
209 | + *out_g_x = g_x; | ||
210 | + *out_g_y = g_y; | ||
211 | + return true; | ||
212 | +} | ||
213 | + | ||
214 | + | ||
215 | +// Pointwise flow using translational 2dof ESM. | ||
216 | +bool OpticalFlow::FindFlowAtPoint_ESM( | ||
217 | + const Image<uint8_t>& img_I, const Image<uint8_t>& img_J, | ||
218 | + const Image<int32_t>& I_x, const Image<int32_t>& I_y, | ||
219 | + const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x, | ||
220 | + const float p_y, float* out_g_x, float* out_g_y) { | ||
221 | + float g_x = *out_g_x; | ||
222 | + float g_y = *out_g_y; | ||
223 | + const float area_inv = 1.0f / static_cast<float>(kFlowArraySize); | ||
224 | + | ||
225 | + // Get values for frame 1. They remain constant through the inner | ||
226 | + // iteration loop. | ||
227 | + uint8_t vals_I[kFlowArraySize]; | ||
228 | + uint8_t vals_J[kFlowArraySize]; | ||
229 | + int16_t src_gradient_x[kFlowArraySize]; | ||
230 | + int16_t src_gradient_y[kFlowArraySize]; | ||
231 | + | ||
232 | + // TODO(rspring): try out the IntegerPatchAlign() method once | ||
233 | + // the code for that is in ../common. | ||
234 | + const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize); | ||
235 | + const int src_left_fixed = RealToFixed1616(p_x - wsize_float); | ||
236 | + const int src_top_fixed = RealToFixed1616(p_y - wsize_float); | ||
237 | + const int patch_size = 2 * kFlowIntegrationWindowSize + 1; | ||
238 | + | ||
239 | + // Create the keypoint template patch from a subpixel location. | ||
240 | + if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, | ||
241 | + patch_size, patch_size, vals_I) || | ||
242 | + !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, | ||
243 | + patch_size, patch_size, | ||
244 | + src_gradient_x) || | ||
245 | + !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, | ||
246 | + patch_size, patch_size, | ||
247 | + src_gradient_y)) { | ||
248 | + return false; | ||
249 | + } | ||
250 | + | ||
251 | + int bright_offset = 0; | ||
252 | + int sum_diff = 0; | ||
253 | + | ||
254 | + // The top-left sub-pixel is set for the current iteration (in 16:16 fixed). | ||
255 | + // This is constant over one iteration. | ||
256 | + int left_fixed = RealToFixed1616(p_x + g_x - wsize_float); | ||
257 | + int top_fixed = RealToFixed1616(p_y + g_y - wsize_float); | ||
258 | + | ||
259 | + // The truncated version gives the most top-left pixel that is used. | ||
260 | + int left_trunc = left_fixed >> 16; | ||
261 | + int top_trunc = top_fixed >> 16; | ||
262 | + | ||
263 | + // Compute an initial brightness offset. | ||
264 | + if (kDoBrightnessNormalize && | ||
265 | + left_trunc >= 0 && top_trunc >= 0 && | ||
266 | + (left_trunc + patch_size) < img_J.width_less_one_ && | ||
267 | + (top_trunc + patch_size) < img_J.height_less_one_) { | ||
268 | + int templ_index = 0; | ||
269 | + const uint8_t* j_row = img_J[top_trunc] + left_trunc; | ||
270 | + | ||
271 | + const int j_stride = img_J.stride(); | ||
272 | + | ||
273 | + for (int y = 0; y < patch_size; ++y, j_row += j_stride) { | ||
274 | + for (int x = 0; x < patch_size; ++x) { | ||
275 | + sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++]; | ||
276 | + } | ||
277 | + } | ||
278 | + | ||
279 | + bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv); | ||
280 | + } | ||
281 | + | ||
282 | + // Iterate kNumIterations times or until we go out of image. | ||
283 | + for (int iteration = 0; iteration < kNumIterations; ++iteration) { | ||
284 | + int jtj[3] = { 0, 0, 0 }; | ||
285 | + int jtr[2] = { 0, 0 }; | ||
286 | + sum_diff = 0; | ||
287 | + | ||
288 | + // Extract the target image values. | ||
289 | + // Extract the gradient from the target image patch and accumulate to | ||
290 | + // the gradient of the source image patch. | ||
291 | + if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed, | ||
292 | + patch_size, patch_size, | ||
293 | + vals_J)) { | ||
294 | + break; | ||
295 | + } | ||
296 | + | ||
297 | + const uint8_t* templ_row = vals_I; | ||
298 | + const uint8_t* extract_row = vals_J; | ||
299 | + const int16_t* src_dx_row = src_gradient_x; | ||
300 | + const int16_t* src_dy_row = src_gradient_y; | ||
301 | + | ||
302 | + for (int y = 0; y < patch_size; ++y, templ_row += patch_size, | ||
303 | + src_dx_row += patch_size, src_dy_row += patch_size, | ||
304 | + extract_row += patch_size) { | ||
305 | + const int fp_y = top_fixed + (y << 16); | ||
306 | + for (int x = 0; x < patch_size; ++x) { | ||
307 | + const int fp_x = left_fixed + (x << 16); | ||
308 | + int32_t target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y); | ||
309 | + int32_t target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y); | ||
310 | + | ||
311 | + // Combine the two Jacobians. | ||
312 | + // Right-shift by one to account for the fact that we add | ||
313 | + // two Jacobians. | ||
314 | + int32_t dx = (src_dx_row[x] + target_dx) >> 1; | ||
315 | + int32_t dy = (src_dy_row[x] + target_dy) >> 1; | ||
316 | + | ||
317 | + // The current residual b - h(q) == extracted - (template + offset) | ||
318 | + int32_t diff = static_cast<int32_t>(extract_row[x]) - | ||
319 | + static_cast<int32_t>(templ_row[x]) - bright_offset; | ||
320 | + | ||
321 | + jtj[0] += dx * dx; | ||
322 | + jtj[1] += dx * dy; | ||
323 | + jtj[2] += dy * dy; | ||
324 | + | ||
325 | + jtr[0] += dx * diff; | ||
326 | + jtr[1] += dy * diff; | ||
327 | + | ||
328 | + sum_diff += diff; | ||
329 | + } | ||
330 | + } | ||
331 | + | ||
332 | + const float jtr1_float = static_cast<float>(jtr[0]); | ||
333 | + const float jtr2_float = static_cast<float>(jtr[1]); | ||
334 | + | ||
335 | + // Add some baseline stability to the system. | ||
336 | + jtj[0] += kEsmRegularizer; | ||
337 | + jtj[2] += kEsmRegularizer; | ||
338 | + | ||
339 | + const int64_t prod1 = static_cast<int64_t>(jtj[0]) * jtj[2]; | ||
340 | + const int64_t prod2 = static_cast<int64_t>(jtj[1]) * jtj[1]; | ||
341 | + | ||
342 | + // One ESM step. | ||
343 | + const float jtj_1[4] = { static_cast<float>(jtj[2]), | ||
344 | + static_cast<float>(-jtj[1]), | ||
345 | + static_cast<float>(-jtj[1]), | ||
346 | + static_cast<float>(jtj[0]) }; | ||
347 | + const double det_inv = 1.0 / static_cast<double>(prod1 - prod2); | ||
348 | + | ||
349 | + g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float); | ||
350 | + g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float); | ||
351 | + | ||
352 | + if (kDoBrightnessNormalize) { | ||
353 | + bright_offset += | ||
354 | + static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f); | ||
355 | + } | ||
356 | + | ||
357 | + // Update top left position. | ||
358 | + left_fixed = RealToFixed1616(p_x + g_x - wsize_float); | ||
359 | + top_fixed = RealToFixed1616(p_y + g_y - wsize_float); | ||
360 | + | ||
361 | + left_trunc = left_fixed >> 16; | ||
362 | + top_trunc = top_fixed >> 16; | ||
363 | + | ||
364 | + // Abort iterations if we go out of borders. | ||
365 | + if (left_trunc < 0 || top_trunc < 0 || | ||
366 | + (left_trunc + patch_size) >= J_x.width_less_one_ || | ||
367 | + (top_trunc + patch_size) >= J_y.height_less_one_) { | ||
368 | + break; | ||
369 | + } | ||
370 | + } // Iteration. | ||
371 | + | ||
372 | + // Copy value back into output. | ||
373 | + *out_g_x = g_x; | ||
374 | + *out_g_y = g_y; | ||
375 | + return true; | ||
376 | +} | ||
377 | + | ||
378 | + | ||
379 | +bool OpticalFlow::FindFlowAtPointReversible( | ||
380 | + const int level, const float u_x, const float u_y, | ||
381 | + const bool reverse_flow, | ||
382 | + float* flow_x, float* flow_y) const { | ||
383 | + const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_; | ||
384 | + const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_; | ||
385 | + | ||
386 | + // Images I (prev) and J (next). | ||
387 | + const Image<uint8_t>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2); | ||
388 | + const Image<uint8_t>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2); | ||
389 | + | ||
390 | + // Computed gradients. | ||
391 | + const Image<int32_t>& I_x = *frame_a.GetSpatialX(level); | ||
392 | + const Image<int32_t>& I_y = *frame_a.GetSpatialY(level); | ||
393 | + const Image<int32_t>& J_x = *frame_b.GetSpatialX(level); | ||
394 | + const Image<int32_t>& J_y = *frame_b.GetSpatialY(level); | ||
395 | + | ||
396 | + // Shrink factor from original. | ||
397 | + const float shrink_factor = (1 << level); | ||
398 | + | ||
399 | + // Image position vector (p := u^l), scaled for this level. | ||
400 | + const float scaled_p_x = u_x / shrink_factor; | ||
401 | + const float scaled_p_y = u_y / shrink_factor; | ||
402 | + | ||
403 | + float scaled_flow_x = *flow_x / shrink_factor; | ||
404 | + float scaled_flow_y = *flow_y / shrink_factor; | ||
405 | + | ||
406 | + // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level, | ||
407 | + // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y); | ||
408 | + | ||
409 | + const bool success = kUseEsm ? | ||
410 | + FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y, | ||
411 | + scaled_p_x, scaled_p_y, | ||
412 | + &scaled_flow_x, &scaled_flow_y) : | ||
413 | + FindFlowAtPoint_LK(img_I, img_J, I_x, I_y, | ||
414 | + scaled_p_x, scaled_p_y, | ||
415 | + &scaled_flow_x, &scaled_flow_y); | ||
416 | + | ||
417 | + *flow_x = scaled_flow_x * shrink_factor; | ||
418 | + *flow_y = scaled_flow_y * shrink_factor; | ||
419 | + | ||
420 | + return success; | ||
421 | +} | ||
422 | + | ||
423 | + | ||
424 | +bool OpticalFlow::FindFlowAtPointSingleLevel( | ||
425 | + const int level, | ||
426 | + const float u_x, const float u_y, | ||
427 | + const bool filter_by_fb_error, | ||
428 | + float* flow_x, float* flow_y) const { | ||
429 | + if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) { | ||
430 | + return false; | ||
431 | + } | ||
432 | + | ||
433 | + if (filter_by_fb_error) { | ||
434 | + const float new_position_x = u_x + *flow_x; | ||
435 | + const float new_position_y = u_y + *flow_y; | ||
436 | + | ||
437 | + float reverse_flow_x = 0.0f; | ||
438 | + float reverse_flow_y = 0.0f; | ||
439 | + | ||
440 | + // Now find the backwards flow and confirm it lines up with the original | ||
441 | + // starting point. | ||
442 | + if (!FindFlowAtPointReversible(level, new_position_x, new_position_y, | ||
443 | + true, | ||
444 | + &reverse_flow_x, &reverse_flow_y)) { | ||
445 | + LOGE("Backward error!"); | ||
446 | + return false; | ||
447 | + } | ||
448 | + | ||
449 | + const float discrepancy_length = | ||
450 | + sqrtf(Square(*flow_x + reverse_flow_x) + | ||
451 | + Square(*flow_y + reverse_flow_y)); | ||
452 | + | ||
453 | + const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y)); | ||
454 | + | ||
455 | + return discrepancy_length < | ||
456 | + (kMaxForwardBackwardErrorAllowed * flow_length); | ||
457 | + } | ||
458 | + | ||
459 | + return true; | ||
460 | +} | ||
461 | + | ||
462 | + | ||
463 | +// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm. | ||
464 | +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details. | ||
465 | +bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y, | ||
466 | + const bool filter_by_fb_error, | ||
467 | + float* flow_x, float* flow_y) const { | ||
468 | + const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment, | ||
469 | + kNumPyramidLevels - kNumCacheLevels); | ||
470 | + | ||
471 | + // For every level in the pyramid, update the coordinates of the best match. | ||
472 | + for (int l = max_level - 1; l >= 0; --l) { | ||
473 | + if (!FindFlowAtPointSingleLevel(l, u_x, u_y, | ||
474 | + filter_by_fb_error, flow_x, flow_y)) { | ||
475 | + return false; | ||
476 | + } | ||
477 | + } | ||
478 | + | ||
479 | + return true; | ||
480 | +} | ||
481 | + | ||
482 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
21 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
23 | + | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/config.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" | ||
27 | +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" | ||
28 | + | ||
29 | +namespace tf_tracking { | ||
30 | + | ||
31 | +class FlowCache; | ||
32 | + | ||
33 | +// Class encapsulating all the data and logic necessary for performing optical | ||
34 | +// flow. | ||
35 | +class OpticalFlow { | ||
36 | + public: | ||
37 | + explicit OpticalFlow(const OpticalFlowConfig* const config); | ||
38 | + | ||
39 | + // Add a new frame to the optical flow. Will update all the non-keypoint | ||
40 | + // related member variables. | ||
41 | + // | ||
42 | + // new_frame should be a buffer of grayscale values, one byte per pixel, | ||
43 | + // at the original frame_width and frame_height used to initialize the | ||
44 | + // OpticalFlow object. Downsampling will be handled internally. | ||
45 | + // | ||
46 | + // time_stamp should be a time in milliseconds that later calls to this and | ||
47 | + // other methods will be relative to. | ||
48 | + void NextFrame(const ImageData* const image_data); | ||
49 | + | ||
50 | + // An implementation of the Lucas-Kanade Optical Flow algorithm. | ||
51 | + static bool FindFlowAtPoint_LK(const Image<uint8_t>& img_I, | ||
52 | + const Image<uint8_t>& img_J, | ||
53 | + const Image<int32_t>& I_x, | ||
54 | + const Image<int32_t>& I_y, const float p_x, | ||
55 | + const float p_y, float* out_g_x, | ||
56 | + float* out_g_y); | ||
57 | + | ||
58 | + // Pointwise flow using translational 2dof ESM. | ||
59 | + static bool FindFlowAtPoint_ESM( | ||
60 | + const Image<uint8_t>& img_I, const Image<uint8_t>& img_J, | ||
61 | + const Image<int32_t>& I_x, const Image<int32_t>& I_y, | ||
62 | + const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x, | ||
63 | + const float p_y, float* out_g_x, float* out_g_y); | ||
64 | + | ||
65 | + // Finds the flow using a specific level, in either direction. | ||
66 | + // If reversed, the coordinates are in the context of the latest | ||
67 | + // frame, not the frame before it. | ||
68 | + // All coordinates used in parameters are global, not scaled. | ||
69 | + bool FindFlowAtPointReversible( | ||
70 | + const int level, const float u_x, const float u_y, | ||
71 | + const bool reverse_flow, | ||
72 | + float* final_x, float* final_y) const; | ||
73 | + | ||
74 | + // Finds the flow using a specific level, filterable by forward-backward | ||
75 | + // error. All coordinates used in parameters are global, not scaled. | ||
76 | + bool FindFlowAtPointSingleLevel(const int level, | ||
77 | + const float u_x, const float u_y, | ||
78 | + const bool filter_by_fb_error, | ||
79 | + float* flow_x, float* flow_y) const; | ||
80 | + | ||
81 | + // Pyramidal optical-flow using all levels. | ||
82 | + bool FindFlowAtPointPyramidal(const float u_x, const float u_y, | ||
83 | + const bool filter_by_fb_error, | ||
84 | + float* flow_x, float* flow_y) const; | ||
85 | + | ||
86 | + private: | ||
87 | + const OpticalFlowConfig* const config_; | ||
88 | + | ||
89 | + const ImageData* frame1_; | ||
90 | + const ImageData* frame2_; | ||
91 | + | ||
92 | + // Size of the internally allocated images (after original is downsampled). | ||
93 | + const Size working_size_; | ||
94 | + | ||
95 | + TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow); | ||
96 | +}; | ||
97 | + | ||
98 | +} // namespace tf_tracking | ||
99 | + | ||
100 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ |
android/android/jni/object_tracking/sprite.h
0 → 100644
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ | ||
18 | + | ||
19 | +#ifdef __RENDER_OPENGL__ | ||
20 | + | ||
21 | +#include <GLES/gl.h> | ||
22 | +#include <GLES/glext.h> | ||
23 | + | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
26 | + | ||
27 | +namespace tf_tracking { | ||
28 | + | ||
29 | +// This class encapsulates the logic necessary to load an render image data | ||
30 | +// at the same aspect ratio as the original source. | ||
31 | +class Sprite { | ||
32 | + public: | ||
33 | + // Only create Sprites when you have an OpenGl context. | ||
34 | + explicit Sprite(const Image<uint8_t>& image) { LoadTexture(image, NULL); } | ||
35 | + | ||
36 | + Sprite(const Image<uint8_t>& image, const BoundingBox* const area) { | ||
37 | + LoadTexture(image, area); | ||
38 | + } | ||
39 | + | ||
40 | + // Also, try to only delete a Sprite when holding an OpenGl context. | ||
41 | + ~Sprite() { | ||
42 | + glDeleteTextures(1, &texture_); | ||
43 | + } | ||
44 | + | ||
45 | + inline int GetWidth() const { | ||
46 | + return actual_width_; | ||
47 | + } | ||
48 | + | ||
49 | + inline int GetHeight() const { | ||
50 | + return actual_height_; | ||
51 | + } | ||
52 | + | ||
53 | + // Draw the sprite at 0,0 - original width/height in the current reference | ||
54 | + // frame. Any transformations desired must be applied before calling this | ||
55 | + // function. | ||
56 | + void Draw() const { | ||
57 | + const float float_width = static_cast<float>(actual_width_); | ||
58 | + const float float_height = static_cast<float>(actual_height_); | ||
59 | + | ||
60 | + // Where it gets rendered to. | ||
61 | + const float vertices[] = { 0.0f, 0.0f, 0.0f, | ||
62 | + 0.0f, float_height, 0.0f, | ||
63 | + float_width, 0.0f, 0.0f, | ||
64 | + float_width, float_height, 0.0f, | ||
65 | + }; | ||
66 | + | ||
67 | + // The coordinates the texture gets drawn from. | ||
68 | + const float max_x = float_width / texture_width_; | ||
69 | + const float max_y = float_height / texture_height_; | ||
70 | + const float textureVertices[] = { | ||
71 | + 0, 0, | ||
72 | + 0, max_y, | ||
73 | + max_x, 0, | ||
74 | + max_x, max_y, | ||
75 | + }; | ||
76 | + | ||
77 | + glEnable(GL_TEXTURE_2D); | ||
78 | + glBindTexture(GL_TEXTURE_2D, texture_); | ||
79 | + | ||
80 | + glEnableClientState(GL_VERTEX_ARRAY); | ||
81 | + glEnableClientState(GL_TEXTURE_COORD_ARRAY); | ||
82 | + | ||
83 | + glVertexPointer(3, GL_FLOAT, 0, vertices); | ||
84 | + glTexCoordPointer(2, GL_FLOAT, 0, textureVertices); | ||
85 | + | ||
86 | + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); | ||
87 | + | ||
88 | + glDisableClientState(GL_VERTEX_ARRAY); | ||
89 | + glDisableClientState(GL_TEXTURE_COORD_ARRAY); | ||
90 | + } | ||
91 | + | ||
92 | + private: | ||
93 | + inline int GetNextPowerOfTwo(const int number) const { | ||
94 | + int power_of_two = 1; | ||
95 | + while (power_of_two < number) { | ||
96 | + power_of_two *= 2; | ||
97 | + } | ||
98 | + return power_of_two; | ||
99 | + } | ||
100 | + | ||
101 | + // TODO(andrewharp): Allow sprites to have their textures reloaded. | ||
102 | + void LoadTexture(const Image<uint8_t>& texture_source, | ||
103 | + const BoundingBox* const area) { | ||
104 | + glEnable(GL_TEXTURE_2D); | ||
105 | + | ||
106 | + glGenTextures(1, &texture_); | ||
107 | + | ||
108 | + glBindTexture(GL_TEXTURE_2D, texture_); | ||
109 | + | ||
110 | + int left = 0; | ||
111 | + int top = 0; | ||
112 | + | ||
113 | + if (area != NULL) { | ||
114 | + // If a sub-region was provided to pull the texture from, use that. | ||
115 | + left = area->left_; | ||
116 | + top = area->top_; | ||
117 | + actual_width_ = area->GetWidth(); | ||
118 | + actual_height_ = area->GetHeight(); | ||
119 | + } else { | ||
120 | + actual_width_ = texture_source.GetWidth(); | ||
121 | + actual_height_ = texture_source.GetHeight(); | ||
122 | + } | ||
123 | + | ||
124 | + // The textures must be a power of two, so find the sizes that are large | ||
125 | + // enough to contain the image data. | ||
126 | + texture_width_ = GetNextPowerOfTwo(actual_width_); | ||
127 | + texture_height_ = GetNextPowerOfTwo(actual_height_); | ||
128 | + | ||
129 | + bool allocated_data = false; | ||
130 | + uint8_t* texture_data; | ||
131 | + | ||
132 | + // Except in the lucky case where we're not using a sub-region of the | ||
133 | + // original image AND the source data has dimensions that are power of two, | ||
134 | + // care must be taken to copy data at the appropriate source and destination | ||
135 | + // strides so that the final block can be copied directly into texture | ||
136 | + // memory. | ||
137 | + // TODO(andrewharp): Figure out if data can be pulled directly from the | ||
138 | + // source image with some alignment modifications. | ||
139 | + if (left != 0 || top != 0 || | ||
140 | + actual_width_ != texture_source.GetWidth() || | ||
141 | + actual_height_ != texture_source.GetHeight()) { | ||
142 | + texture_data = new uint8_t[actual_width_ * actual_height_]; | ||
143 | + | ||
144 | + for (int y = 0; y < actual_height_; ++y) { | ||
145 | + memcpy(texture_data + actual_width_ * y, texture_source[top + y] + left, | ||
146 | + actual_width_ * sizeof(uint8_t)); | ||
147 | + } | ||
148 | + allocated_data = true; | ||
149 | + } else { | ||
150 | + // Cast away const-ness because for some reason glTexSubImage2D wants | ||
151 | + // a non-const data pointer. | ||
152 | + texture_data = const_cast<uint8_t*>(texture_source.data()); | ||
153 | + } | ||
154 | + | ||
155 | + glTexImage2D(GL_TEXTURE_2D, | ||
156 | + 0, | ||
157 | + GL_LUMINANCE, | ||
158 | + texture_width_, | ||
159 | + texture_height_, | ||
160 | + 0, | ||
161 | + GL_LUMINANCE, | ||
162 | + GL_UNSIGNED_BYTE, | ||
163 | + NULL); | ||
164 | + | ||
165 | + glPixelStorei(GL_UNPACK_ALIGNMENT, 1); | ||
166 | + glTexSubImage2D(GL_TEXTURE_2D, | ||
167 | + 0, | ||
168 | + 0, | ||
169 | + 0, | ||
170 | + actual_width_, | ||
171 | + actual_height_, | ||
172 | + GL_LUMINANCE, | ||
173 | + GL_UNSIGNED_BYTE, | ||
174 | + texture_data); | ||
175 | + | ||
176 | + if (allocated_data) { | ||
177 | + delete(texture_data); | ||
178 | + } | ||
179 | + | ||
180 | + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); | ||
181 | + } | ||
182 | + | ||
183 | + // The id for the texture on the GPU. | ||
184 | + GLuint texture_; | ||
185 | + | ||
186 | + // The width and height to be used for display purposes, referring to the | ||
187 | + // dimensions of the original texture. | ||
188 | + int actual_width_; | ||
189 | + int actual_height_; | ||
190 | + | ||
191 | + // The allocated dimensions of the texture data, which must be powers of 2. | ||
192 | + int texture_width_; | ||
193 | + int texture_height_; | ||
194 | + | ||
195 | + TF_DISALLOW_COPY_AND_ASSIGN(Sprite); | ||
196 | +}; | ||
197 | + | ||
198 | +} // namespace tf_tracking | ||
199 | + | ||
200 | +#endif // __RENDER_OPENGL__ | ||
201 | + | ||
202 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" | ||
17 | + | ||
18 | +#ifdef LOG_TIME | ||
19 | +// Storage for logging functionality. | ||
20 | +int num_time_logs = 0; | ||
21 | +LogEntry time_logs[NUM_LOGS]; | ||
22 | + | ||
23 | +int num_avg_entries = 0; | ||
24 | +AverageEntry avg_entries[NUM_LOGS]; | ||
25 | +#endif |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// Utility functions for performance profiling. | ||
17 | + | ||
18 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ | ||
19 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ | ||
20 | + | ||
21 | +#include <stdint.h> | ||
22 | + | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
25 | + | ||
26 | +#ifdef LOG_TIME | ||
27 | + | ||
28 | +// Blend constant for running average. | ||
29 | +#define ALPHA 0.98f | ||
30 | +#define NUM_LOGS 100 | ||
31 | + | ||
32 | +struct LogEntry { | ||
33 | + const char* id; | ||
34 | + int64_t time_stamp; | ||
35 | +}; | ||
36 | + | ||
37 | +struct AverageEntry { | ||
38 | + const char* id; | ||
39 | + float average_duration; | ||
40 | +}; | ||
41 | + | ||
42 | +// Storage for keeping track of this frame's values. | ||
43 | +extern int num_time_logs; | ||
44 | +extern LogEntry time_logs[NUM_LOGS]; | ||
45 | + | ||
46 | +// Storage for keeping track of average values (each entry may not be printed | ||
47 | +// out each frame). | ||
48 | +extern AverageEntry avg_entries[NUM_LOGS]; | ||
49 | +extern int num_avg_entries; | ||
50 | + | ||
51 | +// Call this at the start of a logging phase. | ||
52 | +inline static void ResetTimeLog() { | ||
53 | + num_time_logs = 0; | ||
54 | +} | ||
55 | + | ||
56 | + | ||
57 | +// Log a message to be printed out when printTimeLog is called, along with the | ||
58 | +// amount of time in ms that has passed since the last call to this function. | ||
59 | +inline static void TimeLog(const char* const str) { | ||
60 | + LOGV("%s", str); | ||
61 | + if (num_time_logs >= NUM_LOGS) { | ||
62 | + LOGE("Out of log entries!"); | ||
63 | + return; | ||
64 | + } | ||
65 | + | ||
66 | + time_logs[num_time_logs].id = str; | ||
67 | + time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos(); | ||
68 | + ++num_time_logs; | ||
69 | +} | ||
70 | + | ||
71 | + | ||
72 | +inline static float Blend(float old_val, float new_val) { | ||
73 | + return ALPHA * old_val + (1.0f - ALPHA) * new_val; | ||
74 | +} | ||
75 | + | ||
76 | + | ||
77 | +inline static float UpdateAverage(const char* str, const float new_val) { | ||
78 | + for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) { | ||
79 | + AverageEntry* const entry = avg_entries + entry_num; | ||
80 | + if (str == entry->id) { | ||
81 | + entry->average_duration = Blend(entry->average_duration, new_val); | ||
82 | + return entry->average_duration; | ||
83 | + } | ||
84 | + } | ||
85 | + | ||
86 | + if (num_avg_entries >= NUM_LOGS) { | ||
87 | + LOGE("Too many log entries!"); | ||
88 | + } | ||
89 | + | ||
90 | + // If it wasn't there already, add it. | ||
91 | + avg_entries[num_avg_entries].id = str; | ||
92 | + avg_entries[num_avg_entries].average_duration = new_val; | ||
93 | + ++num_avg_entries; | ||
94 | + | ||
95 | + return new_val; | ||
96 | +} | ||
97 | + | ||
98 | + | ||
99 | +// Prints out all the timeLog statements in chronological order with the | ||
100 | +// interval that passed between subsequent statements. The total time between | ||
101 | +// the first and last statements is printed last. | ||
102 | +inline static void PrintTimeLog() { | ||
103 | + LogEntry* last_time = time_logs; | ||
104 | + | ||
105 | + float average_running_total = 0.0f; | ||
106 | + | ||
107 | + for (int i = 0; i < num_time_logs; ++i) { | ||
108 | + LogEntry* const this_time = time_logs + i; | ||
109 | + | ||
110 | + const float curr_time = | ||
111 | + (this_time->time_stamp - last_time->time_stamp) / 1000000.0f; | ||
112 | + | ||
113 | + const float avg_time = UpdateAverage(this_time->id, curr_time); | ||
114 | + average_running_total += avg_time; | ||
115 | + | ||
116 | + LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time); | ||
117 | + last_time = this_time; | ||
118 | + } | ||
119 | + | ||
120 | + const float total_time = | ||
121 | + (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f; | ||
122 | + | ||
123 | + LOGD("TOTAL TIME: %6.3fms %6.4fms\n", | ||
124 | + total_time, average_running_total); | ||
125 | + LOGD(" "); | ||
126 | +} | ||
127 | +#else | ||
128 | +inline static void ResetTimeLog() {} | ||
129 | + | ||
130 | +inline static void TimeLog(const char* const str) { | ||
131 | + LOGV("%s", str); | ||
132 | +} | ||
133 | + | ||
134 | +inline static void PrintTimeLog() {} | ||
135 | +#endif | ||
136 | + | ||
137 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" | ||
17 | + | ||
18 | +namespace tf_tracking { | ||
19 | + | ||
20 | +static const float kInitialDistance = 20.0f; | ||
21 | + | ||
22 | +static void InitNormalized(const Image<uint8_t>& src_image, | ||
23 | + const BoundingBox& position, | ||
24 | + Image<float>* const dst_image) { | ||
25 | + BoundingBox scaled_box(position); | ||
26 | + CopyArea(src_image, scaled_box, dst_image); | ||
27 | + NormalizeImage(dst_image); | ||
28 | +} | ||
29 | + | ||
30 | +TrackedObject::TrackedObject(const std::string& id, const Image<uint8_t>& image, | ||
31 | + const BoundingBox& bounding_box, | ||
32 | + ObjectModelBase* const model) | ||
33 | + : id_(id), | ||
34 | + last_known_position_(bounding_box), | ||
35 | + last_detection_position_(bounding_box), | ||
36 | + position_last_computed_time_(-1), | ||
37 | + object_model_(model), | ||
38 | + last_detection_thumbnail_(kNormalizedThumbnailSize, | ||
39 | + kNormalizedThumbnailSize), | ||
40 | + last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize), | ||
41 | + tracked_correlation_(0.0f), | ||
42 | + tracked_match_score_(0.0), | ||
43 | + num_consecutive_frames_below_threshold_(0), | ||
44 | + allowable_detection_distance_(Square(kInitialDistance)) { | ||
45 | + InitNormalized(image, bounding_box, &last_detection_thumbnail_); | ||
46 | +} | ||
47 | + | ||
48 | +TrackedObject::~TrackedObject() {} | ||
49 | + | ||
50 | +void TrackedObject::UpdatePosition(const BoundingBox& new_position, | ||
51 | + const int64_t timestamp, | ||
52 | + const ImageData& image_data, | ||
53 | + const bool authoritative) { | ||
54 | + last_known_position_ = new_position; | ||
55 | + position_last_computed_time_ = timestamp; | ||
56 | + | ||
57 | + InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_); | ||
58 | + | ||
59 | + const float last_localization_correlation = ComputeCrossCorrelation( | ||
60 | + last_detection_thumbnail_.data(), | ||
61 | + last_frame_thumbnail_.data(), | ||
62 | + last_frame_thumbnail_.data_size_); | ||
63 | + LOGV("Tracked correlation to last localization: %.6f", | ||
64 | + last_localization_correlation); | ||
65 | + | ||
66 | + // Correlation to object model, if it exists. | ||
67 | + if (object_model_ != NULL) { | ||
68 | + tracked_correlation_ = | ||
69 | + object_model_->GetMaxCorrelation(last_frame_thumbnail_); | ||
70 | + LOGV("Tracked correlation to model: %.6f", | ||
71 | + tracked_correlation_); | ||
72 | + | ||
73 | + tracked_match_score_ = | ||
74 | + object_model_->GetMatchScore(new_position, image_data); | ||
75 | + LOGV("Tracked match score with model: %.6f", | ||
76 | + tracked_match_score_.value); | ||
77 | + } else { | ||
78 | + // If there's no model to check against, set the tracked correlation to | ||
79 | + // simply be the correlation to the last set position. | ||
80 | + tracked_correlation_ = last_localization_correlation; | ||
81 | + tracked_match_score_ = MatchScore(0.0f); | ||
82 | + } | ||
83 | + | ||
84 | + // Determine if it's still being tracked. | ||
85 | + if (tracked_correlation_ >= kMinimumCorrelationForTracking && | ||
86 | + tracked_match_score_ >= kMinimumMatchScore) { | ||
87 | + num_consecutive_frames_below_threshold_ = 0; | ||
88 | + | ||
89 | + if (object_model_ != NULL) { | ||
90 | + object_model_->TrackStep(last_known_position_, *image_data.GetImage(), | ||
91 | + *image_data.GetIntegralImage(), authoritative); | ||
92 | + } | ||
93 | + } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) { | ||
94 | + if (num_consecutive_frames_below_threshold_ < 1000) { | ||
95 | + LOGD("Tracked match score is way too low (%.6f), aborting track.", | ||
96 | + tracked_match_score_.value); | ||
97 | + } | ||
98 | + | ||
99 | + // Add an absurd amount of missed frames so that all heuristics will | ||
100 | + // consider it a lost track. | ||
101 | + num_consecutive_frames_below_threshold_ += 1000; | ||
102 | + | ||
103 | + if (object_model_ != NULL) { | ||
104 | + object_model_->TrackLost(); | ||
105 | + } | ||
106 | + } else { | ||
107 | + ++num_consecutive_frames_below_threshold_; | ||
108 | + allowable_detection_distance_ *= 1.1f; | ||
109 | + } | ||
110 | +} | ||
111 | + | ||
112 | +void TrackedObject::OnDetection(ObjectModelBase* const model, | ||
113 | + const BoundingBox& detection_position, | ||
114 | + const MatchScore match_score, | ||
115 | + const int64_t timestamp, | ||
116 | + const ImageData& image_data) { | ||
117 | + const float overlap = detection_position.PascalScore(last_known_position_); | ||
118 | + if (overlap > kPositionOverlapThreshold) { | ||
119 | + // If the position agreement with the current tracked position is good | ||
120 | + // enough, lock all the current unlocked examples. | ||
121 | + object_model_->TrackConfirmed(); | ||
122 | + num_consecutive_frames_below_threshold_ = 0; | ||
123 | + } | ||
124 | + | ||
125 | + // Before relocalizing, make sure the new proposed position is better than | ||
126 | + // the existing position by a small amount to prevent thrashing. | ||
127 | + if (match_score <= tracked_match_score_ + kMatchScoreBuffer) { | ||
128 | + LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f", | ||
129 | + match_score.value, tracked_match_score_.value, | ||
130 | + kMatchScoreBuffer.value); | ||
131 | + return; | ||
132 | + } | ||
133 | + | ||
134 | + LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to " | ||
135 | + "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f", | ||
136 | + last_known_position_.left_, last_known_position_.top_, | ||
137 | + last_known_position_.GetWidth(), last_known_position_.GetHeight(), | ||
138 | + detection_position.left_, detection_position.top_, | ||
139 | + detection_position.GetWidth(), detection_position.GetHeight(), | ||
140 | + match_score.value, tracked_match_score_.value); | ||
141 | + | ||
142 | + if (overlap < kPositionOverlapThreshold) { | ||
143 | + // The path might be good, it might be bad, but it's no longer a path | ||
144 | + // since we're moving the box to a new position, so just nuke it from | ||
145 | + // orbit to be safe. | ||
146 | + object_model_->TrackLost(); | ||
147 | + } | ||
148 | + | ||
149 | + object_model_ = model; | ||
150 | + | ||
151 | + // Reset the last detected appearance. | ||
152 | + InitNormalized( | ||
153 | + *image_data.GetImage(), detection_position, &last_detection_thumbnail_); | ||
154 | + | ||
155 | + num_consecutive_frames_below_threshold_ = 0; | ||
156 | + last_detection_position_ = detection_position; | ||
157 | + | ||
158 | + UpdatePosition(detection_position, timestamp, image_data, false); | ||
159 | + allowable_detection_distance_ = Square(kInitialDistance); | ||
160 | +} | ||
161 | + | ||
162 | +} // namespace tf_tracking |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ | ||
18 | + | ||
19 | +#ifdef __RENDER_OPENGL__ | ||
20 | +#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h" | ||
21 | +#endif | ||
22 | +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" | ||
23 | + | ||
24 | +namespace tf_tracking { | ||
25 | + | ||
26 | +// A TrackedObject is a specific instance of an ObjectModel, with a known | ||
27 | +// position in the world. | ||
28 | +// It provides the last known position and number of recent detection failures, | ||
29 | +// in addition to the more general appearance data associated with the object | ||
30 | +// class (which is in ObjectModel). | ||
31 | +// TODO(andrewharp): Make getters/setters follow styleguide. | ||
32 | +class TrackedObject { | ||
33 | + public: | ||
34 | + TrackedObject(const std::string& id, const Image<uint8_t>& image, | ||
35 | + const BoundingBox& bounding_box, ObjectModelBase* const model); | ||
36 | + | ||
37 | + ~TrackedObject(); | ||
38 | + | ||
39 | + void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp, | ||
40 | + const ImageData& image_data, const bool authoritative); | ||
41 | + | ||
42 | + // This method is called when the tracked object is detected at a | ||
43 | + // given position, and allows the associated Model to grow and/or prune | ||
44 | + // itself based on where the detection occurred. | ||
45 | + void OnDetection(ObjectModelBase* const model, | ||
46 | + const BoundingBox& detection_position, | ||
47 | + const MatchScore match_score, const int64_t timestamp, | ||
48 | + const ImageData& image_data); | ||
49 | + | ||
50 | + // Called when there's no detection of the tracked object. This will cause | ||
51 | + // a tracking failure after enough consecutive failures if the area under | ||
52 | + // the current bounding box also doesn't meet a minimum correlation threshold | ||
53 | + // with the model. | ||
54 | + void OnDetectionFailure() {} | ||
55 | + | ||
56 | + inline bool IsVisible() const { | ||
57 | + return tracked_correlation_ >= kMinimumCorrelationForTracking || | ||
58 | + num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures; | ||
59 | + } | ||
60 | + | ||
61 | + inline float GetCorrelation() { | ||
62 | + return tracked_correlation_; | ||
63 | + } | ||
64 | + | ||
65 | + inline MatchScore GetMatchScore() { | ||
66 | + return tracked_match_score_; | ||
67 | + } | ||
68 | + | ||
69 | + inline BoundingBox GetPosition() const { | ||
70 | + return last_known_position_; | ||
71 | + } | ||
72 | + | ||
73 | + inline BoundingBox GetLastDetectionPosition() const { | ||
74 | + return last_detection_position_; | ||
75 | + } | ||
76 | + | ||
77 | + inline const ObjectModelBase* GetModel() const { | ||
78 | + return object_model_; | ||
79 | + } | ||
80 | + | ||
81 | + inline const std::string& GetName() const { | ||
82 | + return id_; | ||
83 | + } | ||
84 | + | ||
85 | + inline void Draw() const { | ||
86 | +#ifdef __RENDER_OPENGL__ | ||
87 | + if (tracked_correlation_ < kMinimumCorrelationForTracking) { | ||
88 | + glColor4f(MAX(0.0f, -tracked_correlation_), | ||
89 | + MAX(0.0f, tracked_correlation_), | ||
90 | + 0.0f, | ||
91 | + 1.0f); | ||
92 | + } else { | ||
93 | + glColor4f(MAX(0.0f, -tracked_correlation_), | ||
94 | + MAX(0.0f, tracked_correlation_), | ||
95 | + 1.0f, | ||
96 | + 1.0f); | ||
97 | + } | ||
98 | + | ||
99 | + // Render the box itself. | ||
100 | + BoundingBox temp_box(last_known_position_); | ||
101 | + DrawBox(temp_box); | ||
102 | + | ||
103 | + // Render a box inside this one (in case the actual box is hidden). | ||
104 | + const float kBufferSize = 1.0f; | ||
105 | + temp_box.left_ -= kBufferSize; | ||
106 | + temp_box.top_ -= kBufferSize; | ||
107 | + temp_box.right_ += kBufferSize; | ||
108 | + temp_box.bottom_ += kBufferSize; | ||
109 | + DrawBox(temp_box); | ||
110 | + | ||
111 | + // Render one outside as well. | ||
112 | + temp_box.left_ -= -2.0f * kBufferSize; | ||
113 | + temp_box.top_ -= -2.0f * kBufferSize; | ||
114 | + temp_box.right_ += -2.0f * kBufferSize; | ||
115 | + temp_box.bottom_ += -2.0f * kBufferSize; | ||
116 | + DrawBox(temp_box); | ||
117 | +#endif | ||
118 | + } | ||
119 | + | ||
120 | + // Get current object's num_consecutive_frames_below_threshold_. | ||
121 | + inline int64_t GetNumConsecutiveFramesBelowThreshold() { | ||
122 | + return num_consecutive_frames_below_threshold_; | ||
123 | + } | ||
124 | + | ||
125 | + // Reset num_consecutive_frames_below_threshold_ to 0. | ||
126 | + inline void resetNumConsecutiveFramesBelowThreshold() { | ||
127 | + num_consecutive_frames_below_threshold_ = 0; | ||
128 | + } | ||
129 | + | ||
130 | + inline float GetAllowableDistanceSquared() const { | ||
131 | + return allowable_detection_distance_; | ||
132 | + } | ||
133 | + | ||
134 | + private: | ||
135 | + // The unique id used throughout the system to identify this | ||
136 | + // tracked object. | ||
137 | + const std::string id_; | ||
138 | + | ||
139 | + // The last known position of the object. | ||
140 | + BoundingBox last_known_position_; | ||
141 | + | ||
142 | + // The last known position of the object. | ||
143 | + BoundingBox last_detection_position_; | ||
144 | + | ||
145 | + // When the position was last computed. | ||
146 | + int64_t position_last_computed_time_; | ||
147 | + | ||
148 | + // The object model this tracked object is representative of. | ||
149 | + ObjectModelBase* object_model_; | ||
150 | + | ||
151 | + Image<float> last_detection_thumbnail_; | ||
152 | + | ||
153 | + Image<float> last_frame_thumbnail_; | ||
154 | + | ||
155 | + // The correlation of the object model with the preview frame at its last | ||
156 | + // tracked position. | ||
157 | + float tracked_correlation_; | ||
158 | + | ||
159 | + MatchScore tracked_match_score_; | ||
160 | + | ||
161 | + // The number of consecutive frames that the tracked position for this object | ||
162 | + // has been under the correlation threshold. | ||
163 | + int num_consecutive_frames_below_threshold_; | ||
164 | + | ||
165 | + float allowable_detection_distance_; | ||
166 | + | ||
167 | + friend std::ostream& operator<<(std::ostream& stream, | ||
168 | + const TrackedObject& tracked_object); | ||
169 | + | ||
170 | + TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject); | ||
171 | +}; | ||
172 | + | ||
173 | +inline std::ostream& operator<<(std::ostream& stream, | ||
174 | + const TrackedObject& tracked_object) { | ||
175 | + stream << tracked_object.id_ | ||
176 | + << " " << tracked_object.last_known_position_ | ||
177 | + << " " << tracked_object.position_last_computed_time_ | ||
178 | + << " " << tracked_object.num_consecutive_frames_below_threshold_ | ||
179 | + << " " << tracked_object.object_model_ | ||
180 | + << " " << tracked_object.tracked_correlation_; | ||
181 | + return stream; | ||
182 | +} | ||
183 | + | ||
184 | +} // namespace tf_tracking | ||
185 | + | ||
186 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ |
android/android/jni/object_tracking/utils.h
0 → 100644
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ | ||
18 | + | ||
19 | +#include <math.h> | ||
20 | +#include <stdint.h> | ||
21 | +#include <stdlib.h> | ||
22 | +#include <time.h> | ||
23 | + | ||
24 | +#include <cmath> // for std::abs(float) | ||
25 | + | ||
26 | +#ifndef HAVE_CLOCK_GETTIME | ||
27 | +// Use gettimeofday() instead of clock_gettime(). | ||
28 | +#include <sys/time.h> | ||
29 | +#endif // ifdef HAVE_CLOCK_GETTIME | ||
30 | + | ||
31 | +#include "tensorflow/examples/android/jni/object_tracking/logging.h" | ||
32 | + | ||
33 | +// TODO(andrewharp): clean up these macros to use the codebase statndard. | ||
34 | + | ||
35 | +// A very small number, generally used as the tolerance for accumulated | ||
36 | +// floating point errors in bounds-checks. | ||
37 | +#define EPSILON 0.00001f | ||
38 | + | ||
39 | +#define SAFE_DELETE(pointer) {\ | ||
40 | + if ((pointer) != NULL) {\ | ||
41 | + LOGV("Safe deleting pointer: %s", #pointer);\ | ||
42 | + delete (pointer);\ | ||
43 | + (pointer) = NULL;\ | ||
44 | + } else {\ | ||
45 | + LOGV("Pointer already null: %s", #pointer);\ | ||
46 | + }\ | ||
47 | +} | ||
48 | + | ||
49 | + | ||
50 | +#ifdef __GOOGLE__ | ||
51 | + | ||
52 | +#define CHECK_ALWAYS(condition, format, ...) {\ | ||
53 | + CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\ | ||
54 | +} | ||
55 | + | ||
56 | +#define SCHECK(condition, format, ...) {\ | ||
57 | + DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\ | ||
58 | +} | ||
59 | + | ||
60 | +#else | ||
61 | + | ||
62 | +#define CHECK_ALWAYS(condition, format, ...) {\ | ||
63 | + if (!(condition)) {\ | ||
64 | + LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\ | ||
65 | + abort();\ | ||
66 | + }\ | ||
67 | +} | ||
68 | + | ||
69 | +#ifdef SANITY_CHECKS | ||
70 | +#define SCHECK(condition, format, ...) {\ | ||
71 | + CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\ | ||
72 | +} | ||
73 | +#else | ||
74 | +#define SCHECK(condition, format, ...) {} | ||
75 | +#endif // SANITY_CHECKS | ||
76 | + | ||
77 | +#endif // __GOOGLE__ | ||
78 | + | ||
79 | + | ||
80 | +#ifndef MAX | ||
81 | +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) | ||
82 | +#endif | ||
83 | +#ifndef MIN | ||
84 | +#define MIN(a, b) (((a) > (b)) ? (b) : (a)) | ||
85 | +#endif | ||
86 | + | ||
87 | +inline static int64_t CurrentThreadTimeNanos() { | ||
88 | +#ifdef HAVE_CLOCK_GETTIME | ||
89 | + struct timespec tm; | ||
90 | + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm); | ||
91 | + return tm.tv_sec * 1000000000LL + tm.tv_nsec; | ||
92 | +#else | ||
93 | + struct timeval tv; | ||
94 | + gettimeofday(&tv, NULL); | ||
95 | + return tv.tv_sec * 1000000000 + tv.tv_usec * 1000; | ||
96 | +#endif | ||
97 | +} | ||
98 | + | ||
99 | +inline static int64_t CurrentRealTimeMillis() { | ||
100 | +#ifdef HAVE_CLOCK_GETTIME | ||
101 | + struct timespec tm; | ||
102 | + clock_gettime(CLOCK_MONOTONIC, &tm); | ||
103 | + return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL; | ||
104 | +#else | ||
105 | + struct timeval tv; | ||
106 | + gettimeofday(&tv, NULL); | ||
107 | + return tv.tv_sec * 1000 + tv.tv_usec / 1000; | ||
108 | +#endif | ||
109 | +} | ||
110 | + | ||
111 | + | ||
112 | +template<typename T> | ||
113 | +inline static T Square(const T a) { | ||
114 | + return a * a; | ||
115 | +} | ||
116 | + | ||
117 | + | ||
118 | +template<typename T> | ||
119 | +inline static T Clip(const T a, const T floor, const T ceil) { | ||
120 | + SCHECK(ceil >= floor, "Bounds mismatch!"); | ||
121 | + return (a <= floor) ? floor : ((a >= ceil) ? ceil : a); | ||
122 | +} | ||
123 | + | ||
124 | + | ||
125 | +template<typename T> | ||
126 | +inline static int Floor(const T a) { | ||
127 | + return static_cast<int>(a); | ||
128 | +} | ||
129 | + | ||
130 | + | ||
131 | +template<typename T> | ||
132 | +inline static int Ceil(const T a) { | ||
133 | + return Floor(a) + 1; | ||
134 | +} | ||
135 | + | ||
136 | + | ||
137 | +template<typename T> | ||
138 | +inline static bool InRange(const T a, const T min, const T max) { | ||
139 | + return (a >= min) && (a <= max); | ||
140 | +} | ||
141 | + | ||
142 | + | ||
143 | +inline static bool ValidIndex(const int a, const int max) { | ||
144 | + return (a >= 0) && (a < max); | ||
145 | +} | ||
146 | + | ||
147 | + | ||
148 | +inline bool NearlyEqual(const float a, const float b, const float tolerance) { | ||
149 | + return std::abs(a - b) < tolerance; | ||
150 | +} | ||
151 | + | ||
152 | + | ||
153 | +inline bool NearlyEqual(const float a, const float b) { | ||
154 | + return NearlyEqual(a, b, EPSILON); | ||
155 | +} | ||
156 | + | ||
157 | + | ||
158 | +template<typename T> | ||
159 | +inline static int Round(const float a) { | ||
160 | + return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a)); | ||
161 | +} | ||
162 | + | ||
163 | + | ||
164 | +template<typename T> | ||
165 | +inline static void Swap(T* const a, T* const b) { | ||
166 | + // Cache out the VALUE of what's at a. | ||
167 | + T tmp = *a; | ||
168 | + *a = *b; | ||
169 | + | ||
170 | + *b = tmp; | ||
171 | +} | ||
172 | + | ||
173 | + | ||
174 | +static inline float randf() { | ||
175 | + return rand() / static_cast<float>(RAND_MAX); | ||
176 | +} | ||
177 | + | ||
178 | +static inline float randf(const float min_value, const float max_value) { | ||
179 | + return randf() * (max_value - min_value) + min_value; | ||
180 | +} | ||
181 | + | ||
182 | +static inline uint16_t RealToFixed115(const float real_number) { | ||
183 | + SCHECK(InRange(real_number, 0.0f, 2048.0f), | ||
184 | + "Value out of range! %.2f", real_number); | ||
185 | + | ||
186 | + static const float kMult = 32.0f; | ||
187 | + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f; | ||
188 | + return static_cast<uint16_t>(real_number * kMult + round_add); | ||
189 | +} | ||
190 | + | ||
191 | +static inline float FixedToFloat115(const uint16_t fp_number) { | ||
192 | + const float kDiv = 32.0f; | ||
193 | + return (static_cast<float>(fp_number) / kDiv); | ||
194 | +} | ||
195 | + | ||
196 | +static inline int RealToFixed1616(const float real_number) { | ||
197 | + static const float kMult = 65536.0f; | ||
198 | + SCHECK(InRange(real_number, -kMult, kMult), | ||
199 | + "Value out of range! %.2f", real_number); | ||
200 | + | ||
201 | + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f; | ||
202 | + return static_cast<int>(real_number * kMult + round_add); | ||
203 | +} | ||
204 | + | ||
205 | +static inline float FixedToFloat1616(const int fp_number) { | ||
206 | + const float kDiv = 65536.0f; | ||
207 | + return (static_cast<float>(fp_number) / kDiv); | ||
208 | +} | ||
209 | + | ||
210 | +template<typename T> | ||
211 | +// produces numbers in range [0,2*M_PI] (rather than -PI,PI) | ||
212 | +inline T FastAtan2(const T y, const T x) { | ||
213 | + static const T coeff_1 = (T)(M_PI / 4.0); | ||
214 | + static const T coeff_2 = (T)(3.0 * coeff_1); | ||
215 | + const T abs_y = fabs(y); | ||
216 | + T angle; | ||
217 | + if (x >= 0) { | ||
218 | + T r = (x - abs_y) / (x + abs_y); | ||
219 | + angle = coeff_1 - coeff_1 * r; | ||
220 | + } else { | ||
221 | + T r = (x + abs_y) / (abs_y - x); | ||
222 | + angle = coeff_2 - coeff_1 * r; | ||
223 | + } | ||
224 | + static const T PI_2 = 2.0 * M_PI; | ||
225 | + return y < 0 ? PI_2 - angle : angle; | ||
226 | +} | ||
227 | + | ||
228 | +#define NELEMS(X) (sizeof(X) / sizeof(X[0])) | ||
229 | + | ||
230 | +namespace tf_tracking { | ||
231 | + | ||
232 | +#ifdef __ARM_NEON | ||
233 | +float ComputeMeanNeon(const float* const values, const int num_vals); | ||
234 | + | ||
235 | +float ComputeStdDevNeon(const float* const values, const int num_vals, | ||
236 | + const float mean); | ||
237 | + | ||
238 | +float ComputeWeightedMeanNeon(const float* const values, | ||
239 | + const float* const weights, const int num_vals); | ||
240 | + | ||
241 | +float ComputeCrossCorrelationNeon(const float* const values1, | ||
242 | + const float* const values2, | ||
243 | + const int num_vals); | ||
244 | +#endif | ||
245 | + | ||
246 | +inline float ComputeMeanCpu(const float* const values, const int num_vals) { | ||
247 | + // Get mean. | ||
248 | + float sum = values[0]; | ||
249 | + for (int i = 1; i < num_vals; ++i) { | ||
250 | + sum += values[i]; | ||
251 | + } | ||
252 | + return sum / static_cast<float>(num_vals); | ||
253 | +} | ||
254 | + | ||
255 | + | ||
256 | +inline float ComputeMean(const float* const values, const int num_vals) { | ||
257 | + return | ||
258 | +#ifdef __ARM_NEON | ||
259 | + (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) : | ||
260 | +#endif | ||
261 | + ComputeMeanCpu(values, num_vals); | ||
262 | +} | ||
263 | + | ||
264 | + | ||
265 | +inline float ComputeStdDevCpu(const float* const values, | ||
266 | + const int num_vals, | ||
267 | + const float mean) { | ||
268 | + // Get Std dev. | ||
269 | + float squared_sum = 0.0f; | ||
270 | + for (int i = 0; i < num_vals; ++i) { | ||
271 | + squared_sum += Square(values[i] - mean); | ||
272 | + } | ||
273 | + return sqrt(squared_sum / static_cast<float>(num_vals)); | ||
274 | +} | ||
275 | + | ||
276 | + | ||
277 | +inline float ComputeStdDev(const float* const values, | ||
278 | + const int num_vals, | ||
279 | + const float mean) { | ||
280 | + return | ||
281 | +#ifdef __ARM_NEON | ||
282 | + (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) : | ||
283 | +#endif | ||
284 | + ComputeStdDevCpu(values, num_vals, mean); | ||
285 | +} | ||
286 | + | ||
287 | + | ||
288 | +// TODO(andrewharp): Accelerate with NEON. | ||
289 | +inline float ComputeWeightedMean(const float* const values, | ||
290 | + const float* const weights, | ||
291 | + const int num_vals) { | ||
292 | + float sum = 0.0f; | ||
293 | + float total_weight = 0.0f; | ||
294 | + for (int i = 0; i < num_vals; ++i) { | ||
295 | + sum += values[i] * weights[i]; | ||
296 | + total_weight += weights[i]; | ||
297 | + } | ||
298 | + return sum / num_vals; | ||
299 | +} | ||
300 | + | ||
301 | + | ||
302 | +inline float ComputeCrossCorrelationCpu(const float* const values1, | ||
303 | + const float* const values2, | ||
304 | + const int num_vals) { | ||
305 | + float sxy = 0.0f; | ||
306 | + for (int offset = 0; offset < num_vals; ++offset) { | ||
307 | + sxy += values1[offset] * values2[offset]; | ||
308 | + } | ||
309 | + | ||
310 | + const float cross_correlation = sxy / num_vals; | ||
311 | + | ||
312 | + return cross_correlation; | ||
313 | +} | ||
314 | + | ||
315 | + | ||
316 | +inline float ComputeCrossCorrelation(const float* const values1, | ||
317 | + const float* const values2, | ||
318 | + const int num_vals) { | ||
319 | + return | ||
320 | +#ifdef __ARM_NEON | ||
321 | + (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals) | ||
322 | + : | ||
323 | +#endif | ||
324 | + ComputeCrossCorrelationCpu(values1, values2, num_vals); | ||
325 | +} | ||
326 | + | ||
327 | + | ||
328 | +inline void NormalizeNumbers(float* const values, const int num_vals) { | ||
329 | + // Find the mean and then subtract so that the new mean is 0.0. | ||
330 | + const float mean = ComputeMean(values, num_vals); | ||
331 | + VLOG(2) << "Mean is " << mean; | ||
332 | + float* curr_data = values; | ||
333 | + for (int i = 0; i < num_vals; ++i) { | ||
334 | + *curr_data -= mean; | ||
335 | + curr_data++; | ||
336 | + } | ||
337 | + | ||
338 | + // Now divide by the std deviation so the new standard deviation is 1.0. | ||
339 | + // The numbers might all be identical (and thus shifted to 0.0 now), | ||
340 | + // so only scale by the standard deviation if this is not the case. | ||
341 | + const float std_dev = ComputeStdDev(values, num_vals, 0.0f); | ||
342 | + if (std_dev > 0.0f) { | ||
343 | + VLOG(2) << "Std dev is " << std_dev; | ||
344 | + curr_data = values; | ||
345 | + for (int i = 0; i < num_vals; ++i) { | ||
346 | + *curr_data /= std_dev; | ||
347 | + curr_data++; | ||
348 | + } | ||
349 | + } | ||
350 | +} | ||
351 | + | ||
352 | + | ||
353 | +// Returns the determinant of a 2x2 matrix. | ||
354 | +template<class T> | ||
355 | +inline T FindDeterminant2x2(const T* const a) { | ||
356 | + // Determinant: (ad - bc) | ||
357 | + return a[0] * a[3] - a[1] * a[2]; | ||
358 | +} | ||
359 | + | ||
360 | + | ||
361 | +// Finds the inverse of a 2x2 matrix. | ||
362 | +// Returns true upon success, false if the matrix is not invertible. | ||
363 | +template<class T> | ||
364 | +inline bool Invert2x2(const T* const a, float* const a_inv) { | ||
365 | + const float det = static_cast<float>(FindDeterminant2x2(a)); | ||
366 | + if (fabs(det) < EPSILON) { | ||
367 | + return false; | ||
368 | + } | ||
369 | + const float inv_det = 1.0f / det; | ||
370 | + | ||
371 | + a_inv[0] = inv_det * static_cast<float>(a[3]); // d | ||
372 | + a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b | ||
373 | + a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c | ||
374 | + a_inv[3] = inv_det * static_cast<float>(a[0]); // a | ||
375 | + | ||
376 | + return true; | ||
377 | +} | ||
378 | + | ||
379 | +} // namespace tf_tracking | ||
380 | + | ||
381 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// NEON implementations of Image methods for compatible devices. Control | ||
17 | +// should never enter this compilation unit on incompatible devices. | ||
18 | + | ||
19 | +#ifdef __ARM_NEON | ||
20 | + | ||
21 | +#include <arm_neon.h> | ||
22 | + | ||
23 | +#include "tensorflow/examples/android/jni/object_tracking/geom.h" | ||
24 | +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" | ||
25 | +#include "tensorflow/examples/android/jni/object_tracking/image.h" | ||
26 | +#include "tensorflow/examples/android/jni/object_tracking/utils.h" | ||
27 | + | ||
28 | +namespace tf_tracking { | ||
29 | + | ||
30 | +inline static float GetSum(const float32x4_t& values) { | ||
31 | + static float32_t summed_values[4]; | ||
32 | + vst1q_f32(summed_values, values); | ||
33 | + return summed_values[0] | ||
34 | + + summed_values[1] | ||
35 | + + summed_values[2] | ||
36 | + + summed_values[3]; | ||
37 | +} | ||
38 | + | ||
39 | + | ||
40 | +float ComputeMeanNeon(const float* const values, const int num_vals) { | ||
41 | + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); | ||
42 | + | ||
43 | + const float32_t* const arm_vals = (const float32_t* const) values; | ||
44 | + float32x4_t accum = vdupq_n_f32(0.0f); | ||
45 | + | ||
46 | + int offset = 0; | ||
47 | + for (; offset <= num_vals - 4; offset += 4) { | ||
48 | + accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset])); | ||
49 | + } | ||
50 | + | ||
51 | + // Pull the accumulated values into a single variable. | ||
52 | + float sum = GetSum(accum); | ||
53 | + | ||
54 | + // Get the remaining 1 to 3 values. | ||
55 | + for (; offset < num_vals; ++offset) { | ||
56 | + sum += values[offset]; | ||
57 | + } | ||
58 | + | ||
59 | + const float mean_neon = sum / static_cast<float>(num_vals); | ||
60 | + | ||
61 | +#ifdef SANITY_CHECKS | ||
62 | + const float mean_cpu = ComputeMeanCpu(values, num_vals); | ||
63 | + SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals), | ||
64 | + "Neon mismatch with CPU mean! %.10f vs %.10f", | ||
65 | + mean_neon, mean_cpu); | ||
66 | +#endif | ||
67 | + | ||
68 | + return mean_neon; | ||
69 | +} | ||
70 | + | ||
71 | + | ||
72 | +float ComputeStdDevNeon(const float* const values, | ||
73 | + const int num_vals, const float mean) { | ||
74 | + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); | ||
75 | + | ||
76 | + const float32_t* const arm_vals = (const float32_t* const) values; | ||
77 | + const float32x4_t mean_vec = vdupq_n_f32(-mean); | ||
78 | + | ||
79 | + float32x4_t accum = vdupq_n_f32(0.0f); | ||
80 | + | ||
81 | + int offset = 0; | ||
82 | + for (; offset <= num_vals - 4; offset += 4) { | ||
83 | + const float32x4_t deltas = | ||
84 | + vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset])); | ||
85 | + | ||
86 | + accum = vmlaq_f32(accum, deltas, deltas); | ||
87 | + } | ||
88 | + | ||
89 | + // Pull the accumulated values into a single variable. | ||
90 | + float squared_sum = GetSum(accum); | ||
91 | + | ||
92 | + // Get the remaining 1 to 3 values. | ||
93 | + for (; offset < num_vals; ++offset) { | ||
94 | + squared_sum += Square(values[offset] - mean); | ||
95 | + } | ||
96 | + | ||
97 | + const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals)); | ||
98 | + | ||
99 | +#ifdef SANITY_CHECKS | ||
100 | + const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean); | ||
101 | + SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals), | ||
102 | + "Neon mismatch with CPU std dev! %.10f vs %.10f", | ||
103 | + std_dev_neon, std_dev_cpu); | ||
104 | +#endif | ||
105 | + | ||
106 | + return std_dev_neon; | ||
107 | +} | ||
108 | + | ||
109 | + | ||
110 | +float ComputeCrossCorrelationNeon(const float* const values1, | ||
111 | + const float* const values2, | ||
112 | + const int num_vals) { | ||
113 | + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); | ||
114 | + | ||
115 | + const float32_t* const arm_vals1 = (const float32_t* const) values1; | ||
116 | + const float32_t* const arm_vals2 = (const float32_t* const) values2; | ||
117 | + | ||
118 | + float32x4_t accum = vdupq_n_f32(0.0f); | ||
119 | + | ||
120 | + int offset = 0; | ||
121 | + for (; offset <= num_vals - 4; offset += 4) { | ||
122 | + accum = vmlaq_f32(accum, | ||
123 | + vld1q_f32(&arm_vals1[offset]), | ||
124 | + vld1q_f32(&arm_vals2[offset])); | ||
125 | + } | ||
126 | + | ||
127 | + // Pull the accumulated values into a single variable. | ||
128 | + float sxy = GetSum(accum); | ||
129 | + | ||
130 | + // Get the remaining 1 to 3 values. | ||
131 | + for (; offset < num_vals; ++offset) { | ||
132 | + sxy += values1[offset] * values2[offset]; | ||
133 | + } | ||
134 | + | ||
135 | + const float cross_correlation_neon = sxy / num_vals; | ||
136 | + | ||
137 | +#ifdef SANITY_CHECKS | ||
138 | + const float cross_correlation_cpu = | ||
139 | + ComputeCrossCorrelationCpu(values1, values2, num_vals); | ||
140 | + SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu, | ||
141 | + EPSILON * num_vals), | ||
142 | + "Neon mismatch with CPU cross correlation! %.10f vs %.10f", | ||
143 | + cross_correlation_neon, cross_correlation_cpu); | ||
144 | +#endif | ||
145 | + | ||
146 | + return cross_correlation_neon; | ||
147 | +} | ||
148 | + | ||
149 | +} // namespace tf_tracking | ||
150 | + | ||
151 | +#endif // __ARM_NEON |
android/android/jni/rgb2yuv.cc
0 → 100644
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// These utility functions allow for the conversion of RGB data to YUV data. | ||
17 | + | ||
18 | +#include "tensorflow/examples/android/jni/rgb2yuv.h" | ||
19 | + | ||
20 | +static inline void WriteYUV(const int x, const int y, const int width, | ||
21 | + const int r8, const int g8, const int b8, | ||
22 | + uint8_t* const pY, uint8_t* const pUV) { | ||
23 | + // Using formulas from http://msdn.microsoft.com/en-us/library/ms893078 | ||
24 | + *pY = ((66 * r8 + 129 * g8 + 25 * b8 + 128) >> 8) + 16; | ||
25 | + | ||
26 | + // Odd widths get rounded up so that UV blocks on the side don't get cut off. | ||
27 | + const int blocks_per_row = (width + 1) / 2; | ||
28 | + | ||
29 | + // 2 bytes per UV block | ||
30 | + const int offset = 2 * (((y / 2) * blocks_per_row + (x / 2))); | ||
31 | + | ||
32 | + // U and V are the average values of all 4 pixels in the block. | ||
33 | + if (!(x & 1) && !(y & 1)) { | ||
34 | + // Explicitly clear the block if this is the first pixel in it. | ||
35 | + pUV[offset] = 0; | ||
36 | + pUV[offset + 1] = 0; | ||
37 | + } | ||
38 | + | ||
39 | + // V (with divide by 4 factored in) | ||
40 | +#ifdef __APPLE__ | ||
41 | + const int u_offset = 0; | ||
42 | + const int v_offset = 1; | ||
43 | +#else | ||
44 | + const int u_offset = 1; | ||
45 | + const int v_offset = 0; | ||
46 | +#endif | ||
47 | + pUV[offset + v_offset] += ((112 * r8 - 94 * g8 - 18 * b8 + 128) >> 10) + 32; | ||
48 | + | ||
49 | + // U (with divide by 4 factored in) | ||
50 | + pUV[offset + u_offset] += ((-38 * r8 - 74 * g8 + 112 * b8 + 128) >> 10) + 32; | ||
51 | +} | ||
52 | + | ||
53 | +void ConvertARGB8888ToYUV420SP(const uint32_t* const input, | ||
54 | + uint8_t* const output, int width, int height) { | ||
55 | + uint8_t* pY = output; | ||
56 | + uint8_t* pUV = output + (width * height); | ||
57 | + const uint32_t* in = input; | ||
58 | + | ||
59 | + for (int y = 0; y < height; y++) { | ||
60 | + for (int x = 0; x < width; x++) { | ||
61 | + const uint32_t rgb = *in++; | ||
62 | +#ifdef __APPLE__ | ||
63 | + const int nB = (rgb >> 8) & 0xFF; | ||
64 | + const int nG = (rgb >> 16) & 0xFF; | ||
65 | + const int nR = (rgb >> 24) & 0xFF; | ||
66 | +#else | ||
67 | + const int nR = (rgb >> 16) & 0xFF; | ||
68 | + const int nG = (rgb >> 8) & 0xFF; | ||
69 | + const int nB = rgb & 0xFF; | ||
70 | +#endif | ||
71 | + WriteYUV(x, y, width, nR, nG, nB, pY++, pUV); | ||
72 | + } | ||
73 | + } | ||
74 | +} | ||
75 | + | ||
76 | +void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output, | ||
77 | + const int width, const int height) { | ||
78 | + uint8_t* pY = output; | ||
79 | + uint8_t* pUV = output + (width * height); | ||
80 | + const uint16_t* in = input; | ||
81 | + | ||
82 | + for (int y = 0; y < height; y++) { | ||
83 | + for (int x = 0; x < width; x++) { | ||
84 | + const uint32_t rgb = *in++; | ||
85 | + | ||
86 | + const int r5 = ((rgb >> 11) & 0x1F); | ||
87 | + const int g6 = ((rgb >> 5) & 0x3F); | ||
88 | + const int b5 = (rgb & 0x1F); | ||
89 | + | ||
90 | + // Shift left, then fill in the empty low bits with a copy of the high | ||
91 | + // bits so we can stretch across the entire 0 - 255 range. | ||
92 | + const int r8 = r5 << 3 | r5 >> 2; | ||
93 | + const int g8 = g6 << 2 | g6 >> 4; | ||
94 | + const int b8 = b5 << 3 | b5 >> 2; | ||
95 | + | ||
96 | + WriteYUV(x, y, width, r8, g8, b8, pY++, pUV); | ||
97 | + } | ||
98 | + } | ||
99 | +} |
android/android/jni/rgb2yuv.h
0 → 100644
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_ | ||
17 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_ | ||
18 | + | ||
19 | +#include <stdint.h> | ||
20 | + | ||
21 | +#ifdef __cplusplus | ||
22 | +extern "C" { | ||
23 | +#endif | ||
24 | + | ||
25 | +void ConvertARGB8888ToYUV420SP(const uint32_t* const input, | ||
26 | + uint8_t* const output, int width, int height); | ||
27 | + | ||
28 | +void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output, | ||
29 | + const int width, const int height); | ||
30 | + | ||
31 | +#ifdef __cplusplus | ||
32 | +} | ||
33 | +#endif | ||
34 | + | ||
35 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_ |
android/android/jni/version_script.lds
0 → 100644
android/android/jni/yuv2rgb.cc
0 → 100644
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// This is a collection of routines which converts various YUV image formats | ||
17 | +// to ARGB. | ||
18 | + | ||
19 | +#include "tensorflow/examples/android/jni/yuv2rgb.h" | ||
20 | + | ||
21 | +#ifndef MAX | ||
22 | +#define MAX(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a > _b ? _a : _b; }) | ||
23 | +#define MIN(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a < _b ? _a : _b; }) | ||
24 | +#endif | ||
25 | + | ||
26 | +// This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges | ||
27 | +// are normalized to eight bits. | ||
28 | +static const int kMaxChannelValue = 262143; | ||
29 | + | ||
30 | +static inline uint32_t YUV2RGB(int nY, int nU, int nV) { | ||
31 | + nY -= 16; | ||
32 | + nU -= 128; | ||
33 | + nV -= 128; | ||
34 | + if (nY < 0) nY = 0; | ||
35 | + | ||
36 | + // This is the floating point equivalent. We do the conversion in integer | ||
37 | + // because some Android devices do not have floating point in hardware. | ||
38 | + // nR = (int)(1.164 * nY + 2.018 * nU); | ||
39 | + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); | ||
40 | + // nB = (int)(1.164 * nY + 1.596 * nV); | ||
41 | + | ||
42 | + int nR = 1192 * nY + 1634 * nV; | ||
43 | + int nG = 1192 * nY - 833 * nV - 400 * nU; | ||
44 | + int nB = 1192 * nY + 2066 * nU; | ||
45 | + | ||
46 | + nR = MIN(kMaxChannelValue, MAX(0, nR)); | ||
47 | + nG = MIN(kMaxChannelValue, MAX(0, nG)); | ||
48 | + nB = MIN(kMaxChannelValue, MAX(0, nB)); | ||
49 | + | ||
50 | + nR = (nR >> 10) & 0xff; | ||
51 | + nG = (nG >> 10) & 0xff; | ||
52 | + nB = (nB >> 10) & 0xff; | ||
53 | + | ||
54 | + return 0xff000000 | (nR << 16) | (nG << 8) | nB; | ||
55 | +} | ||
56 | + | ||
57 | +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by | ||
58 | +// separate u and v planes with arbitrary row and column strides, | ||
59 | +// containing 8 bit 2x2 subsampled chroma samples. | ||
60 | +// Converts to a packed ARGB 32 bit output of the same pixel dimensions. | ||
61 | +void ConvertYUV420ToARGB8888(const uint8_t* const yData, | ||
62 | + const uint8_t* const uData, | ||
63 | + const uint8_t* const vData, uint32_t* const output, | ||
64 | + const int width, const int height, | ||
65 | + const int y_row_stride, const int uv_row_stride, | ||
66 | + const int uv_pixel_stride) { | ||
67 | + uint32_t* out = output; | ||
68 | + | ||
69 | + for (int y = 0; y < height; y++) { | ||
70 | + const uint8_t* pY = yData + y_row_stride * y; | ||
71 | + | ||
72 | + const int uv_row_start = uv_row_stride * (y >> 1); | ||
73 | + const uint8_t* pU = uData + uv_row_start; | ||
74 | + const uint8_t* pV = vData + uv_row_start; | ||
75 | + | ||
76 | + for (int x = 0; x < width; x++) { | ||
77 | + const int uv_offset = (x >> 1) * uv_pixel_stride; | ||
78 | + *out++ = YUV2RGB(pY[x], pU[uv_offset], pV[uv_offset]); | ||
79 | + } | ||
80 | + } | ||
81 | +} | ||
82 | + | ||
83 | +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an | ||
84 | +// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples, | ||
85 | +// except the interleave order of U and V is reversed. Converts to a packed | ||
86 | +// ARGB 32 bit output of the same pixel dimensions. | ||
87 | +void ConvertYUV420SPToARGB8888(const uint8_t* const yData, | ||
88 | + const uint8_t* const uvData, | ||
89 | + uint32_t* const output, const int width, | ||
90 | + const int height) { | ||
91 | + const uint8_t* pY = yData; | ||
92 | + const uint8_t* pUV = uvData; | ||
93 | + uint32_t* out = output; | ||
94 | + | ||
95 | + for (int y = 0; y < height; y++) { | ||
96 | + for (int x = 0; x < width; x++) { | ||
97 | + int nY = *pY++; | ||
98 | + int offset = (y >> 1) * width + 2 * (x >> 1); | ||
99 | +#ifdef __APPLE__ | ||
100 | + int nU = pUV[offset]; | ||
101 | + int nV = pUV[offset + 1]; | ||
102 | +#else | ||
103 | + int nV = pUV[offset]; | ||
104 | + int nU = pUV[offset + 1]; | ||
105 | +#endif | ||
106 | + | ||
107 | + *out++ = YUV2RGB(nY, nU, nV); | ||
108 | + } | ||
109 | + } | ||
110 | +} | ||
111 | + | ||
112 | +// The same as above, but downsamples each dimension to half size. | ||
113 | +void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input, | ||
114 | + uint32_t* const output, int width, | ||
115 | + int height) { | ||
116 | + const uint8_t* pY = input; | ||
117 | + const uint8_t* pUV = input + (width * height); | ||
118 | + uint32_t* out = output; | ||
119 | + int stride = width; | ||
120 | + width >>= 1; | ||
121 | + height >>= 1; | ||
122 | + | ||
123 | + for (int y = 0; y < height; y++) { | ||
124 | + for (int x = 0; x < width; x++) { | ||
125 | + int nY = (pY[0] + pY[1] + pY[stride] + pY[stride + 1]) >> 2; | ||
126 | + pY += 2; | ||
127 | +#ifdef __APPLE__ | ||
128 | + int nU = *pUV++; | ||
129 | + int nV = *pUV++; | ||
130 | +#else | ||
131 | + int nV = *pUV++; | ||
132 | + int nU = *pUV++; | ||
133 | +#endif | ||
134 | + | ||
135 | + *out++ = YUV2RGB(nY, nU, nV); | ||
136 | + } | ||
137 | + pY += stride; | ||
138 | + } | ||
139 | +} | ||
140 | + | ||
141 | +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an | ||
142 | +// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples, | ||
143 | +// except the interleave order of U and V is reversed. Converts to a packed | ||
144 | +// RGB 565 bit output of the same pixel dimensions. | ||
145 | +void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output, | ||
146 | + const int width, const int height) { | ||
147 | + const uint8_t* pY = input; | ||
148 | + const uint8_t* pUV = input + (width * height); | ||
149 | + uint16_t* out = output; | ||
150 | + | ||
151 | + for (int y = 0; y < height; y++) { | ||
152 | + for (int x = 0; x < width; x++) { | ||
153 | + int nY = *pY++; | ||
154 | + int offset = (y >> 1) * width + 2 * (x >> 1); | ||
155 | +#ifdef __APPLE__ | ||
156 | + int nU = pUV[offset]; | ||
157 | + int nV = pUV[offset + 1]; | ||
158 | +#else | ||
159 | + int nV = pUV[offset]; | ||
160 | + int nU = pUV[offset + 1]; | ||
161 | +#endif | ||
162 | + | ||
163 | + nY -= 16; | ||
164 | + nU -= 128; | ||
165 | + nV -= 128; | ||
166 | + if (nY < 0) nY = 0; | ||
167 | + | ||
168 | + // This is the floating point equivalent. We do the conversion in integer | ||
169 | + // because some Android devices do not have floating point in hardware. | ||
170 | + // nR = (int)(1.164 * nY + 2.018 * nU); | ||
171 | + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); | ||
172 | + // nB = (int)(1.164 * nY + 1.596 * nV); | ||
173 | + | ||
174 | + int nR = 1192 * nY + 1634 * nV; | ||
175 | + int nG = 1192 * nY - 833 * nV - 400 * nU; | ||
176 | + int nB = 1192 * nY + 2066 * nU; | ||
177 | + | ||
178 | + nR = MIN(kMaxChannelValue, MAX(0, nR)); | ||
179 | + nG = MIN(kMaxChannelValue, MAX(0, nG)); | ||
180 | + nB = MIN(kMaxChannelValue, MAX(0, nB)); | ||
181 | + | ||
182 | + // Shift more than for ARGB8888 and apply appropriate bitmask. | ||
183 | + nR = (nR >> 13) & 0x1f; | ||
184 | + nG = (nG >> 12) & 0x3f; | ||
185 | + nB = (nB >> 13) & 0x1f; | ||
186 | + | ||
187 | + // R is high 5 bits, G is middle 6 bits, and B is low 5 bits. | ||
188 | + *out++ = (nR << 11) | (nG << 5) | nB; | ||
189 | + } | ||
190 | + } | ||
191 | +} |
android/android/jni/yuv2rgb.h
0 → 100644
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +// This is a collection of routines which converts various YUV image formats | ||
17 | +// to (A)RGB. | ||
18 | + | ||
19 | +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_ | ||
20 | +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_ | ||
21 | + | ||
22 | +#include <stdint.h> | ||
23 | + | ||
24 | +#ifdef __cplusplus | ||
25 | +extern "C" { | ||
26 | +#endif | ||
27 | + | ||
28 | +void ConvertYUV420ToARGB8888(const uint8_t* const yData, | ||
29 | + const uint8_t* const uData, | ||
30 | + const uint8_t* const vData, uint32_t* const output, | ||
31 | + const int width, const int height, | ||
32 | + const int y_row_stride, const int uv_row_stride, | ||
33 | + const int uv_pixel_stride); | ||
34 | + | ||
35 | +// Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width | ||
36 | +// and height. The input and output must already be allocated and non-null. | ||
37 | +// For efficiency, no error checking is performed. | ||
38 | +void ConvertYUV420SPToARGB8888(const uint8_t* const pY, | ||
39 | + const uint8_t* const pUV, uint32_t* const output, | ||
40 | + const int width, const int height); | ||
41 | + | ||
42 | +// The same as above, but downsamples each dimension to half size. | ||
43 | +void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input, | ||
44 | + uint32_t* const output, int width, | ||
45 | + int height); | ||
46 | + | ||
47 | +// Converts YUV420 semi-planar data to RGB 565 data using the supplied width | ||
48 | +// and height. The input and output must already be allocated and non-null. | ||
49 | +// For efficiency, no error checking is performed. | ||
50 | +void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output, | ||
51 | + const int width, const int height); | ||
52 | + | ||
53 | +#ifdef __cplusplus | ||
54 | +} | ||
55 | +#endif | ||
56 | + | ||
57 | +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_ |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<set xmlns:android="http://schemas.android.com/apk/res/android" | ||
17 | + android:ordering="sequentially"> | ||
18 | + <objectAnimator | ||
19 | + android:propertyName="backgroundColor" | ||
20 | + android:duration="375" | ||
21 | + android:valueFrom="0x00b3ccff" | ||
22 | + android:valueTo="0xffb3ccff" | ||
23 | + android:valueType="colorType"/> | ||
24 | + <objectAnimator | ||
25 | + android:propertyName="backgroundColor" | ||
26 | + android:duration="375" | ||
27 | + android:valueFrom="0xffb3ccff" | ||
28 | + android:valueTo="0x00b3ccff" | ||
29 | + android:valueType="colorType"/> | ||
30 | +</set> |
1 KB
4.21 KB
android/android/res/drawable-hdpi/tile.9.png
0 → 100644
196 Bytes
665 Bytes
2.21 KB
1.32 KB
6.53 KB
2.21 KB
12.4 KB
android/android/res/drawable/border.xml
0 → 100644
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle" > | ||
17 | + <solid android:color="#00000000" /> | ||
18 | + <stroke android:width="1dip" android:color="#cccccc" /> | ||
19 | +</shape> |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android" | ||
17 | + xmlns:tools="http://schemas.android.com/tools" | ||
18 | + android:id="@+id/container" | ||
19 | + android:layout_width="match_parent" | ||
20 | + android:layout_height="match_parent" | ||
21 | + android:background="#000" | ||
22 | + tools:context="org.tensorflow.demo.CameraActivity" /> |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<FrameLayout | ||
17 | + xmlns:android="http://schemas.android.com/apk/res/android" | ||
18 | + xmlns:app="http://schemas.android.com/apk/res-auto" | ||
19 | + xmlns:tools="http://schemas.android.com/tools" | ||
20 | + android:layout_width="match_parent" | ||
21 | + android:layout_height="match_parent" | ||
22 | + tools:context="org.tensorflow.demo.SpeechActivity"> | ||
23 | + | ||
24 | + <TextView | ||
25 | + android:layout_width="wrap_content" | ||
26 | + android:layout_height="wrap_content" | ||
27 | + android:text="Say one of the words below!" | ||
28 | + android:id="@+id/textView" | ||
29 | + android:textAlignment="center" | ||
30 | + android:layout_gravity="top" | ||
31 | + android:textSize="24dp" | ||
32 | + android:layout_marginTop="10dp" | ||
33 | + android:layout_marginLeft="10dp" | ||
34 | + /> | ||
35 | + | ||
36 | + <ListView | ||
37 | + android:id="@+id/list_view" | ||
38 | + android:layout_width="240dp" | ||
39 | + android:layout_height="wrap_content" | ||
40 | + android:background="@drawable/border" | ||
41 | + android:layout_gravity="top|center_horizontal" | ||
42 | + android:textAlignment="center" | ||
43 | + android:layout_marginTop="100dp" | ||
44 | + /> | ||
45 | + | ||
46 | + <Button | ||
47 | + android:id="@+id/quit" | ||
48 | + android:layout_width="wrap_content" | ||
49 | + android:layout_height="wrap_content" | ||
50 | + android:text="Quit" | ||
51 | + android:layout_gravity="bottom|center_horizontal" | ||
52 | + android:layout_marginBottom="10dp" | ||
53 | + /> | ||
54 | + | ||
55 | +</FrameLayout> |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" | ||
17 | + android:layout_width="match_parent" | ||
18 | + android:layout_height="match_parent"> | ||
19 | + | ||
20 | + <org.tensorflow.demo.AutoFitTextureView | ||
21 | + android:id="@+id/texture" | ||
22 | + android:layout_width="wrap_content" | ||
23 | + android:layout_height="wrap_content" | ||
24 | + android:layout_alignParentBottom="true" /> | ||
25 | + | ||
26 | + <org.tensorflow.demo.RecognitionScoreView | ||
27 | + android:id="@+id/results" | ||
28 | + android:layout_width="match_parent" | ||
29 | + android:layout_height="112dp" | ||
30 | + android:layout_alignParentTop="true" /> | ||
31 | + | ||
32 | + <org.tensorflow.demo.OverlayView | ||
33 | + android:id="@+id/debug_overlay" | ||
34 | + android:layout_width="match_parent" | ||
35 | + android:layout_height="match_parent" | ||
36 | + android:layout_alignParentBottom="true" /> | ||
37 | + | ||
38 | +</RelativeLayout> |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" | ||
17 | + android:orientation="vertical" | ||
18 | + android:layout_width="match_parent" | ||
19 | + android:layout_height="match_parent"> | ||
20 | + <org.tensorflow.demo.AutoFitTextureView | ||
21 | + android:id="@+id/texture" | ||
22 | + android:layout_width="wrap_content" | ||
23 | + android:layout_height="wrap_content" | ||
24 | + android:layout_alignParentTop="true" /> | ||
25 | + | ||
26 | + <RelativeLayout | ||
27 | + android:id="@+id/black" | ||
28 | + android:layout_width="match_parent" | ||
29 | + android:layout_height="match_parent" | ||
30 | + android:background="#FF000000" /> | ||
31 | + | ||
32 | + <GridView | ||
33 | + android:id="@+id/grid_layout" | ||
34 | + android:numColumns="7" | ||
35 | + android:stretchMode="columnWidth" | ||
36 | + android:layout_alignParentBottom="true" | ||
37 | + android:layout_width="match_parent" | ||
38 | + android:layout_height="wrap_content" /> | ||
39 | + | ||
40 | + <org.tensorflow.demo.OverlayView | ||
41 | + android:id="@+id/overlay" | ||
42 | + android:layout_width="match_parent" | ||
43 | + android:layout_height="match_parent" | ||
44 | + android:layout_alignParentTop="true" /> | ||
45 | + | ||
46 | + <org.tensorflow.demo.OverlayView | ||
47 | + android:id="@+id/debug_overlay" | ||
48 | + android:layout_width="match_parent" | ||
49 | + android:layout_height="match_parent" | ||
50 | + android:layout_alignParentTop="true" /> | ||
51 | +</RelativeLayout> |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android" | ||
17 | + android:layout_width="match_parent" | ||
18 | + android:layout_height="match_parent"> | ||
19 | + | ||
20 | + <org.tensorflow.demo.AutoFitTextureView | ||
21 | + android:id="@+id/texture" | ||
22 | + android:layout_width="wrap_content" | ||
23 | + android:layout_height="wrap_content"/> | ||
24 | + | ||
25 | + <org.tensorflow.demo.OverlayView | ||
26 | + android:id="@+id/tracking_overlay" | ||
27 | + android:layout_width="match_parent" | ||
28 | + android:layout_height="match_parent"/> | ||
29 | + | ||
30 | + <org.tensorflow.demo.OverlayView | ||
31 | + android:id="@+id/debug_overlay" | ||
32 | + android:layout_width="match_parent" | ||
33 | + android:layout_height="match_parent"/> | ||
34 | +</FrameLayout> |
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<TextView | ||
17 | + xmlns:android="http://schemas.android.com/apk/res/android" | ||
18 | + android:id="@+id/list_text_item" | ||
19 | + android:layout_width="match_parent" | ||
20 | + android:layout_height="wrap_content" | ||
21 | + android:text="TextView" | ||
22 | + android:textSize="24dp" | ||
23 | + android:textAlignment="center" | ||
24 | + android:gravity="center_horizontal" | ||
25 | + /> |
1 | +<!-- | ||
2 | + Copyright 2013 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | + --> | ||
16 | + | ||
17 | +<resources> | ||
18 | + | ||
19 | + <!-- Semantic definitions --> | ||
20 | + | ||
21 | + <dimen name="horizontal_page_margin">@dimen/margin_huge</dimen> | ||
22 | + <dimen name="vertical_page_margin">@dimen/margin_medium</dimen> | ||
23 | + | ||
24 | +</resources> |
1 | +<!-- | ||
2 | + Copyright 2013 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | + --> | ||
16 | + | ||
17 | +<resources> | ||
18 | + | ||
19 | + <style name="Widget.SampleMessage"> | ||
20 | + <item name="android:textAppearance">?android:textAppearanceLarge</item> | ||
21 | + <item name="android:lineSpacingMultiplier">1.2</item> | ||
22 | + <item name="android:shadowDy">-6.5</item> | ||
23 | + </style> | ||
24 | + | ||
25 | +</resources> |
android/android/res/values-v11/styles.xml
0 → 100644
1 | +<?xml version="1.0" encoding="utf-8"?> | ||
2 | +<resources> | ||
3 | + | ||
4 | + <!-- | ||
5 | + Base application theme for API 11+. This theme completely replaces | ||
6 | + AppBaseTheme from res/values/styles.xml on API 11+ devices. | ||
7 | + --> | ||
8 | + <style name="AppBaseTheme" parent="android:Theme.Holo.Light"> | ||
9 | + <!-- API 11 theme customizations can go here. --> | ||
10 | + </style> | ||
11 | + | ||
12 | + <style name="FullscreenTheme" parent="android:Theme.Holo"> | ||
13 | + <item name="android:actionBarStyle">@style/FullscreenActionBarStyle</item> | ||
14 | + <item name="android:windowActionBarOverlay">true</item> | ||
15 | + <item name="android:windowBackground">@null</item> | ||
16 | + <item name="metaButtonBarStyle">?android:attr/buttonBarStyle</item> | ||
17 | + <item name="metaButtonBarButtonStyle">?android:attr/buttonBarButtonStyle</item> | ||
18 | + </style> | ||
19 | + | ||
20 | + <style name="FullscreenActionBarStyle" parent="android:Widget.Holo.ActionBar"> | ||
21 | + <!-- <item name="android:background">@color/black_overlay</item> --> | ||
22 | + </style> | ||
23 | + | ||
24 | +</resources> |
1 | +<!-- | ||
2 | + Copyright 2013 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | + --> | ||
16 | + | ||
17 | +<resources> | ||
18 | + | ||
19 | + <!-- Activity themes --> | ||
20 | + <style name="Theme.Base" parent="android:Theme.Holo.Light" /> | ||
21 | + | ||
22 | +</resources> |
android/android/res/values-v14/styles.xml
0 → 100644
1 | +<resources> | ||
2 | + | ||
3 | + <!-- | ||
4 | + Base application theme for API 14+. This theme completely replaces | ||
5 | + AppBaseTheme from BOTH res/values/styles.xml and | ||
6 | + res/values-v11/styles.xml on API 14+ devices. | ||
7 | + --> | ||
8 | + <style name="AppBaseTheme" parent="android:Theme.Holo.Light.DarkActionBar"> | ||
9 | + <!-- API 14 theme customizations can go here. --> | ||
10 | + </style> | ||
11 | + | ||
12 | +</resources> |
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<!-- | ||
3 | + Copyright 2013 The TensorFlow Authors. All Rights Reserved. | ||
4 | + | ||
5 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | + you may not use this file except in compliance with the License. | ||
7 | + You may obtain a copy of the License at | ||
8 | + | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | + | ||
11 | + Unless required by applicable law or agreed to in writing, software | ||
12 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | + See the License for the specific language governing permissions and | ||
15 | + limitations under the License. | ||
16 | +--> | ||
17 | + | ||
18 | +<resources> | ||
19 | + | ||
20 | + | ||
21 | +</resources> |
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<!-- | ||
3 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
4 | + | ||
5 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | + you may not use this file except in compliance with the License. | ||
7 | + You may obtain a copy of the License at | ||
8 | + | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | + | ||
11 | + Unless required by applicable law or agreed to in writing, software | ||
12 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | + See the License for the specific language governing permissions and | ||
15 | + limitations under the License. | ||
16 | +--> | ||
17 | + | ||
18 | +<resources> | ||
19 | + | ||
20 | + <!-- Activity themes --> | ||
21 | + <style name="Theme.Base" parent="android:Theme.Material.Light"> | ||
22 | + </style> | ||
23 | + | ||
24 | +</resources> |
android/android/res/values/attrs.xml
0 → 100644
1 | +<resources> | ||
2 | + | ||
3 | + <!-- | ||
4 | + Declare custom theme attributes that allow changing which styles are | ||
5 | + used for button bars depending on the API level. | ||
6 | + ?android:attr/buttonBarStyle is new as of API 11 so this is | ||
7 | + necessary to support previous API levels. | ||
8 | + --> | ||
9 | + <declare-styleable name="ButtonBarContainerTheme"> | ||
10 | + <attr name="metaButtonBarStyle" format="reference" /> | ||
11 | + <attr name="metaButtonBarButtonStyle" format="reference" /> | ||
12 | + </declare-styleable> | ||
13 | + | ||
14 | +</resources> |
android/android/res/values/base-strings.xml
0 → 100644
1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
2 | +<!-- | ||
3 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
4 | + | ||
5 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | + you may not use this file except in compliance with the License. | ||
7 | + You may obtain a copy of the License at | ||
8 | + | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | + | ||
11 | + Unless required by applicable law or agreed to in writing, software | ||
12 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | + See the License for the specific language governing permissions and | ||
15 | + limitations under the License. | ||
16 | +--> | ||
17 | + | ||
18 | +<resources> | ||
19 | + <string name="app_name">TensorFlow Demo</string> | ||
20 | + <string name="activity_name_classification">TF Classify</string> | ||
21 | + <string name="activity_name_detection">TF Detect</string> | ||
22 | + <string name="activity_name_stylize">TF Stylize</string> | ||
23 | + <string name="activity_name_speech">TF Speech</string> | ||
24 | +</resources> |
android/android/res/values/colors.xml
0 → 100644
1 | +<?xml version="1.0" encoding="utf-8"?> | ||
2 | +<!-- | ||
3 | + Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
4 | + | ||
5 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | + you may not use this file except in compliance with the License. | ||
7 | + You may obtain a copy of the License at | ||
8 | + | ||
9 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | + | ||
11 | + Unless required by applicable law or agreed to in writing, software | ||
12 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | + See the License for the specific language governing permissions and | ||
15 | + limitations under the License. | ||
16 | +--> | ||
17 | +<resources> | ||
18 | + <color name="control_background">#cc4285f4</color> | ||
19 | +</resources> |
android/android/res/values/strings.xml
0 → 100644
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<resources> | ||
17 | + <string name="description_info">Info</string> | ||
18 | + <string name="request_permission">This sample needs camera permission.</string> | ||
19 | + <string name="camera_error">This device doesn\'t support Camera2 API.</string> | ||
20 | +</resources> |
android/android/res/values/styles.xml
0 → 100644
1 | +<?xml version="1.0" encoding="utf-8"?><!-- | ||
2 | + Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | +--> | ||
16 | +<resources> | ||
17 | + <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" /> | ||
18 | +</resources> |
1 | +<!-- | ||
2 | + Copyright 2013 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | + --> | ||
16 | + | ||
17 | +<resources> | ||
18 | + | ||
19 | + <!-- Define standard dimensions to comply with Holo-style grids and rhythm. --> | ||
20 | + | ||
21 | + <dimen name="margin_tiny">4dp</dimen> | ||
22 | + <dimen name="margin_small">8dp</dimen> | ||
23 | + <dimen name="margin_medium">16dp</dimen> | ||
24 | + <dimen name="margin_large">32dp</dimen> | ||
25 | + <dimen name="margin_huge">64dp</dimen> | ||
26 | + | ||
27 | + <!-- Semantic definitions --> | ||
28 | + | ||
29 | + <dimen name="horizontal_page_margin">@dimen/margin_medium</dimen> | ||
30 | + <dimen name="vertical_page_margin">@dimen/margin_medium</dimen> | ||
31 | + | ||
32 | +</resources> |
1 | +<!-- | ||
2 | + Copyright 2013 The TensorFlow Authors. All Rights Reserved. | ||
3 | + | ||
4 | + Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + you may not use this file except in compliance with the License. | ||
6 | + You may obtain a copy of the License at | ||
7 | + | ||
8 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + | ||
10 | + Unless required by applicable law or agreed to in writing, software | ||
11 | + distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + See the License for the specific language governing permissions and | ||
14 | + limitations under the License. | ||
15 | + --> | ||
16 | + | ||
17 | +<resources> | ||
18 | + | ||
19 | + <!-- Activity themes --> | ||
20 | + | ||
21 | + <style name="Theme.Base" parent="android:Theme.Light" /> | ||
22 | + | ||
23 | + <style name="Theme.Sample" parent="Theme.Base" /> | ||
24 | + | ||
25 | + <style name="AppTheme" parent="Theme.Sample" /> | ||
26 | + <!-- Widget styling --> | ||
27 | + | ||
28 | + <style name="Widget" /> | ||
29 | + | ||
30 | + <style name="Widget.SampleMessage"> | ||
31 | + <item name="android:textAppearance">?android:textAppearanceMedium</item> | ||
32 | + <item name="android:lineSpacingMultiplier">1.1</item> | ||
33 | + </style> | ||
34 | + | ||
35 | + <style name="Widget.SampleMessageTile"> | ||
36 | + <item name="android:background">@drawable/tile</item> | ||
37 | + <item name="android:shadowColor">#7F000000</item> | ||
38 | + <item name="android:shadowDy">-3.5</item> | ||
39 | + <item name="android:shadowRadius">2</item> | ||
40 | + </style> | ||
41 | + | ||
42 | +</resources> |
android/android/sample_images/classify1.jpg
0 → 100644
23.9 KB
android/android/sample_images/detect1.jpg
0 → 100644
44.7 KB
android/android/sample_images/stylize1.jpg
0 → 100644
51.2 KB
1 | +/* | ||
2 | + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.content.Context; | ||
20 | +import android.util.AttributeSet; | ||
21 | +import android.view.TextureView; | ||
22 | + | ||
23 | +/** | ||
24 | + * A {@link TextureView} that can be adjusted to a specified aspect ratio. | ||
25 | + */ | ||
26 | +public class AutoFitTextureView extends TextureView { | ||
27 | + private int ratioWidth = 0; | ||
28 | + private int ratioHeight = 0; | ||
29 | + | ||
30 | + public AutoFitTextureView(final Context context) { | ||
31 | + this(context, null); | ||
32 | + } | ||
33 | + | ||
34 | + public AutoFitTextureView(final Context context, final AttributeSet attrs) { | ||
35 | + this(context, attrs, 0); | ||
36 | + } | ||
37 | + | ||
38 | + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { | ||
39 | + super(context, attrs, defStyle); | ||
40 | + } | ||
41 | + | ||
42 | + /** | ||
43 | + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio | ||
44 | + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that | ||
45 | + * is, calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. | ||
46 | + * | ||
47 | + * @param width Relative horizontal size | ||
48 | + * @param height Relative vertical size | ||
49 | + */ | ||
50 | + public void setAspectRatio(final int width, final int height) { | ||
51 | + if (width < 0 || height < 0) { | ||
52 | + throw new IllegalArgumentException("Size cannot be negative."); | ||
53 | + } | ||
54 | + ratioWidth = width; | ||
55 | + ratioHeight = height; | ||
56 | + requestLayout(); | ||
57 | + } | ||
58 | + | ||
59 | + @Override | ||
60 | + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { | ||
61 | + super.onMeasure(widthMeasureSpec, heightMeasureSpec); | ||
62 | + final int width = MeasureSpec.getSize(widthMeasureSpec); | ||
63 | + final int height = MeasureSpec.getSize(heightMeasureSpec); | ||
64 | + if (0 == ratioWidth || 0 == ratioHeight) { | ||
65 | + setMeasuredDimension(width, height); | ||
66 | + } else { | ||
67 | + if (width < height * ratioWidth / ratioHeight) { | ||
68 | + setMeasuredDimension(width, width * ratioHeight / ratioWidth); | ||
69 | + } else { | ||
70 | + setMeasuredDimension(height * ratioWidth / ratioHeight, height); | ||
71 | + } | ||
72 | + } | ||
73 | + } | ||
74 | +} |
1 | +/* | ||
2 | + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.Manifest; | ||
20 | +import android.app.Activity; | ||
21 | +import android.app.Fragment; | ||
22 | +import android.content.Context; | ||
23 | +import android.content.pm.PackageManager; | ||
24 | +import android.hardware.Camera; | ||
25 | +import android.hardware.camera2.CameraAccessException; | ||
26 | +import android.hardware.camera2.CameraCharacteristics; | ||
27 | +import android.hardware.camera2.CameraManager; | ||
28 | +import android.hardware.camera2.params.StreamConfigurationMap; | ||
29 | +import android.media.Image; | ||
30 | +import android.media.Image.Plane; | ||
31 | +import android.media.ImageReader; | ||
32 | +import android.media.ImageReader.OnImageAvailableListener; | ||
33 | +import android.os.Build; | ||
34 | +import android.os.Bundle; | ||
35 | +import android.os.Handler; | ||
36 | +import android.os.HandlerThread; | ||
37 | +import android.os.Trace; | ||
38 | +import android.util.Size; | ||
39 | +import android.view.KeyEvent; | ||
40 | +import android.view.Surface; | ||
41 | +import android.view.WindowManager; | ||
42 | +import android.widget.Toast; | ||
43 | +import java.nio.ByteBuffer; | ||
44 | +import org.tensorflow.demo.env.ImageUtils; | ||
45 | +import org.tensorflow.demo.env.Logger; | ||
46 | +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. | ||
47 | + | ||
48 | +public abstract class CameraActivity extends Activity | ||
49 | + implements OnImageAvailableListener, Camera.PreviewCallback { | ||
50 | + private static final Logger LOGGER = new Logger(); | ||
51 | + | ||
52 | + private static final int PERMISSIONS_REQUEST = 1; | ||
53 | + | ||
54 | + private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; | ||
55 | + private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE; | ||
56 | + | ||
57 | + private boolean debug = false; | ||
58 | + | ||
59 | + private Handler handler; | ||
60 | + private HandlerThread handlerThread; | ||
61 | + private boolean useCamera2API; | ||
62 | + private boolean isProcessingFrame = false; | ||
63 | + private byte[][] yuvBytes = new byte[3][]; | ||
64 | + private int[] rgbBytes = null; | ||
65 | + private int yRowStride; | ||
66 | + | ||
67 | + protected int previewWidth = 0; | ||
68 | + protected int previewHeight = 0; | ||
69 | + | ||
70 | + private Runnable postInferenceCallback; | ||
71 | + private Runnable imageConverter; | ||
72 | + | ||
73 | + @Override | ||
74 | + protected void onCreate(final Bundle savedInstanceState) { | ||
75 | + LOGGER.d("onCreate " + this); | ||
76 | + super.onCreate(null); | ||
77 | + getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); | ||
78 | + | ||
79 | + setContentView(R.layout.activity_camera); | ||
80 | + | ||
81 | + if (hasPermission()) { | ||
82 | + setFragment(); | ||
83 | + } else { | ||
84 | + requestPermission(); | ||
85 | + } | ||
86 | + } | ||
87 | + | ||
88 | + private byte[] lastPreviewFrame; | ||
89 | + | ||
90 | + protected int[] getRgbBytes() { | ||
91 | + imageConverter.run(); | ||
92 | + return rgbBytes; | ||
93 | + } | ||
94 | + | ||
95 | + protected int getLuminanceStride() { | ||
96 | + return yRowStride; | ||
97 | + } | ||
98 | + | ||
99 | + protected byte[] getLuminance() { | ||
100 | + return yuvBytes[0]; | ||
101 | + } | ||
102 | + | ||
103 | + /** | ||
104 | + * Callback for android.hardware.Camera API | ||
105 | + */ | ||
106 | + @Override | ||
107 | + public void onPreviewFrame(final byte[] bytes, final Camera camera) { | ||
108 | + if (isProcessingFrame) { | ||
109 | + LOGGER.w("Dropping frame!"); | ||
110 | + return; | ||
111 | + } | ||
112 | + | ||
113 | + try { | ||
114 | + // Initialize the storage bitmaps once when the resolution is known. | ||
115 | + if (rgbBytes == null) { | ||
116 | + Camera.Size previewSize = camera.getParameters().getPreviewSize(); | ||
117 | + previewHeight = previewSize.height; | ||
118 | + previewWidth = previewSize.width; | ||
119 | + rgbBytes = new int[previewWidth * previewHeight]; | ||
120 | + onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); | ||
121 | + } | ||
122 | + } catch (final Exception e) { | ||
123 | + LOGGER.e(e, "Exception!"); | ||
124 | + return; | ||
125 | + } | ||
126 | + | ||
127 | + isProcessingFrame = true; | ||
128 | + lastPreviewFrame = bytes; | ||
129 | + yuvBytes[0] = bytes; | ||
130 | + yRowStride = previewWidth; | ||
131 | + | ||
132 | + imageConverter = | ||
133 | + new Runnable() { | ||
134 | + @Override | ||
135 | + public void run() { | ||
136 | + ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); | ||
137 | + } | ||
138 | + }; | ||
139 | + | ||
140 | + postInferenceCallback = | ||
141 | + new Runnable() { | ||
142 | + @Override | ||
143 | + public void run() { | ||
144 | + camera.addCallbackBuffer(bytes); | ||
145 | + isProcessingFrame = false; | ||
146 | + } | ||
147 | + }; | ||
148 | + processImage(); | ||
149 | + } | ||
150 | + | ||
151 | + /** | ||
152 | + * Callback for Camera2 API | ||
153 | + */ | ||
154 | + @Override | ||
155 | + public void onImageAvailable(final ImageReader reader) { | ||
156 | + //We need wait until we have some size from onPreviewSizeChosen | ||
157 | + if (previewWidth == 0 || previewHeight == 0) { | ||
158 | + return; | ||
159 | + } | ||
160 | + if (rgbBytes == null) { | ||
161 | + rgbBytes = new int[previewWidth * previewHeight]; | ||
162 | + } | ||
163 | + try { | ||
164 | + final Image image = reader.acquireLatestImage(); | ||
165 | + | ||
166 | + if (image == null) { | ||
167 | + return; | ||
168 | + } | ||
169 | + | ||
170 | + if (isProcessingFrame) { | ||
171 | + image.close(); | ||
172 | + return; | ||
173 | + } | ||
174 | + isProcessingFrame = true; | ||
175 | + Trace.beginSection("imageAvailable"); | ||
176 | + final Plane[] planes = image.getPlanes(); | ||
177 | + fillBytes(planes, yuvBytes); | ||
178 | + yRowStride = planes[0].getRowStride(); | ||
179 | + final int uvRowStride = planes[1].getRowStride(); | ||
180 | + final int uvPixelStride = planes[1].getPixelStride(); | ||
181 | + | ||
182 | + imageConverter = | ||
183 | + new Runnable() { | ||
184 | + @Override | ||
185 | + public void run() { | ||
186 | + ImageUtils.convertYUV420ToARGB8888( | ||
187 | + yuvBytes[0], | ||
188 | + yuvBytes[1], | ||
189 | + yuvBytes[2], | ||
190 | + previewWidth, | ||
191 | + previewHeight, | ||
192 | + yRowStride, | ||
193 | + uvRowStride, | ||
194 | + uvPixelStride, | ||
195 | + rgbBytes); | ||
196 | + } | ||
197 | + }; | ||
198 | + | ||
199 | + postInferenceCallback = | ||
200 | + new Runnable() { | ||
201 | + @Override | ||
202 | + public void run() { | ||
203 | + image.close(); | ||
204 | + isProcessingFrame = false; | ||
205 | + } | ||
206 | + }; | ||
207 | + | ||
208 | + processImage(); | ||
209 | + } catch (final Exception e) { | ||
210 | + LOGGER.e(e, "Exception!"); | ||
211 | + Trace.endSection(); | ||
212 | + return; | ||
213 | + } | ||
214 | + Trace.endSection(); | ||
215 | + } | ||
216 | + | ||
217 | + @Override | ||
218 | + public synchronized void onStart() { | ||
219 | + LOGGER.d("onStart " + this); | ||
220 | + super.onStart(); | ||
221 | + } | ||
222 | + | ||
223 | + @Override | ||
224 | + public synchronized void onResume() { | ||
225 | + LOGGER.d("onResume " + this); | ||
226 | + super.onResume(); | ||
227 | + | ||
228 | + handlerThread = new HandlerThread("inference"); | ||
229 | + handlerThread.start(); | ||
230 | + handler = new Handler(handlerThread.getLooper()); | ||
231 | + } | ||
232 | + | ||
233 | + @Override | ||
234 | + public synchronized void onPause() { | ||
235 | + LOGGER.d("onPause " + this); | ||
236 | + | ||
237 | + if (!isFinishing()) { | ||
238 | + LOGGER.d("Requesting finish"); | ||
239 | + finish(); | ||
240 | + } | ||
241 | + | ||
242 | + handlerThread.quitSafely(); | ||
243 | + try { | ||
244 | + handlerThread.join(); | ||
245 | + handlerThread = null; | ||
246 | + handler = null; | ||
247 | + } catch (final InterruptedException e) { | ||
248 | + LOGGER.e(e, "Exception!"); | ||
249 | + } | ||
250 | + | ||
251 | + super.onPause(); | ||
252 | + } | ||
253 | + | ||
254 | + @Override | ||
255 | + public synchronized void onStop() { | ||
256 | + LOGGER.d("onStop " + this); | ||
257 | + super.onStop(); | ||
258 | + } | ||
259 | + | ||
260 | + @Override | ||
261 | + public synchronized void onDestroy() { | ||
262 | + LOGGER.d("onDestroy " + this); | ||
263 | + super.onDestroy(); | ||
264 | + } | ||
265 | + | ||
266 | + protected synchronized void runInBackground(final Runnable r) { | ||
267 | + if (handler != null) { | ||
268 | + handler.post(r); | ||
269 | + } | ||
270 | + } | ||
271 | + | ||
272 | + @Override | ||
273 | + public void onRequestPermissionsResult( | ||
274 | + final int requestCode, final String[] permissions, final int[] grantResults) { | ||
275 | + if (requestCode == PERMISSIONS_REQUEST) { | ||
276 | + if (grantResults.length > 0 | ||
277 | + && grantResults[0] == PackageManager.PERMISSION_GRANTED | ||
278 | + && grantResults[1] == PackageManager.PERMISSION_GRANTED) { | ||
279 | + setFragment(); | ||
280 | + } else { | ||
281 | + requestPermission(); | ||
282 | + } | ||
283 | + } | ||
284 | + } | ||
285 | + | ||
286 | + private boolean hasPermission() { | ||
287 | + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { | ||
288 | + return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED && | ||
289 | + checkSelfPermission(PERMISSION_STORAGE) == PackageManager.PERMISSION_GRANTED; | ||
290 | + } else { | ||
291 | + return true; | ||
292 | + } | ||
293 | + } | ||
294 | + | ||
295 | + private void requestPermission() { | ||
296 | + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { | ||
297 | + if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA) || | ||
298 | + shouldShowRequestPermissionRationale(PERMISSION_STORAGE)) { | ||
299 | + Toast.makeText(CameraActivity.this, | ||
300 | + "Camera AND storage permission are required for this demo", Toast.LENGTH_LONG).show(); | ||
301 | + } | ||
302 | + requestPermissions(new String[] {PERMISSION_CAMERA, PERMISSION_STORAGE}, PERMISSIONS_REQUEST); | ||
303 | + } | ||
304 | + } | ||
305 | + | ||
306 | + // Returns true if the device supports the required hardware level, or better. | ||
307 | + private boolean isHardwareLevelSupported( | ||
308 | + CameraCharacteristics characteristics, int requiredLevel) { | ||
309 | + int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); | ||
310 | + if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { | ||
311 | + return requiredLevel == deviceLevel; | ||
312 | + } | ||
313 | + // deviceLevel is not LEGACY, can use numerical sort | ||
314 | + return requiredLevel <= deviceLevel; | ||
315 | + } | ||
316 | + | ||
317 | + private String chooseCamera() { | ||
318 | + final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); | ||
319 | + try { | ||
320 | + for (final String cameraId : manager.getCameraIdList()) { | ||
321 | + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); | ||
322 | + | ||
323 | + // We don't use a front facing camera in this sample. | ||
324 | + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); | ||
325 | + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { | ||
326 | + continue; | ||
327 | + } | ||
328 | + | ||
329 | + final StreamConfigurationMap map = | ||
330 | + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); | ||
331 | + | ||
332 | + if (map == null) { | ||
333 | + continue; | ||
334 | + } | ||
335 | + | ||
336 | + // Fallback to camera1 API for internal cameras that don't have full support. | ||
337 | + // This should help with legacy situations where using the camera2 API causes | ||
338 | + // distorted or otherwise broken previews. | ||
339 | + useCamera2API = (facing == CameraCharacteristics.LENS_FACING_EXTERNAL) | ||
340 | + || isHardwareLevelSupported(characteristics, | ||
341 | + CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); | ||
342 | + LOGGER.i("Camera API lv2?: %s", useCamera2API); | ||
343 | + return cameraId; | ||
344 | + } | ||
345 | + } catch (CameraAccessException e) { | ||
346 | + LOGGER.e(e, "Not allowed to access camera"); | ||
347 | + } | ||
348 | + | ||
349 | + return null; | ||
350 | + } | ||
351 | + | ||
352 | + protected void setFragment() { | ||
353 | + String cameraId = chooseCamera(); | ||
354 | + if (cameraId == null) { | ||
355 | + Toast.makeText(this, "No Camera Detected", Toast.LENGTH_SHORT).show(); | ||
356 | + finish(); | ||
357 | + } | ||
358 | + | ||
359 | + Fragment fragment; | ||
360 | + if (useCamera2API) { | ||
361 | + CameraConnectionFragment camera2Fragment = | ||
362 | + CameraConnectionFragment.newInstance( | ||
363 | + new CameraConnectionFragment.ConnectionCallback() { | ||
364 | + @Override | ||
365 | + public void onPreviewSizeChosen(final Size size, final int rotation) { | ||
366 | + previewHeight = size.getHeight(); | ||
367 | + previewWidth = size.getWidth(); | ||
368 | + CameraActivity.this.onPreviewSizeChosen(size, rotation); | ||
369 | + } | ||
370 | + }, | ||
371 | + this, | ||
372 | + getLayoutId(), | ||
373 | + getDesiredPreviewFrameSize()); | ||
374 | + | ||
375 | + camera2Fragment.setCamera(cameraId); | ||
376 | + fragment = camera2Fragment; | ||
377 | + } else { | ||
378 | + fragment = | ||
379 | + new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); | ||
380 | + } | ||
381 | + | ||
382 | + getFragmentManager() | ||
383 | + .beginTransaction() | ||
384 | + .replace(R.id.container, fragment) | ||
385 | + .commit(); | ||
386 | + } | ||
387 | + | ||
388 | + protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { | ||
389 | + // Because of the variable row stride it's not possible to know in | ||
390 | + // advance the actual necessary dimensions of the yuv planes. | ||
391 | + for (int i = 0; i < planes.length; ++i) { | ||
392 | + final ByteBuffer buffer = planes[i].getBuffer(); | ||
393 | + if (yuvBytes[i] == null) { | ||
394 | + LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); | ||
395 | + yuvBytes[i] = new byte[buffer.capacity()]; | ||
396 | + } | ||
397 | + buffer.get(yuvBytes[i]); | ||
398 | + } | ||
399 | + } | ||
400 | + | ||
401 | + public boolean isDebug() { | ||
402 | + return debug; | ||
403 | + } | ||
404 | + | ||
405 | + public void requestRender() { | ||
406 | + final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay); | ||
407 | + if (overlay != null) { | ||
408 | + overlay.postInvalidate(); | ||
409 | + } | ||
410 | + } | ||
411 | + | ||
412 | + public void addCallback(final OverlayView.DrawCallback callback) { | ||
413 | + final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay); | ||
414 | + if (overlay != null) { | ||
415 | + overlay.addCallback(callback); | ||
416 | + } | ||
417 | + } | ||
418 | + | ||
419 | + public void onSetDebug(final boolean debug) {} | ||
420 | + | ||
421 | + @Override | ||
422 | + public boolean onKeyDown(final int keyCode, final KeyEvent event) { | ||
423 | + if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP | ||
424 | + || keyCode == KeyEvent.KEYCODE_BUTTON_L1 || keyCode == KeyEvent.KEYCODE_DPAD_CENTER) { | ||
425 | + debug = !debug; | ||
426 | + requestRender(); | ||
427 | + onSetDebug(debug); | ||
428 | + return true; | ||
429 | + } | ||
430 | + return super.onKeyDown(keyCode, event); | ||
431 | + } | ||
432 | + | ||
433 | + protected void readyForNextImage() { | ||
434 | + if (postInferenceCallback != null) { | ||
435 | + postInferenceCallback.run(); | ||
436 | + } | ||
437 | + } | ||
438 | + | ||
439 | + protected int getScreenOrientation() { | ||
440 | + switch (getWindowManager().getDefaultDisplay().getRotation()) { | ||
441 | + case Surface.ROTATION_270: | ||
442 | + return 270; | ||
443 | + case Surface.ROTATION_180: | ||
444 | + return 180; | ||
445 | + case Surface.ROTATION_90: | ||
446 | + return 90; | ||
447 | + default: | ||
448 | + return 0; | ||
449 | + } | ||
450 | + } | ||
451 | + | ||
452 | + protected abstract void processImage(); | ||
453 | + | ||
454 | + protected abstract void onPreviewSizeChosen(final Size size, final int rotation); | ||
455 | + protected abstract int getLayoutId(); | ||
456 | + protected abstract Size getDesiredPreviewFrameSize(); | ||
457 | +} |
1 | +/* | ||
2 | + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.app.Activity; | ||
20 | +import android.app.AlertDialog; | ||
21 | +import android.app.Dialog; | ||
22 | +import android.app.DialogFragment; | ||
23 | +import android.app.Fragment; | ||
24 | +import android.content.Context; | ||
25 | +import android.content.DialogInterface; | ||
26 | +import android.content.res.Configuration; | ||
27 | +import android.graphics.ImageFormat; | ||
28 | +import android.graphics.Matrix; | ||
29 | +import android.graphics.RectF; | ||
30 | +import android.graphics.SurfaceTexture; | ||
31 | +import android.hardware.camera2.CameraAccessException; | ||
32 | +import android.hardware.camera2.CameraCaptureSession; | ||
33 | +import android.hardware.camera2.CameraCharacteristics; | ||
34 | +import android.hardware.camera2.CameraDevice; | ||
35 | +import android.hardware.camera2.CameraManager; | ||
36 | +import android.hardware.camera2.CaptureRequest; | ||
37 | +import android.hardware.camera2.CaptureResult; | ||
38 | +import android.hardware.camera2.TotalCaptureResult; | ||
39 | +import android.hardware.camera2.params.StreamConfigurationMap; | ||
40 | +import android.media.ImageReader; | ||
41 | +import android.media.ImageReader.OnImageAvailableListener; | ||
42 | +import android.os.Bundle; | ||
43 | +import android.os.Handler; | ||
44 | +import android.os.HandlerThread; | ||
45 | +import android.text.TextUtils; | ||
46 | +import android.util.Size; | ||
47 | +import android.util.SparseIntArray; | ||
48 | +import android.view.LayoutInflater; | ||
49 | +import android.view.Surface; | ||
50 | +import android.view.TextureView; | ||
51 | +import android.view.View; | ||
52 | +import android.view.ViewGroup; | ||
53 | +import android.widget.Toast; | ||
54 | +import java.util.ArrayList; | ||
55 | +import java.util.Arrays; | ||
56 | +import java.util.Collections; | ||
57 | +import java.util.Comparator; | ||
58 | +import java.util.List; | ||
59 | +import java.util.concurrent.Semaphore; | ||
60 | +import java.util.concurrent.TimeUnit; | ||
61 | +import org.tensorflow.demo.env.Logger; | ||
62 | +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. | ||
63 | + | ||
64 | +public class CameraConnectionFragment extends Fragment { | ||
65 | + private static final Logger LOGGER = new Logger(); | ||
66 | + | ||
67 | + /** | ||
68 | + * The camera preview size will be chosen to be the smallest frame by pixel size capable of | ||
69 | + * containing a DESIRED_SIZE x DESIRED_SIZE square. | ||
70 | + */ | ||
71 | + private static final int MINIMUM_PREVIEW_SIZE = 320; | ||
72 | + | ||
73 | + /** | ||
74 | + * Conversion from screen rotation to JPEG orientation. | ||
75 | + */ | ||
76 | + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); | ||
77 | + private static final String FRAGMENT_DIALOG = "dialog"; | ||
78 | + | ||
79 | + static { | ||
80 | + ORIENTATIONS.append(Surface.ROTATION_0, 90); | ||
81 | + ORIENTATIONS.append(Surface.ROTATION_90, 0); | ||
82 | + ORIENTATIONS.append(Surface.ROTATION_180, 270); | ||
83 | + ORIENTATIONS.append(Surface.ROTATION_270, 180); | ||
84 | + } | ||
85 | + | ||
86 | + /** | ||
87 | + * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a | ||
88 | + * {@link TextureView}. | ||
89 | + */ | ||
90 | + private final TextureView.SurfaceTextureListener surfaceTextureListener = | ||
91 | + new TextureView.SurfaceTextureListener() { | ||
92 | + @Override | ||
93 | + public void onSurfaceTextureAvailable( | ||
94 | + final SurfaceTexture texture, final int width, final int height) { | ||
95 | + openCamera(width, height); | ||
96 | + } | ||
97 | + | ||
98 | + @Override | ||
99 | + public void onSurfaceTextureSizeChanged( | ||
100 | + final SurfaceTexture texture, final int width, final int height) { | ||
101 | + configureTransform(width, height); | ||
102 | + } | ||
103 | + | ||
104 | + @Override | ||
105 | + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { | ||
106 | + return true; | ||
107 | + } | ||
108 | + | ||
109 | + @Override | ||
110 | + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} | ||
111 | + }; | ||
112 | + | ||
113 | + /** | ||
114 | + * Callback for Activities to use to initialize their data once the | ||
115 | + * selected preview size is known. | ||
116 | + */ | ||
117 | + public interface ConnectionCallback { | ||
118 | + void onPreviewSizeChosen(Size size, int cameraRotation); | ||
119 | + } | ||
120 | + | ||
121 | + /** | ||
122 | + * ID of the current {@link CameraDevice}. | ||
123 | + */ | ||
124 | + private String cameraId; | ||
125 | + | ||
126 | + /** | ||
127 | + * An {@link AutoFitTextureView} for camera preview. | ||
128 | + */ | ||
129 | + private AutoFitTextureView textureView; | ||
130 | + | ||
131 | + /** | ||
132 | + * A {@link CameraCaptureSession } for camera preview. | ||
133 | + */ | ||
134 | + private CameraCaptureSession captureSession; | ||
135 | + | ||
136 | + /** | ||
137 | + * A reference to the opened {@link CameraDevice}. | ||
138 | + */ | ||
139 | + private CameraDevice cameraDevice; | ||
140 | + | ||
141 | + /** | ||
142 | + * The rotation in degrees of the camera sensor from the display. | ||
143 | + */ | ||
144 | + private Integer sensorOrientation; | ||
145 | + | ||
146 | + /** | ||
147 | + * The {@link android.util.Size} of camera preview. | ||
148 | + */ | ||
149 | + private Size previewSize; | ||
150 | + | ||
151 | + /** | ||
152 | + * {@link android.hardware.camera2.CameraDevice.StateCallback} | ||
153 | + * is called when {@link CameraDevice} changes its state. | ||
154 | + */ | ||
155 | + private final CameraDevice.StateCallback stateCallback = | ||
156 | + new CameraDevice.StateCallback() { | ||
157 | + @Override | ||
158 | + public void onOpened(final CameraDevice cd) { | ||
159 | + // This method is called when the camera is opened. We start camera preview here. | ||
160 | + cameraOpenCloseLock.release(); | ||
161 | + cameraDevice = cd; | ||
162 | + createCameraPreviewSession(); | ||
163 | + } | ||
164 | + | ||
165 | + @Override | ||
166 | + public void onDisconnected(final CameraDevice cd) { | ||
167 | + cameraOpenCloseLock.release(); | ||
168 | + cd.close(); | ||
169 | + cameraDevice = null; | ||
170 | + } | ||
171 | + | ||
172 | + @Override | ||
173 | + public void onError(final CameraDevice cd, final int error) { | ||
174 | + cameraOpenCloseLock.release(); | ||
175 | + cd.close(); | ||
176 | + cameraDevice = null; | ||
177 | + final Activity activity = getActivity(); | ||
178 | + if (null != activity) { | ||
179 | + activity.finish(); | ||
180 | + } | ||
181 | + } | ||
182 | + }; | ||
183 | + | ||
184 | + /** | ||
185 | + * An additional thread for running tasks that shouldn't block the UI. | ||
186 | + */ | ||
187 | + private HandlerThread backgroundThread; | ||
188 | + | ||
189 | + /** | ||
190 | + * A {@link Handler} for running tasks in the background. | ||
191 | + */ | ||
192 | + private Handler backgroundHandler; | ||
193 | + | ||
194 | + /** | ||
195 | + * An {@link ImageReader} that handles preview frame capture. | ||
196 | + */ | ||
197 | + private ImageReader previewReader; | ||
198 | + | ||
199 | + /** | ||
200 | + * {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview | ||
201 | + */ | ||
202 | + private CaptureRequest.Builder previewRequestBuilder; | ||
203 | + | ||
204 | + /** | ||
205 | + * {@link CaptureRequest} generated by {@link #previewRequestBuilder} | ||
206 | + */ | ||
207 | + private CaptureRequest previewRequest; | ||
208 | + | ||
209 | + /** | ||
210 | + * A {@link Semaphore} to prevent the app from exiting before closing the camera. | ||
211 | + */ | ||
212 | + private final Semaphore cameraOpenCloseLock = new Semaphore(1); | ||
213 | + | ||
214 | + /** | ||
215 | + * A {@link OnImageAvailableListener} to receive frames as they are available. | ||
216 | + */ | ||
217 | + private final OnImageAvailableListener imageListener; | ||
218 | + | ||
219 | + /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */ | ||
220 | + private final Size inputSize; | ||
221 | + | ||
222 | + /** | ||
223 | + * The layout identifier to inflate for this Fragment. | ||
224 | + */ | ||
225 | + private final int layout; | ||
226 | + | ||
227 | + | ||
228 | + private final ConnectionCallback cameraConnectionCallback; | ||
229 | + | ||
230 | + private CameraConnectionFragment( | ||
231 | + final ConnectionCallback connectionCallback, | ||
232 | + final OnImageAvailableListener imageListener, | ||
233 | + final int layout, | ||
234 | + final Size inputSize) { | ||
235 | + this.cameraConnectionCallback = connectionCallback; | ||
236 | + this.imageListener = imageListener; | ||
237 | + this.layout = layout; | ||
238 | + this.inputSize = inputSize; | ||
239 | + } | ||
240 | + | ||
241 | + /** | ||
242 | + * Shows a {@link Toast} on the UI thread. | ||
243 | + * | ||
244 | + * @param text The message to show | ||
245 | + */ | ||
246 | + private void showToast(final String text) { | ||
247 | + final Activity activity = getActivity(); | ||
248 | + if (activity != null) { | ||
249 | + activity.runOnUiThread( | ||
250 | + new Runnable() { | ||
251 | + @Override | ||
252 | + public void run() { | ||
253 | + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); | ||
254 | + } | ||
255 | + }); | ||
256 | + } | ||
257 | + } | ||
258 | + | ||
259 | + /** | ||
260 | + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose | ||
261 | + * width and height are at least as large as the minimum of both, or an exact match if possible. | ||
262 | + * | ||
263 | + * @param choices The list of sizes that the camera supports for the intended output class | ||
264 | + * @param width The minimum desired width | ||
265 | + * @param height The minimum desired height | ||
266 | + * @return The optimal {@code Size}, or an arbitrary one if none were big enough | ||
267 | + */ | ||
268 | + protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) { | ||
269 | + final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE); | ||
270 | + final Size desiredSize = new Size(width, height); | ||
271 | + | ||
272 | + // Collect the supported resolutions that are at least as big as the preview Surface | ||
273 | + boolean exactSizeFound = false; | ||
274 | + final List<Size> bigEnough = new ArrayList<Size>(); | ||
275 | + final List<Size> tooSmall = new ArrayList<Size>(); | ||
276 | + for (final Size option : choices) { | ||
277 | + if (option.equals(desiredSize)) { | ||
278 | + // Set the size but don't return yet so that remaining sizes will still be logged. | ||
279 | + exactSizeFound = true; | ||
280 | + } | ||
281 | + | ||
282 | + if (option.getHeight() >= minSize && option.getWidth() >= minSize) { | ||
283 | + bigEnough.add(option); | ||
284 | + } else { | ||
285 | + tooSmall.add(option); | ||
286 | + } | ||
287 | + } | ||
288 | + | ||
289 | + LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize); | ||
290 | + LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]"); | ||
291 | + LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]"); | ||
292 | + | ||
293 | + if (exactSizeFound) { | ||
294 | + LOGGER.i("Exact size match found."); | ||
295 | + return desiredSize; | ||
296 | + } | ||
297 | + | ||
298 | + // Pick the smallest of those, assuming we found any | ||
299 | + if (bigEnough.size() > 0) { | ||
300 | + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); | ||
301 | + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); | ||
302 | + return chosenSize; | ||
303 | + } else { | ||
304 | + LOGGER.e("Couldn't find any suitable preview size"); | ||
305 | + return choices[0]; | ||
306 | + } | ||
307 | + } | ||
308 | + | ||
309 | + public static CameraConnectionFragment newInstance( | ||
310 | + final ConnectionCallback callback, | ||
311 | + final OnImageAvailableListener imageListener, | ||
312 | + final int layout, | ||
313 | + final Size inputSize) { | ||
314 | + return new CameraConnectionFragment(callback, imageListener, layout, inputSize); | ||
315 | + } | ||
316 | + | ||
317 | + @Override | ||
318 | + public View onCreateView( | ||
319 | + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { | ||
320 | + return inflater.inflate(layout, container, false); | ||
321 | + } | ||
322 | + | ||
323 | + @Override | ||
324 | + public void onViewCreated(final View view, final Bundle savedInstanceState) { | ||
325 | + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); | ||
326 | + } | ||
327 | + | ||
328 | + @Override | ||
329 | + public void onActivityCreated(final Bundle savedInstanceState) { | ||
330 | + super.onActivityCreated(savedInstanceState); | ||
331 | + } | ||
332 | + | ||
333 | + @Override | ||
334 | + public void onResume() { | ||
335 | + super.onResume(); | ||
336 | + startBackgroundThread(); | ||
337 | + | ||
338 | + // When the screen is turned off and turned back on, the SurfaceTexture is already | ||
339 | + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open | ||
340 | + // a camera and start preview from here (otherwise, we wait until the surface is ready in | ||
341 | + // the SurfaceTextureListener). | ||
342 | + if (textureView.isAvailable()) { | ||
343 | + openCamera(textureView.getWidth(), textureView.getHeight()); | ||
344 | + } else { | ||
345 | + textureView.setSurfaceTextureListener(surfaceTextureListener); | ||
346 | + } | ||
347 | + } | ||
348 | + | ||
349 | + @Override | ||
350 | + public void onPause() { | ||
351 | + closeCamera(); | ||
352 | + stopBackgroundThread(); | ||
353 | + super.onPause(); | ||
354 | + } | ||
355 | + | ||
356 | + public void setCamera(String cameraId) { | ||
357 | + this.cameraId = cameraId; | ||
358 | + } | ||
359 | + | ||
360 | + /** | ||
361 | + * Sets up member variables related to camera. | ||
362 | + */ | ||
363 | + private void setUpCameraOutputs() { | ||
364 | + final Activity activity = getActivity(); | ||
365 | + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); | ||
366 | + try { | ||
367 | + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); | ||
368 | + | ||
369 | + final StreamConfigurationMap map = | ||
370 | + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); | ||
371 | + | ||
372 | + // For still image captures, we use the largest available size. | ||
373 | + final Size largest = | ||
374 | + Collections.max( | ||
375 | + Arrays.asList(map.getOutputSizes(ImageFormat.YUV_420_888)), | ||
376 | + new CompareSizesByArea()); | ||
377 | + | ||
378 | + sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); | ||
379 | + | ||
380 | + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera | ||
381 | + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of | ||
382 | + // garbage capture data. | ||
383 | + previewSize = | ||
384 | + chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class), | ||
385 | + inputSize.getWidth(), | ||
386 | + inputSize.getHeight()); | ||
387 | + | ||
388 | + // We fit the aspect ratio of TextureView to the size of preview we picked. | ||
389 | + final int orientation = getResources().getConfiguration().orientation; | ||
390 | + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { | ||
391 | + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); | ||
392 | + } else { | ||
393 | + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); | ||
394 | + } | ||
395 | + } catch (final CameraAccessException e) { | ||
396 | + LOGGER.e(e, "Exception!"); | ||
397 | + } catch (final NullPointerException e) { | ||
398 | + // Currently an NPE is thrown when the Camera2API is used but not supported on the | ||
399 | + // device this code runs. | ||
400 | + // TODO(andrewharp): abstract ErrorDialog/RuntimeException handling out into new method and | ||
401 | + // reuse throughout app. | ||
402 | + ErrorDialog.newInstance(getString(R.string.camera_error)) | ||
403 | + .show(getChildFragmentManager(), FRAGMENT_DIALOG); | ||
404 | + throw new RuntimeException(getString(R.string.camera_error)); | ||
405 | + } | ||
406 | + | ||
407 | + cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); | ||
408 | + } | ||
409 | + | ||
410 | + /** | ||
411 | + * Opens the camera specified by {@link CameraConnectionFragment#cameraId}. | ||
412 | + */ | ||
413 | + private void openCamera(final int width, final int height) { | ||
414 | + setUpCameraOutputs(); | ||
415 | + configureTransform(width, height); | ||
416 | + final Activity activity = getActivity(); | ||
417 | + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); | ||
418 | + try { | ||
419 | + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { | ||
420 | + throw new RuntimeException("Time out waiting to lock camera opening."); | ||
421 | + } | ||
422 | + manager.openCamera(cameraId, stateCallback, backgroundHandler); | ||
423 | + } catch (final CameraAccessException e) { | ||
424 | + LOGGER.e(e, "Exception!"); | ||
425 | + } catch (final InterruptedException e) { | ||
426 | + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); | ||
427 | + } | ||
428 | + } | ||
429 | + | ||
430 | + /** | ||
431 | + * Closes the current {@link CameraDevice}. | ||
432 | + */ | ||
433 | + private void closeCamera() { | ||
434 | + try { | ||
435 | + cameraOpenCloseLock.acquire(); | ||
436 | + if (null != captureSession) { | ||
437 | + captureSession.close(); | ||
438 | + captureSession = null; | ||
439 | + } | ||
440 | + if (null != cameraDevice) { | ||
441 | + cameraDevice.close(); | ||
442 | + cameraDevice = null; | ||
443 | + } | ||
444 | + if (null != previewReader) { | ||
445 | + previewReader.close(); | ||
446 | + previewReader = null; | ||
447 | + } | ||
448 | + } catch (final InterruptedException e) { | ||
449 | + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); | ||
450 | + } finally { | ||
451 | + cameraOpenCloseLock.release(); | ||
452 | + } | ||
453 | + } | ||
454 | + | ||
455 | + /** | ||
456 | + * Starts a background thread and its {@link Handler}. | ||
457 | + */ | ||
458 | + private void startBackgroundThread() { | ||
459 | + backgroundThread = new HandlerThread("ImageListener"); | ||
460 | + backgroundThread.start(); | ||
461 | + backgroundHandler = new Handler(backgroundThread.getLooper()); | ||
462 | + } | ||
463 | + | ||
464 | + /** | ||
465 | + * Stops the background thread and its {@link Handler}. | ||
466 | + */ | ||
467 | + private void stopBackgroundThread() { | ||
468 | + backgroundThread.quitSafely(); | ||
469 | + try { | ||
470 | + backgroundThread.join(); | ||
471 | + backgroundThread = null; | ||
472 | + backgroundHandler = null; | ||
473 | + } catch (final InterruptedException e) { | ||
474 | + LOGGER.e(e, "Exception!"); | ||
475 | + } | ||
476 | + } | ||
477 | + | ||
478 | + private final CameraCaptureSession.CaptureCallback captureCallback = | ||
479 | + new CameraCaptureSession.CaptureCallback() { | ||
480 | + @Override | ||
481 | + public void onCaptureProgressed( | ||
482 | + final CameraCaptureSession session, | ||
483 | + final CaptureRequest request, | ||
484 | + final CaptureResult partialResult) {} | ||
485 | + | ||
486 | + @Override | ||
487 | + public void onCaptureCompleted( | ||
488 | + final CameraCaptureSession session, | ||
489 | + final CaptureRequest request, | ||
490 | + final TotalCaptureResult result) {} | ||
491 | + }; | ||
492 | + | ||
493 | + /** | ||
494 | + * Creates a new {@link CameraCaptureSession} for camera preview. | ||
495 | + */ | ||
496 | + private void createCameraPreviewSession() { | ||
497 | + try { | ||
498 | + final SurfaceTexture texture = textureView.getSurfaceTexture(); | ||
499 | + assert texture != null; | ||
500 | + | ||
501 | + // We configure the size of default buffer to be the size of camera preview we want. | ||
502 | + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); | ||
503 | + | ||
504 | + // This is the output Surface we need to start preview. | ||
505 | + final Surface surface = new Surface(texture); | ||
506 | + | ||
507 | + // We set up a CaptureRequest.Builder with the output Surface. | ||
508 | + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); | ||
509 | + previewRequestBuilder.addTarget(surface); | ||
510 | + | ||
511 | + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); | ||
512 | + | ||
513 | + // Create the reader for the preview frames. | ||
514 | + previewReader = | ||
515 | + ImageReader.newInstance( | ||
516 | + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); | ||
517 | + | ||
518 | + previewReader.setOnImageAvailableListener(imageListener, backgroundHandler); | ||
519 | + previewRequestBuilder.addTarget(previewReader.getSurface()); | ||
520 | + | ||
521 | + // Here, we create a CameraCaptureSession for camera preview. | ||
522 | + cameraDevice.createCaptureSession( | ||
523 | + Arrays.asList(surface, previewReader.getSurface()), | ||
524 | + new CameraCaptureSession.StateCallback() { | ||
525 | + | ||
526 | + @Override | ||
527 | + public void onConfigured(final CameraCaptureSession cameraCaptureSession) { | ||
528 | + // The camera is already closed | ||
529 | + if (null == cameraDevice) { | ||
530 | + return; | ||
531 | + } | ||
532 | + | ||
533 | + // When the session is ready, we start displaying the preview. | ||
534 | + captureSession = cameraCaptureSession; | ||
535 | + try { | ||
536 | + // Auto focus should be continuous for camera preview. | ||
537 | + previewRequestBuilder.set( | ||
538 | + CaptureRequest.CONTROL_AF_MODE, | ||
539 | + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); | ||
540 | + // Flash is automatically enabled when necessary. | ||
541 | + previewRequestBuilder.set( | ||
542 | + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); | ||
543 | + | ||
544 | + // Finally, we start displaying the camera preview. | ||
545 | + previewRequest = previewRequestBuilder.build(); | ||
546 | + captureSession.setRepeatingRequest( | ||
547 | + previewRequest, captureCallback, backgroundHandler); | ||
548 | + } catch (final CameraAccessException e) { | ||
549 | + LOGGER.e(e, "Exception!"); | ||
550 | + } | ||
551 | + } | ||
552 | + | ||
553 | + @Override | ||
554 | + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { | ||
555 | + showToast("Failed"); | ||
556 | + } | ||
557 | + }, | ||
558 | + null); | ||
559 | + } catch (final CameraAccessException e) { | ||
560 | + LOGGER.e(e, "Exception!"); | ||
561 | + } | ||
562 | + } | ||
563 | + | ||
564 | + /** | ||
565 | + * Configures the necessary {@link android.graphics.Matrix} transformation to `mTextureView`. | ||
566 | + * This method should be called after the camera preview size is determined in | ||
567 | + * setUpCameraOutputs and also the size of `mTextureView` is fixed. | ||
568 | + * | ||
569 | + * @param viewWidth The width of `mTextureView` | ||
570 | + * @param viewHeight The height of `mTextureView` | ||
571 | + */ | ||
572 | + private void configureTransform(final int viewWidth, final int viewHeight) { | ||
573 | + final Activity activity = getActivity(); | ||
574 | + if (null == textureView || null == previewSize || null == activity) { | ||
575 | + return; | ||
576 | + } | ||
577 | + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); | ||
578 | + final Matrix matrix = new Matrix(); | ||
579 | + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); | ||
580 | + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); | ||
581 | + final float centerX = viewRect.centerX(); | ||
582 | + final float centerY = viewRect.centerY(); | ||
583 | + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { | ||
584 | + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); | ||
585 | + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); | ||
586 | + final float scale = | ||
587 | + Math.max( | ||
588 | + (float) viewHeight / previewSize.getHeight(), | ||
589 | + (float) viewWidth / previewSize.getWidth()); | ||
590 | + matrix.postScale(scale, scale, centerX, centerY); | ||
591 | + matrix.postRotate(90 * (rotation - 2), centerX, centerY); | ||
592 | + } else if (Surface.ROTATION_180 == rotation) { | ||
593 | + matrix.postRotate(180, centerX, centerY); | ||
594 | + } | ||
595 | + textureView.setTransform(matrix); | ||
596 | + } | ||
597 | + | ||
598 | + /** | ||
599 | + * Compares two {@code Size}s based on their areas. | ||
600 | + */ | ||
601 | + static class CompareSizesByArea implements Comparator<Size> { | ||
602 | + @Override | ||
603 | + public int compare(final Size lhs, final Size rhs) { | ||
604 | + // We cast here to ensure the multiplications won't overflow | ||
605 | + return Long.signum( | ||
606 | + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); | ||
607 | + } | ||
608 | + } | ||
609 | + | ||
610 | + /** | ||
611 | + * Shows an error message dialog. | ||
612 | + */ | ||
613 | + public static class ErrorDialog extends DialogFragment { | ||
614 | + private static final String ARG_MESSAGE = "message"; | ||
615 | + | ||
616 | + public static ErrorDialog newInstance(final String message) { | ||
617 | + final ErrorDialog dialog = new ErrorDialog(); | ||
618 | + final Bundle args = new Bundle(); | ||
619 | + args.putString(ARG_MESSAGE, message); | ||
620 | + dialog.setArguments(args); | ||
621 | + return dialog; | ||
622 | + } | ||
623 | + | ||
624 | + @Override | ||
625 | + public Dialog onCreateDialog(final Bundle savedInstanceState) { | ||
626 | + final Activity activity = getActivity(); | ||
627 | + return new AlertDialog.Builder(activity) | ||
628 | + .setMessage(getArguments().getString(ARG_MESSAGE)) | ||
629 | + .setPositiveButton( | ||
630 | + android.R.string.ok, | ||
631 | + new DialogInterface.OnClickListener() { | ||
632 | + @Override | ||
633 | + public void onClick(final DialogInterface dialogInterface, final int i) { | ||
634 | + activity.finish(); | ||
635 | + } | ||
636 | + }) | ||
637 | + .create(); | ||
638 | + } | ||
639 | + } | ||
640 | +} |
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.graphics.Bitmap; | ||
19 | +import android.graphics.RectF; | ||
20 | +import java.util.List; | ||
21 | + | ||
22 | +/** | ||
23 | + * Generic interface for interacting with different recognition engines. | ||
24 | + */ | ||
25 | +public interface Classifier { | ||
26 | + /** | ||
27 | + * An immutable result returned by a Classifier describing what was recognized. | ||
28 | + */ | ||
29 | + public class Recognition { | ||
30 | + /** | ||
31 | + * A unique identifier for what has been recognized. Specific to the class, not the instance of | ||
32 | + * the object. | ||
33 | + */ | ||
34 | + private final String id; | ||
35 | + | ||
36 | + /** | ||
37 | + * Display name for the recognition. | ||
38 | + */ | ||
39 | + private final String title; | ||
40 | + | ||
41 | + /** | ||
42 | + * A sortable score for how good the recognition is relative to others. Higher should be better. | ||
43 | + */ | ||
44 | + private final Float confidence; | ||
45 | + | ||
46 | + /** Optional location within the source image for the location of the recognized object. */ | ||
47 | + private RectF location; | ||
48 | + | ||
49 | + public Recognition( | ||
50 | + final String id, final String title, final Float confidence, final RectF location) { | ||
51 | + this.id = id; | ||
52 | + this.title = title; | ||
53 | + this.confidence = confidence; | ||
54 | + this.location = location; | ||
55 | + } | ||
56 | + | ||
57 | + public String getId() { | ||
58 | + return id; | ||
59 | + } | ||
60 | + | ||
61 | + public String getTitle() { | ||
62 | + return title; | ||
63 | + } | ||
64 | + | ||
65 | + public Float getConfidence() { | ||
66 | + return confidence; | ||
67 | + } | ||
68 | + | ||
69 | + public RectF getLocation() { | ||
70 | + return new RectF(location); | ||
71 | + } | ||
72 | + | ||
73 | + public void setLocation(RectF location) { | ||
74 | + this.location = location; | ||
75 | + } | ||
76 | + | ||
77 | + @Override | ||
78 | + public String toString() { | ||
79 | + String resultString = ""; | ||
80 | + if (id != null) { | ||
81 | + resultString += "[" + id + "] "; | ||
82 | + } | ||
83 | + | ||
84 | + if (title != null) { | ||
85 | + resultString += title + " "; | ||
86 | + } | ||
87 | + | ||
88 | + if (confidence != null) { | ||
89 | + resultString += String.format("(%.1f%%) ", confidence * 100.0f); | ||
90 | + } | ||
91 | + | ||
92 | + if (location != null) { | ||
93 | + resultString += location + " "; | ||
94 | + } | ||
95 | + | ||
96 | + return resultString.trim(); | ||
97 | + } | ||
98 | + } | ||
99 | + | ||
100 | + List<Recognition> recognizeImage(Bitmap bitmap); | ||
101 | + | ||
102 | + void enableStatLogging(final boolean debug); | ||
103 | + | ||
104 | + String getStatString(); | ||
105 | + | ||
106 | + void close(); | ||
107 | +} |
1 | +/* | ||
2 | + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.graphics.Bitmap; | ||
20 | +import android.graphics.Bitmap.Config; | ||
21 | +import android.graphics.Canvas; | ||
22 | +import android.graphics.Matrix; | ||
23 | +import android.graphics.Paint; | ||
24 | +import android.graphics.Typeface; | ||
25 | +import android.media.ImageReader.OnImageAvailableListener; | ||
26 | +import android.os.SystemClock; | ||
27 | +import android.util.Size; | ||
28 | +import android.util.TypedValue; | ||
29 | +import android.view.Display; | ||
30 | +import android.view.Surface; | ||
31 | +import java.util.List; | ||
32 | +import java.util.Vector; | ||
33 | +import org.tensorflow.demo.OverlayView.DrawCallback; | ||
34 | +import org.tensorflow.demo.env.BorderedText; | ||
35 | +import org.tensorflow.demo.env.ImageUtils; | ||
36 | +import org.tensorflow.demo.env.Logger; | ||
37 | +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. | ||
38 | + | ||
39 | +public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener { | ||
40 | + private static final Logger LOGGER = new Logger(); | ||
41 | + | ||
42 | + protected static final boolean SAVE_PREVIEW_BITMAP = false; | ||
43 | + | ||
44 | + private ResultsView resultsView; | ||
45 | + | ||
46 | + private Bitmap rgbFrameBitmap = null; | ||
47 | + private Bitmap croppedBitmap = null; | ||
48 | + private Bitmap cropCopyBitmap = null; | ||
49 | + | ||
50 | + private long lastProcessingTimeMs; | ||
51 | + | ||
52 | + // These are the settings for the original v1 Inception model. If you want to | ||
53 | + // use a model that's been produced from the TensorFlow for Poets codelab, | ||
54 | + // you'll need to set IMAGE_SIZE = 299, IMAGE_MEAN = 128, IMAGE_STD = 128, | ||
55 | + // INPUT_NAME = "Mul", and OUTPUT_NAME = "final_result". | ||
56 | + // You'll also need to update the MODEL_FILE and LABEL_FILE paths to point to | ||
57 | + // the ones you produced. | ||
58 | + // | ||
59 | + // To use v3 Inception model, strip the DecodeJpeg Op from your retrained | ||
60 | + // model first: | ||
61 | + // | ||
62 | + // python strip_unused.py \ | ||
63 | + // --input_graph=<retrained-pb-file> \ | ||
64 | + // --output_graph=<your-stripped-pb-file> \ | ||
65 | + // --input_node_names="Mul" \ | ||
66 | + // --output_node_names="final_result" \ | ||
67 | + // --input_binary=true | ||
68 | + private static final int INPUT_SIZE = 224; | ||
69 | + private static final int IMAGE_MEAN = 117; | ||
70 | + private static final float IMAGE_STD = 1; | ||
71 | + private static final String INPUT_NAME = "input"; | ||
72 | + private static final String OUTPUT_NAME = "output"; | ||
73 | + | ||
74 | + | ||
75 | + private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb"; | ||
76 | + private static final String LABEL_FILE = | ||
77 | + "file:///android_asset/imagenet_comp_graph_label_strings.txt"; | ||
78 | + | ||
79 | + | ||
80 | + private static final boolean MAINTAIN_ASPECT = true; | ||
81 | + | ||
82 | + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); | ||
83 | + | ||
84 | + | ||
85 | + private Integer sensorOrientation; | ||
86 | + private Classifier classifier; | ||
87 | + private Matrix frameToCropTransform; | ||
88 | + private Matrix cropToFrameTransform; | ||
89 | + | ||
90 | + | ||
91 | + private BorderedText borderedText; | ||
92 | + | ||
93 | + | ||
94 | + @Override | ||
95 | + protected int getLayoutId() { | ||
96 | + return R.layout.camera_connection_fragment; | ||
97 | + } | ||
98 | + | ||
99 | + @Override | ||
100 | + protected Size getDesiredPreviewFrameSize() { | ||
101 | + return DESIRED_PREVIEW_SIZE; | ||
102 | + } | ||
103 | + | ||
104 | + private static final float TEXT_SIZE_DIP = 10; | ||
105 | + | ||
106 | + @Override | ||
107 | + public void onPreviewSizeChosen(final Size size, final int rotation) { | ||
108 | + final float textSizePx = TypedValue.applyDimension( | ||
109 | + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); | ||
110 | + borderedText = new BorderedText(textSizePx); | ||
111 | + borderedText.setTypeface(Typeface.MONOSPACE); | ||
112 | + | ||
113 | + classifier = | ||
114 | + TensorFlowImageClassifier.create( | ||
115 | + getAssets(), | ||
116 | + MODEL_FILE, | ||
117 | + LABEL_FILE, | ||
118 | + INPUT_SIZE, | ||
119 | + IMAGE_MEAN, | ||
120 | + IMAGE_STD, | ||
121 | + INPUT_NAME, | ||
122 | + OUTPUT_NAME); | ||
123 | + | ||
124 | + previewWidth = size.getWidth(); | ||
125 | + previewHeight = size.getHeight(); | ||
126 | + | ||
127 | + sensorOrientation = rotation - getScreenOrientation(); | ||
128 | + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); | ||
129 | + | ||
130 | + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); | ||
131 | + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); | ||
132 | + croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); | ||
133 | + | ||
134 | + frameToCropTransform = ImageUtils.getTransformationMatrix( | ||
135 | + previewWidth, previewHeight, | ||
136 | + INPUT_SIZE, INPUT_SIZE, | ||
137 | + sensorOrientation, MAINTAIN_ASPECT); | ||
138 | + | ||
139 | + cropToFrameTransform = new Matrix(); | ||
140 | + frameToCropTransform.invert(cropToFrameTransform); | ||
141 | + | ||
142 | + addCallback( | ||
143 | + new DrawCallback() { | ||
144 | + @Override | ||
145 | + public void drawCallback(final Canvas canvas) { | ||
146 | + renderDebug(canvas); | ||
147 | + } | ||
148 | + }); | ||
149 | + } | ||
150 | + | ||
151 | + @Override | ||
152 | + protected void processImage() { | ||
153 | + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); | ||
154 | + final Canvas canvas = new Canvas(croppedBitmap); | ||
155 | + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); | ||
156 | + | ||
157 | + // For examining the actual TF input. | ||
158 | + if (SAVE_PREVIEW_BITMAP) { | ||
159 | + ImageUtils.saveBitmap(croppedBitmap); | ||
160 | + } | ||
161 | + runInBackground( | ||
162 | + new Runnable() { | ||
163 | + @Override | ||
164 | + public void run() { | ||
165 | + final long startTime = SystemClock.uptimeMillis(); | ||
166 | + final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap); | ||
167 | + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; | ||
168 | + LOGGER.i("Detect: %s", results); | ||
169 | + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); | ||
170 | + if (resultsView == null) { | ||
171 | + resultsView = (ResultsView) findViewById(R.id.results); | ||
172 | + } | ||
173 | + resultsView.setResults(results); | ||
174 | + requestRender(); | ||
175 | + readyForNextImage(); | ||
176 | + } | ||
177 | + }); | ||
178 | + } | ||
179 | + | ||
180 | + @Override | ||
181 | + public void onSetDebug(boolean debug) { | ||
182 | + classifier.enableStatLogging(debug); | ||
183 | + } | ||
184 | + | ||
185 | + private void renderDebug(final Canvas canvas) { | ||
186 | + if (!isDebug()) { | ||
187 | + return; | ||
188 | + } | ||
189 | + final Bitmap copy = cropCopyBitmap; | ||
190 | + if (copy != null) { | ||
191 | + final Matrix matrix = new Matrix(); | ||
192 | + final float scaleFactor = 2; | ||
193 | + matrix.postScale(scaleFactor, scaleFactor); | ||
194 | + matrix.postTranslate( | ||
195 | + canvas.getWidth() - copy.getWidth() * scaleFactor, | ||
196 | + canvas.getHeight() - copy.getHeight() * scaleFactor); | ||
197 | + canvas.drawBitmap(copy, matrix, new Paint()); | ||
198 | + | ||
199 | + final Vector<String> lines = new Vector<String>(); | ||
200 | + if (classifier != null) { | ||
201 | + String statString = classifier.getStatString(); | ||
202 | + String[] statLines = statString.split("\n"); | ||
203 | + for (String line : statLines) { | ||
204 | + lines.add(line); | ||
205 | + } | ||
206 | + } | ||
207 | + | ||
208 | + lines.add("Frame: " + previewWidth + "x" + previewHeight); | ||
209 | + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); | ||
210 | + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); | ||
211 | + lines.add("Rotation: " + sensorOrientation); | ||
212 | + lines.add("Inference time: " + lastProcessingTimeMs + "ms"); | ||
213 | + | ||
214 | + borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); | ||
215 | + } | ||
216 | + } | ||
217 | +} |
1 | +/* | ||
2 | + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.graphics.Bitmap; | ||
20 | +import android.graphics.Bitmap.Config; | ||
21 | +import android.graphics.Canvas; | ||
22 | +import android.graphics.Color; | ||
23 | +import android.graphics.Matrix; | ||
24 | +import android.graphics.Paint; | ||
25 | +import android.graphics.Paint.Style; | ||
26 | +import android.graphics.RectF; | ||
27 | +import android.graphics.Typeface; | ||
28 | +import android.media.ImageReader.OnImageAvailableListener; | ||
29 | +import android.os.SystemClock; | ||
30 | +import android.util.Size; | ||
31 | +import android.util.TypedValue; | ||
32 | +import android.view.Display; | ||
33 | +import android.view.Surface; | ||
34 | +import android.widget.Toast; | ||
35 | +import java.io.IOException; | ||
36 | +import java.util.LinkedList; | ||
37 | +import java.util.List; | ||
38 | +import java.util.Vector; | ||
39 | +import org.tensorflow.demo.OverlayView.DrawCallback; | ||
40 | +import org.tensorflow.demo.env.BorderedText; | ||
41 | +import org.tensorflow.demo.env.ImageUtils; | ||
42 | +import org.tensorflow.demo.env.Logger; | ||
43 | +import org.tensorflow.demo.tracking.MultiBoxTracker; | ||
44 | +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. | ||
45 | + | ||
46 | +/** | ||
47 | + * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track | ||
48 | + * objects. | ||
49 | + */ | ||
50 | +public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { | ||
51 | + private static final Logger LOGGER = new Logger(); | ||
52 | + | ||
53 | + // Configuration values for the prepackaged multibox model. | ||
54 | + private static final int MB_INPUT_SIZE = 224; | ||
55 | + private static final int MB_IMAGE_MEAN = 128; | ||
56 | + private static final float MB_IMAGE_STD = 128; | ||
57 | + private static final String MB_INPUT_NAME = "ResizeBilinear"; | ||
58 | + private static final String MB_OUTPUT_LOCATIONS_NAME = "output_locations/Reshape"; | ||
59 | + private static final String MB_OUTPUT_SCORES_NAME = "output_scores/Reshape"; | ||
60 | + private static final String MB_MODEL_FILE = "file:///android_asset/multibox_model.pb"; | ||
61 | + private static final String MB_LOCATION_FILE = | ||
62 | + "file:///android_asset/multibox_location_priors.txt"; | ||
63 | + | ||
64 | + private static final int TF_OD_API_INPUT_SIZE = 300; | ||
65 | + private static final String TF_OD_API_MODEL_FILE = | ||
66 | + "file:///android_asset/ssd_mobilenet_v1_android_export.pb"; | ||
67 | + private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt"; | ||
68 | + | ||
69 | + // Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and | ||
70 | + // must be manually placed in the assets/ directory by the user. | ||
71 | + // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via | ||
72 | + // DarkFlow (https://github.com/thtrieu/darkflow). Sample command: | ||
73 | + // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise | ||
74 | + private static final String YOLO_MODEL_FILE = "file:///android_asset/yolov3.pb"; | ||
75 | + private static final int YOLO_INPUT_SIZE = 416; | ||
76 | + private static final String YOLO_INPUT_NAME = "input"; | ||
77 | + private static final String YOLO_OUTPUT_NAMES = "output"; | ||
78 | + private static final int YOLO_BLOCK_SIZE = 32; | ||
79 | + | ||
80 | + // Which detection model to use: by default uses Tensorflow Object Detection API frozen | ||
81 | + // checkpoints. Optionally use legacy Multibox (trained using an older version of the API) | ||
82 | + // or YOLO. | ||
83 | + private enum DetectorMode { | ||
84 | + TF_OD_API, MULTIBOX, YOLO; | ||
85 | + } | ||
86 | + private static final DetectorMode MODE = DetectorMode.YOLO; | ||
87 | + | ||
88 | + // Minimum detection confidence to track a detection. | ||
89 | + private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f; | ||
90 | + private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f; | ||
91 | + private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f; | ||
92 | + | ||
93 | + private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO; | ||
94 | + | ||
95 | + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); | ||
96 | + | ||
97 | + private static final boolean SAVE_PREVIEW_BITMAP = false; | ||
98 | + private static final float TEXT_SIZE_DIP = 10; | ||
99 | + | ||
100 | + private Integer sensorOrientation; | ||
101 | + | ||
102 | + private Classifier detector; | ||
103 | + | ||
104 | + private long lastProcessingTimeMs; | ||
105 | + private Bitmap rgbFrameBitmap = null; | ||
106 | + private Bitmap croppedBitmap = null; | ||
107 | + private Bitmap cropCopyBitmap = null; | ||
108 | + | ||
109 | + private boolean computingDetection = false; | ||
110 | + | ||
111 | + private long timestamp = 0; | ||
112 | + | ||
113 | + private Matrix frameToCropTransform; | ||
114 | + private Matrix cropToFrameTransform; | ||
115 | + | ||
116 | + private MultiBoxTracker tracker; | ||
117 | + | ||
118 | + private byte[] luminanceCopy; | ||
119 | + | ||
120 | + private BorderedText borderedText; | ||
121 | + @Override | ||
122 | + public void onPreviewSizeChosen(final Size size, final int rotation) { | ||
123 | + final float textSizePx = | ||
124 | + TypedValue.applyDimension( | ||
125 | + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); | ||
126 | + borderedText = new BorderedText(textSizePx); | ||
127 | + borderedText.setTypeface(Typeface.MONOSPACE); | ||
128 | + | ||
129 | + tracker = new MultiBoxTracker(this); | ||
130 | + | ||
131 | + int cropSize = TF_OD_API_INPUT_SIZE; | ||
132 | + if (MODE == DetectorMode.YOLO) { | ||
133 | + detector = | ||
134 | + TensorFlowYoloDetector.create( | ||
135 | + getAssets(), | ||
136 | + YOLO_MODEL_FILE, | ||
137 | + YOLO_INPUT_SIZE, | ||
138 | + YOLO_INPUT_NAME, | ||
139 | + YOLO_OUTPUT_NAMES, | ||
140 | + YOLO_BLOCK_SIZE); | ||
141 | + cropSize = YOLO_INPUT_SIZE; | ||
142 | + } else if (MODE == DetectorMode.MULTIBOX) { | ||
143 | + detector = | ||
144 | + TensorFlowMultiBoxDetector.create( | ||
145 | + getAssets(), | ||
146 | + MB_MODEL_FILE, | ||
147 | + MB_LOCATION_FILE, | ||
148 | + MB_IMAGE_MEAN, | ||
149 | + MB_IMAGE_STD, | ||
150 | + MB_INPUT_NAME, | ||
151 | + MB_OUTPUT_LOCATIONS_NAME, | ||
152 | + MB_OUTPUT_SCORES_NAME); | ||
153 | + cropSize = MB_INPUT_SIZE; | ||
154 | + } else { | ||
155 | + try { | ||
156 | + detector = TensorFlowObjectDetectionAPIModel.create( | ||
157 | + getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); | ||
158 | + cropSize = TF_OD_API_INPUT_SIZE; | ||
159 | + } catch (final IOException e) { | ||
160 | + LOGGER.e(e, "Exception initializing classifier!"); | ||
161 | + Toast toast = | ||
162 | + Toast.makeText( | ||
163 | + getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); | ||
164 | + toast.show(); | ||
165 | + finish(); | ||
166 | + } | ||
167 | + } | ||
168 | + | ||
169 | + previewWidth = size.getWidth(); | ||
170 | + previewHeight = size.getHeight(); | ||
171 | + | ||
172 | + sensorOrientation = rotation - getScreenOrientation(); | ||
173 | + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); | ||
174 | + | ||
175 | + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); | ||
176 | + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); | ||
177 | + croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888); | ||
178 | + | ||
179 | + frameToCropTransform = | ||
180 | + ImageUtils.getTransformationMatrix( | ||
181 | + previewWidth, previewHeight, | ||
182 | + cropSize, cropSize, | ||
183 | + sensorOrientation, MAINTAIN_ASPECT); | ||
184 | + | ||
185 | + cropToFrameTransform = new Matrix(); | ||
186 | + frameToCropTransform.invert(cropToFrameTransform); | ||
187 | + | ||
188 | + trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay); | ||
189 | + trackingOverlay.addCallback( | ||
190 | + new DrawCallback() { | ||
191 | + @Override | ||
192 | + public void drawCallback(final Canvas canvas) { | ||
193 | + tracker.draw(canvas); | ||
194 | + if (isDebug()) { | ||
195 | + tracker.drawDebug(canvas); | ||
196 | + } | ||
197 | + } | ||
198 | + }); | ||
199 | + | ||
200 | + addCallback( | ||
201 | + new DrawCallback() { | ||
202 | + @Override | ||
203 | + public void drawCallback(final Canvas canvas) { | ||
204 | + if (!isDebug()) { | ||
205 | + return; | ||
206 | + } | ||
207 | + final Bitmap copy = cropCopyBitmap; | ||
208 | + if (copy == null) { | ||
209 | + return; | ||
210 | + } | ||
211 | + | ||
212 | + final int backgroundColor = Color.argb(100, 0, 0, 0); | ||
213 | + canvas.drawColor(backgroundColor); | ||
214 | + | ||
215 | + final Matrix matrix = new Matrix(); | ||
216 | + final float scaleFactor = 2; | ||
217 | + matrix.postScale(scaleFactor, scaleFactor); | ||
218 | + matrix.postTranslate( | ||
219 | + canvas.getWidth() - copy.getWidth() * scaleFactor, | ||
220 | + canvas.getHeight() - copy.getHeight() * scaleFactor); | ||
221 | + canvas.drawBitmap(copy, matrix, new Paint()); | ||
222 | + | ||
223 | + final Vector<String> lines = new Vector<String>(); | ||
224 | + if (detector != null) { | ||
225 | + final String statString = detector.getStatString(); | ||
226 | + final String[] statLines = statString.split("\n"); | ||
227 | + for (final String line : statLines) { | ||
228 | + lines.add(line); | ||
229 | + } | ||
230 | + } | ||
231 | + lines.add(""); | ||
232 | + | ||
233 | + lines.add("Frame: " + previewWidth + "x" + previewHeight); | ||
234 | + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); | ||
235 | + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); | ||
236 | + lines.add("Rotation: " + sensorOrientation); | ||
237 | + lines.add("Inference time: " + lastProcessingTimeMs + "ms"); | ||
238 | + | ||
239 | + borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); | ||
240 | + } | ||
241 | + }); | ||
242 | + } | ||
243 | + | ||
244 | + OverlayView trackingOverlay; | ||
245 | + | ||
246 | + @Override | ||
247 | + protected void processImage() { | ||
248 | + ++timestamp; | ||
249 | + final long currTimestamp = timestamp; | ||
250 | + byte[] originalLuminance = getLuminance(); | ||
251 | + tracker.onFrame( | ||
252 | + previewWidth, | ||
253 | + previewHeight, | ||
254 | + getLuminanceStride(), | ||
255 | + sensorOrientation, | ||
256 | + originalLuminance, | ||
257 | + timestamp); | ||
258 | + trackingOverlay.postInvalidate(); | ||
259 | + | ||
260 | + // No mutex needed as this method is not reentrant. | ||
261 | + if (computingDetection) { | ||
262 | + readyForNextImage(); | ||
263 | + return; | ||
264 | + } | ||
265 | + computingDetection = true; | ||
266 | + LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread."); | ||
267 | + | ||
268 | + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); | ||
269 | + | ||
270 | + if (luminanceCopy == null) { | ||
271 | + luminanceCopy = new byte[originalLuminance.length]; | ||
272 | + } | ||
273 | + System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length); | ||
274 | + readyForNextImage(); | ||
275 | + | ||
276 | + final Canvas canvas = new Canvas(croppedBitmap); | ||
277 | + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); | ||
278 | + // For examining the actual TF input. | ||
279 | + if (SAVE_PREVIEW_BITMAP) { | ||
280 | + ImageUtils.saveBitmap(croppedBitmap); | ||
281 | + } | ||
282 | + | ||
283 | + runInBackground( | ||
284 | + new Runnable() { | ||
285 | + @Override | ||
286 | + public void run() { | ||
287 | + LOGGER.i("Running detection on image " + currTimestamp); | ||
288 | + final long startTime = SystemClock.uptimeMillis(); | ||
289 | + final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); | ||
290 | + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; | ||
291 | + | ||
292 | + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); | ||
293 | + final Canvas canvas = new Canvas(cropCopyBitmap); | ||
294 | + final Paint paint = new Paint(); | ||
295 | + paint.setColor(Color.RED); | ||
296 | + paint.setStyle(Style.STROKE); | ||
297 | + paint.setStrokeWidth(2.0f); | ||
298 | + | ||
299 | + float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; | ||
300 | + switch (MODE) { | ||
301 | + case TF_OD_API: | ||
302 | + minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; | ||
303 | + break; | ||
304 | + case MULTIBOX: | ||
305 | + minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX; | ||
306 | + break; | ||
307 | + case YOLO: | ||
308 | + minimumConfidence = MINIMUM_CONFIDENCE_YOLO; | ||
309 | + break; | ||
310 | + } | ||
311 | + | ||
312 | + final List<Classifier.Recognition> mappedRecognitions = | ||
313 | + new LinkedList<Classifier.Recognition>(); | ||
314 | + | ||
315 | + for (final Classifier.Recognition result : results) { | ||
316 | + final RectF location = result.getLocation(); | ||
317 | + if (location != null && result.getConfidence() >= minimumConfidence) { | ||
318 | + canvas.drawRect(location, paint); | ||
319 | + | ||
320 | + cropToFrameTransform.mapRect(location); | ||
321 | + result.setLocation(location); | ||
322 | + mappedRecognitions.add(result); | ||
323 | + } | ||
324 | + } | ||
325 | + | ||
326 | + tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); | ||
327 | + trackingOverlay.postInvalidate(); | ||
328 | + | ||
329 | + requestRender(); | ||
330 | + computingDetection = false; | ||
331 | + } | ||
332 | + }); | ||
333 | + } | ||
334 | + | ||
335 | + @Override | ||
336 | + protected int getLayoutId() { | ||
337 | + return R.layout.camera_connection_fragment_tracking; | ||
338 | + } | ||
339 | + | ||
340 | + @Override | ||
341 | + protected Size getDesiredPreviewFrameSize() { | ||
342 | + return DESIRED_PREVIEW_SIZE; | ||
343 | + } | ||
344 | + | ||
345 | + @Override | ||
346 | + public void onSetDebug(final boolean debug) { | ||
347 | + detector.enableStatLogging(debug); | ||
348 | + } | ||
349 | +} |
1 | +package org.tensorflow.demo; | ||
2 | + | ||
3 | +/* | ||
4 | + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
5 | + * | ||
6 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
7 | + * you may not use this file except in compliance with the License. | ||
8 | + * You may obtain a copy of the License at | ||
9 | + * | ||
10 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
11 | + * | ||
12 | + * Unless required by applicable law or agreed to in writing, software | ||
13 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
15 | + * See the License for the specific language governing permissions and | ||
16 | + * limitations under the License. | ||
17 | + */ | ||
18 | + | ||
19 | +import android.app.Fragment; | ||
20 | +import android.graphics.SurfaceTexture; | ||
21 | +import android.hardware.Camera; | ||
22 | +import android.hardware.Camera.CameraInfo; | ||
23 | +import android.os.Bundle; | ||
24 | +import android.os.Handler; | ||
25 | +import android.os.HandlerThread; | ||
26 | +import android.util.Size; | ||
27 | +import android.util.SparseIntArray; | ||
28 | +import android.view.LayoutInflater; | ||
29 | +import android.view.Surface; | ||
30 | +import android.view.TextureView; | ||
31 | +import android.view.View; | ||
32 | +import android.view.ViewGroup; | ||
33 | +import java.io.IOException; | ||
34 | +import java.util.List; | ||
35 | +import org.tensorflow.demo.env.ImageUtils; | ||
36 | +import org.tensorflow.demo.env.Logger; | ||
37 | +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. | ||
38 | + | ||
39 | +public class LegacyCameraConnectionFragment extends Fragment { | ||
40 | + private Camera camera; | ||
41 | + private static final Logger LOGGER = new Logger(); | ||
42 | + private Camera.PreviewCallback imageListener; | ||
43 | + private Size desiredSize; | ||
44 | + | ||
45 | + /** | ||
46 | + * The layout identifier to inflate for this Fragment. | ||
47 | + */ | ||
48 | + private int layout; | ||
49 | + | ||
50 | + public LegacyCameraConnectionFragment( | ||
51 | + final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) { | ||
52 | + this.imageListener = imageListener; | ||
53 | + this.layout = layout; | ||
54 | + this.desiredSize = desiredSize; | ||
55 | + } | ||
56 | + | ||
57 | + /** | ||
58 | + * Conversion from screen rotation to JPEG orientation. | ||
59 | + */ | ||
60 | + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); | ||
61 | + | ||
62 | + static { | ||
63 | + ORIENTATIONS.append(Surface.ROTATION_0, 90); | ||
64 | + ORIENTATIONS.append(Surface.ROTATION_90, 0); | ||
65 | + ORIENTATIONS.append(Surface.ROTATION_180, 270); | ||
66 | + ORIENTATIONS.append(Surface.ROTATION_270, 180); | ||
67 | + } | ||
68 | + | ||
69 | + /** | ||
70 | + * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a | ||
71 | + * {@link TextureView}. | ||
72 | + */ | ||
73 | + private final TextureView.SurfaceTextureListener surfaceTextureListener = | ||
74 | + new TextureView.SurfaceTextureListener() { | ||
75 | + @Override | ||
76 | + public void onSurfaceTextureAvailable( | ||
77 | + final SurfaceTexture texture, final int width, final int height) { | ||
78 | + | ||
79 | + int index = getCameraId(); | ||
80 | + camera = Camera.open(index); | ||
81 | + | ||
82 | + try { | ||
83 | + Camera.Parameters parameters = camera.getParameters(); | ||
84 | + List<String> focusModes = parameters.getSupportedFocusModes(); | ||
85 | + if (focusModes != null | ||
86 | + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { | ||
87 | + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); | ||
88 | + } | ||
89 | + List<Camera.Size> cameraSizes = parameters.getSupportedPreviewSizes(); | ||
90 | + Size[] sizes = new Size[cameraSizes.size()]; | ||
91 | + int i = 0; | ||
92 | + for (Camera.Size size : cameraSizes) { | ||
93 | + sizes[i++] = new Size(size.width, size.height); | ||
94 | + } | ||
95 | + Size previewSize = | ||
96 | + CameraConnectionFragment.chooseOptimalSize( | ||
97 | + sizes, desiredSize.getWidth(), desiredSize.getHeight()); | ||
98 | + parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight()); | ||
99 | + camera.setDisplayOrientation(90); | ||
100 | + camera.setParameters(parameters); | ||
101 | + camera.setPreviewTexture(texture); | ||
102 | + } catch (IOException exception) { | ||
103 | + camera.release(); | ||
104 | + } | ||
105 | + | ||
106 | + camera.setPreviewCallbackWithBuffer(imageListener); | ||
107 | + Camera.Size s = camera.getParameters().getPreviewSize(); | ||
108 | + camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]); | ||
109 | + | ||
110 | + textureView.setAspectRatio(s.height, s.width); | ||
111 | + | ||
112 | + camera.startPreview(); | ||
113 | + } | ||
114 | + | ||
115 | + @Override | ||
116 | + public void onSurfaceTextureSizeChanged( | ||
117 | + final SurfaceTexture texture, final int width, final int height) {} | ||
118 | + | ||
119 | + @Override | ||
120 | + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { | ||
121 | + return true; | ||
122 | + } | ||
123 | + | ||
124 | + @Override | ||
125 | + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} | ||
126 | + }; | ||
127 | + | ||
128 | + /** | ||
129 | + * An {@link AutoFitTextureView} for camera preview. | ||
130 | + */ | ||
131 | + private AutoFitTextureView textureView; | ||
132 | + | ||
133 | + /** | ||
134 | + * An additional thread for running tasks that shouldn't block the UI. | ||
135 | + */ | ||
136 | + private HandlerThread backgroundThread; | ||
137 | + | ||
138 | + @Override | ||
139 | + public View onCreateView( | ||
140 | + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { | ||
141 | + return inflater.inflate(layout, container, false); | ||
142 | + } | ||
143 | + | ||
144 | + @Override | ||
145 | + public void onViewCreated(final View view, final Bundle savedInstanceState) { | ||
146 | + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); | ||
147 | + } | ||
148 | + | ||
149 | + @Override | ||
150 | + public void onActivityCreated(final Bundle savedInstanceState) { | ||
151 | + super.onActivityCreated(savedInstanceState); | ||
152 | + } | ||
153 | + | ||
154 | + @Override | ||
155 | + public void onResume() { | ||
156 | + super.onResume(); | ||
157 | + startBackgroundThread(); | ||
158 | + // When the screen is turned off and turned back on, the SurfaceTexture is already | ||
159 | + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open | ||
160 | + // a camera and start preview from here (otherwise, we wait until the surface is ready in | ||
161 | + // the SurfaceTextureListener). | ||
162 | + | ||
163 | + if (textureView.isAvailable()) { | ||
164 | + camera.startPreview(); | ||
165 | + } else { | ||
166 | + textureView.setSurfaceTextureListener(surfaceTextureListener); | ||
167 | + } | ||
168 | + } | ||
169 | + | ||
170 | + @Override | ||
171 | + public void onPause() { | ||
172 | + stopCamera(); | ||
173 | + stopBackgroundThread(); | ||
174 | + super.onPause(); | ||
175 | + } | ||
176 | + | ||
177 | + /** | ||
178 | + * Starts a background thread and its {@link Handler}. | ||
179 | + */ | ||
180 | + private void startBackgroundThread() { | ||
181 | + backgroundThread = new HandlerThread("CameraBackground"); | ||
182 | + backgroundThread.start(); | ||
183 | + } | ||
184 | + | ||
185 | + /** | ||
186 | + * Stops the background thread and its {@link Handler}. | ||
187 | + */ | ||
188 | + private void stopBackgroundThread() { | ||
189 | + backgroundThread.quitSafely(); | ||
190 | + try { | ||
191 | + backgroundThread.join(); | ||
192 | + backgroundThread = null; | ||
193 | + } catch (final InterruptedException e) { | ||
194 | + LOGGER.e(e, "Exception!"); | ||
195 | + } | ||
196 | + } | ||
197 | + | ||
198 | + protected void stopCamera() { | ||
199 | + if (camera != null) { | ||
200 | + camera.stopPreview(); | ||
201 | + camera.setPreviewCallback(null); | ||
202 | + camera.release(); | ||
203 | + camera = null; | ||
204 | + } | ||
205 | + } | ||
206 | + | ||
207 | + private int getCameraId() { | ||
208 | + CameraInfo ci = new CameraInfo(); | ||
209 | + for (int i = 0; i < Camera.getNumberOfCameras(); i++) { | ||
210 | + Camera.getCameraInfo(i, ci); | ||
211 | + if (ci.facing == CameraInfo.CAMERA_FACING_BACK) | ||
212 | + return i; | ||
213 | + } | ||
214 | + return -1; // No camera found | ||
215 | + } | ||
216 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.content.Context; | ||
19 | +import android.graphics.Canvas; | ||
20 | +import android.util.AttributeSet; | ||
21 | +import android.view.View; | ||
22 | +import java.util.LinkedList; | ||
23 | +import java.util.List; | ||
24 | + | ||
25 | +/** | ||
26 | + * A simple View providing a render callback to other classes. | ||
27 | + */ | ||
28 | +public class OverlayView extends View { | ||
29 | + private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>(); | ||
30 | + | ||
31 | + public OverlayView(final Context context, final AttributeSet attrs) { | ||
32 | + super(context, attrs); | ||
33 | + } | ||
34 | + | ||
35 | + /** | ||
36 | + * Interface defining the callback for client classes. | ||
37 | + */ | ||
38 | + public interface DrawCallback { | ||
39 | + public void drawCallback(final Canvas canvas); | ||
40 | + } | ||
41 | + | ||
42 | + public void addCallback(final DrawCallback callback) { | ||
43 | + callbacks.add(callback); | ||
44 | + } | ||
45 | + | ||
46 | + @Override | ||
47 | + public synchronized void draw(final Canvas canvas) { | ||
48 | + for (final DrawCallback callback : callbacks) { | ||
49 | + callback.drawCallback(canvas); | ||
50 | + } | ||
51 | + } | ||
52 | +} |
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.content.Context; | ||
19 | +import android.graphics.Canvas; | ||
20 | +import android.graphics.Paint; | ||
21 | +import android.util.AttributeSet; | ||
22 | +import android.util.TypedValue; | ||
23 | +import android.view.View; | ||
24 | + | ||
25 | +import org.tensorflow.demo.Classifier.Recognition; | ||
26 | + | ||
27 | +import java.util.List; | ||
28 | + | ||
29 | +public class RecognitionScoreView extends View implements ResultsView { | ||
30 | + private static final float TEXT_SIZE_DIP = 24; | ||
31 | + private List<Recognition> results; | ||
32 | + private final float textSizePx; | ||
33 | + private final Paint fgPaint; | ||
34 | + private final Paint bgPaint; | ||
35 | + | ||
36 | + public RecognitionScoreView(final Context context, final AttributeSet set) { | ||
37 | + super(context, set); | ||
38 | + | ||
39 | + textSizePx = | ||
40 | + TypedValue.applyDimension( | ||
41 | + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); | ||
42 | + fgPaint = new Paint(); | ||
43 | + fgPaint.setTextSize(textSizePx); | ||
44 | + | ||
45 | + bgPaint = new Paint(); | ||
46 | + bgPaint.setColor(0xcc4285f4); | ||
47 | + } | ||
48 | + | ||
49 | + @Override | ||
50 | + public void setResults(final List<Recognition> results) { | ||
51 | + this.results = results; | ||
52 | + postInvalidate(); | ||
53 | + } | ||
54 | + | ||
55 | + @Override | ||
56 | + public void onDraw(final Canvas canvas) { | ||
57 | + final int x = 10; | ||
58 | + int y = (int) (fgPaint.getTextSize() * 1.5f); | ||
59 | + | ||
60 | + canvas.drawPaint(bgPaint); | ||
61 | + | ||
62 | + if (results != null) { | ||
63 | + for (final Recognition recog : results) { | ||
64 | + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); | ||
65 | + y += fgPaint.getTextSize() * 1.5f; | ||
66 | + } | ||
67 | + } | ||
68 | + } | ||
69 | +} |
1 | +/* | ||
2 | + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.util.Log; | ||
20 | +import android.util.Pair; | ||
21 | +import java.util.ArrayDeque; | ||
22 | +import java.util.ArrayList; | ||
23 | +import java.util.Arrays; | ||
24 | +import java.util.Deque; | ||
25 | +import java.util.List; | ||
26 | + | ||
27 | +/** Reads in results from an instantaneous audio recognition model and smoothes them over time. */ | ||
28 | +public class RecognizeCommands { | ||
29 | + // Configuration settings. | ||
30 | + private List<String> labels = new ArrayList<String>(); | ||
31 | + private long averageWindowDurationMs; | ||
32 | + private float detectionThreshold; | ||
33 | + private int suppressionMs; | ||
34 | + private int minimumCount; | ||
35 | + private long minimumTimeBetweenSamplesMs; | ||
36 | + | ||
37 | + // Working variables. | ||
38 | + private Deque<Pair<Long, float[]>> previousResults = new ArrayDeque<Pair<Long, float[]>>(); | ||
39 | + private String previousTopLabel; | ||
40 | + private int labelsCount; | ||
41 | + private long previousTopLabelTime; | ||
42 | + private float previousTopLabelScore; | ||
43 | + | ||
44 | + private static final String SILENCE_LABEL = "_silence_"; | ||
45 | + private static final long MINIMUM_TIME_FRACTION = 4; | ||
46 | + | ||
47 | + public RecognizeCommands( | ||
48 | + List<String> inLabels, | ||
49 | + long inAverageWindowDurationMs, | ||
50 | + float inDetectionThreshold, | ||
51 | + int inSuppressionMS, | ||
52 | + int inMinimumCount, | ||
53 | + long inMinimumTimeBetweenSamplesMS) { | ||
54 | + labels = inLabels; | ||
55 | + averageWindowDurationMs = inAverageWindowDurationMs; | ||
56 | + detectionThreshold = inDetectionThreshold; | ||
57 | + suppressionMs = inSuppressionMS; | ||
58 | + minimumCount = inMinimumCount; | ||
59 | + labelsCount = inLabels.size(); | ||
60 | + previousTopLabel = SILENCE_LABEL; | ||
61 | + previousTopLabelTime = Long.MIN_VALUE; | ||
62 | + previousTopLabelScore = 0.0f; | ||
63 | + minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS; | ||
64 | + } | ||
65 | + | ||
66 | + /** Holds information about what's been recognized. */ | ||
67 | + public static class RecognitionResult { | ||
68 | + public final String foundCommand; | ||
69 | + public final float score; | ||
70 | + public final boolean isNewCommand; | ||
71 | + | ||
72 | + public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) { | ||
73 | + foundCommand = inFoundCommand; | ||
74 | + score = inScore; | ||
75 | + isNewCommand = inIsNewCommand; | ||
76 | + } | ||
77 | + } | ||
78 | + | ||
79 | + private static class ScoreForSorting implements Comparable<ScoreForSorting> { | ||
80 | + public final float score; | ||
81 | + public final int index; | ||
82 | + | ||
83 | + public ScoreForSorting(float inScore, int inIndex) { | ||
84 | + score = inScore; | ||
85 | + index = inIndex; | ||
86 | + } | ||
87 | + | ||
88 | + @Override | ||
89 | + public int compareTo(ScoreForSorting other) { | ||
90 | + if (this.score > other.score) { | ||
91 | + return -1; | ||
92 | + } else if (this.score < other.score) { | ||
93 | + return 1; | ||
94 | + } else { | ||
95 | + return 0; | ||
96 | + } | ||
97 | + } | ||
98 | + } | ||
99 | + | ||
100 | + public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) { | ||
101 | + if (currentResults.length != labelsCount) { | ||
102 | + throw new RuntimeException( | ||
103 | + "The results for recognition should contain " | ||
104 | + + labelsCount | ||
105 | + + " elements, but there are " | ||
106 | + + currentResults.length); | ||
107 | + } | ||
108 | + | ||
109 | + if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) { | ||
110 | + throw new RuntimeException( | ||
111 | + "You must feed results in increasing time order, but received a timestamp of " | ||
112 | + + currentTimeMS | ||
113 | + + " that was earlier than the previous one of " | ||
114 | + + previousResults.getFirst().first); | ||
115 | + } | ||
116 | + | ||
117 | + final int howManyResults = previousResults.size(); | ||
118 | + // Ignore any results that are coming in too frequently. | ||
119 | + if (howManyResults > 1) { | ||
120 | + final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first; | ||
121 | + if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) { | ||
122 | + return new RecognitionResult(previousTopLabel, previousTopLabelScore, false); | ||
123 | + } | ||
124 | + } | ||
125 | + | ||
126 | + // Add the latest results to the head of the queue. | ||
127 | + previousResults.addLast(new Pair<Long, float[]>(currentTimeMS, currentResults)); | ||
128 | + | ||
129 | + // Prune any earlier results that are too old for the averaging window. | ||
130 | + final long timeLimit = currentTimeMS - averageWindowDurationMs; | ||
131 | + while (previousResults.getFirst().first < timeLimit) { | ||
132 | + previousResults.removeFirst(); | ||
133 | + } | ||
134 | + | ||
135 | + // If there are too few results, assume the result will be unreliable and | ||
136 | + // bail. | ||
137 | + final long earliestTime = previousResults.getFirst().first; | ||
138 | + final long samplesDuration = currentTimeMS - earliestTime; | ||
139 | + if ((howManyResults < minimumCount) | ||
140 | + || (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) { | ||
141 | + Log.v("RecognizeResult", "Too few results"); | ||
142 | + return new RecognitionResult(previousTopLabel, 0.0f, false); | ||
143 | + } | ||
144 | + | ||
145 | + // Calculate the average score across all the results in the window. | ||
146 | + float[] averageScores = new float[labelsCount]; | ||
147 | + for (Pair<Long, float[]> previousResult : previousResults) { | ||
148 | + final float[] scoresTensor = previousResult.second; | ||
149 | + int i = 0; | ||
150 | + while (i < scoresTensor.length) { | ||
151 | + averageScores[i] += scoresTensor[i] / howManyResults; | ||
152 | + ++i; | ||
153 | + } | ||
154 | + } | ||
155 | + | ||
156 | + // Sort the averaged results in descending score order. | ||
157 | + ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount]; | ||
158 | + for (int i = 0; i < labelsCount; ++i) { | ||
159 | + sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i); | ||
160 | + } | ||
161 | + Arrays.sort(sortedAverageScores); | ||
162 | + | ||
163 | + // See if the latest top score is enough to trigger a detection. | ||
164 | + final int currentTopIndex = sortedAverageScores[0].index; | ||
165 | + final String currentTopLabel = labels.get(currentTopIndex); | ||
166 | + final float currentTopScore = sortedAverageScores[0].score; | ||
167 | + // If we've recently had another label trigger, assume one that occurs too | ||
168 | + // soon afterwards is a bad result. | ||
169 | + long timeSinceLastTop; | ||
170 | + if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) { | ||
171 | + timeSinceLastTop = Long.MAX_VALUE; | ||
172 | + } else { | ||
173 | + timeSinceLastTop = currentTimeMS - previousTopLabelTime; | ||
174 | + } | ||
175 | + boolean isNewCommand; | ||
176 | + if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) { | ||
177 | + previousTopLabel = currentTopLabel; | ||
178 | + previousTopLabelTime = currentTimeMS; | ||
179 | + previousTopLabelScore = currentTopScore; | ||
180 | + isNewCommand = true; | ||
181 | + } else { | ||
182 | + isNewCommand = false; | ||
183 | + } | ||
184 | + return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand); | ||
185 | + } | ||
186 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import org.tensorflow.demo.Classifier.Recognition; | ||
19 | + | ||
20 | +import java.util.List; | ||
21 | + | ||
22 | +public interface ResultsView { | ||
23 | + public void setResults(final List<Recognition> results); | ||
24 | +} |
1 | +/* | ||
2 | + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +/* Demonstrates how to run an audio recognition model in Android. | ||
18 | + | ||
19 | +This example loads a simple speech recognition model trained by the tutorial at | ||
20 | +https://www.tensorflow.org/tutorials/audio_training | ||
21 | + | ||
22 | +The model files should be downloaded automatically from the TensorFlow website, | ||
23 | +but if you have a custom model you can update the LABEL_FILENAME and | ||
24 | +MODEL_FILENAME constants to point to your own files. | ||
25 | + | ||
26 | +The example application displays a list view with all of the known audio labels, | ||
27 | +and highlights each one when it thinks it has detected one through the | ||
28 | +microphone. The averaging of results to give a more reliable signal happens in | ||
29 | +the RecognizeCommands helper class. | ||
30 | +*/ | ||
31 | + | ||
32 | +package org.tensorflow.demo; | ||
33 | + | ||
34 | +import android.animation.AnimatorInflater; | ||
35 | +import android.animation.AnimatorSet; | ||
36 | +import android.app.Activity; | ||
37 | +import android.content.pm.PackageManager; | ||
38 | +import android.media.AudioFormat; | ||
39 | +import android.media.AudioRecord; | ||
40 | +import android.media.MediaRecorder; | ||
41 | +import android.os.Build; | ||
42 | +import android.os.Bundle; | ||
43 | +import android.util.Log; | ||
44 | +import android.view.View; | ||
45 | +import android.widget.ArrayAdapter; | ||
46 | +import android.widget.Button; | ||
47 | +import android.widget.ListView; | ||
48 | +import java.io.BufferedReader; | ||
49 | +import java.io.IOException; | ||
50 | +import java.io.InputStreamReader; | ||
51 | +import java.util.ArrayList; | ||
52 | +import java.util.List; | ||
53 | +import java.util.concurrent.locks.ReentrantLock; | ||
54 | +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; | ||
55 | +import org.tensorflow.demo.R; | ||
56 | + | ||
57 | +/** | ||
58 | + * An activity that listens for audio and then uses a TensorFlow model to detect particular classes, | ||
59 | + * by default a small set of action words. | ||
60 | + */ | ||
61 | +public class SpeechActivity extends Activity { | ||
62 | + | ||
63 | + // Constants that control the behavior of the recognition code and model | ||
64 | + // settings. See the audio recognition tutorial for a detailed explanation of | ||
65 | + // all these, but you should customize them to match your training settings if | ||
66 | + // you are running your own model. | ||
67 | + private static final int SAMPLE_RATE = 16000; | ||
68 | + private static final int SAMPLE_DURATION_MS = 1000; | ||
69 | + private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000); | ||
70 | + private static final long AVERAGE_WINDOW_DURATION_MS = 500; | ||
71 | + private static final float DETECTION_THRESHOLD = 0.70f; | ||
72 | + private static final int SUPPRESSION_MS = 1500; | ||
73 | + private static final int MINIMUM_COUNT = 3; | ||
74 | + private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30; | ||
75 | + private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt"; | ||
76 | + private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.pb"; | ||
77 | + private static final String INPUT_DATA_NAME = "decoded_sample_data:0"; | ||
78 | + private static final String SAMPLE_RATE_NAME = "decoded_sample_data:1"; | ||
79 | + private static final String OUTPUT_SCORES_NAME = "labels_softmax"; | ||
80 | + | ||
81 | + // UI elements. | ||
82 | + private static final int REQUEST_RECORD_AUDIO = 13; | ||
83 | + private Button quitButton; | ||
84 | + private ListView labelsListView; | ||
85 | + private static final String LOG_TAG = SpeechActivity.class.getSimpleName(); | ||
86 | + | ||
87 | + // Working variables. | ||
88 | + short[] recordingBuffer = new short[RECORDING_LENGTH]; | ||
89 | + int recordingOffset = 0; | ||
90 | + boolean shouldContinue = true; | ||
91 | + private Thread recordingThread; | ||
92 | + boolean shouldContinueRecognition = true; | ||
93 | + private Thread recognitionThread; | ||
94 | + private final ReentrantLock recordingBufferLock = new ReentrantLock(); | ||
95 | + private TensorFlowInferenceInterface inferenceInterface; | ||
96 | + private List<String> labels = new ArrayList<String>(); | ||
97 | + private List<String> displayedLabels = new ArrayList<>(); | ||
98 | + private RecognizeCommands recognizeCommands = null; | ||
99 | + | ||
100 | + @Override | ||
101 | + protected void onCreate(Bundle savedInstanceState) { | ||
102 | + // Set up the UI. | ||
103 | + super.onCreate(savedInstanceState); | ||
104 | + setContentView(R.layout.activity_speech); | ||
105 | + quitButton = (Button) findViewById(R.id.quit); | ||
106 | + quitButton.setOnClickListener( | ||
107 | + new View.OnClickListener() { | ||
108 | + @Override | ||
109 | + public void onClick(View view) { | ||
110 | + moveTaskToBack(true); | ||
111 | + android.os.Process.killProcess(android.os.Process.myPid()); | ||
112 | + System.exit(1); | ||
113 | + } | ||
114 | + }); | ||
115 | + labelsListView = (ListView) findViewById(R.id.list_view); | ||
116 | + | ||
117 | + // Load the labels for the model, but only display those that don't start | ||
118 | + // with an underscore. | ||
119 | + String actualFilename = LABEL_FILENAME.split("file:///android_asset/")[1]; | ||
120 | + Log.i(LOG_TAG, "Reading labels from: " + actualFilename); | ||
121 | + BufferedReader br = null; | ||
122 | + try { | ||
123 | + br = new BufferedReader(new InputStreamReader(getAssets().open(actualFilename))); | ||
124 | + String line; | ||
125 | + while ((line = br.readLine()) != null) { | ||
126 | + labels.add(line); | ||
127 | + if (line.charAt(0) != '_') { | ||
128 | + displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1)); | ||
129 | + } | ||
130 | + } | ||
131 | + br.close(); | ||
132 | + } catch (IOException e) { | ||
133 | + throw new RuntimeException("Problem reading label file!", e); | ||
134 | + } | ||
135 | + | ||
136 | + // Build a list view based on these labels. | ||
137 | + ArrayAdapter<String> arrayAdapter = | ||
138 | + new ArrayAdapter<String>(this, R.layout.list_text_item, displayedLabels); | ||
139 | + labelsListView.setAdapter(arrayAdapter); | ||
140 | + | ||
141 | + // Set up an object to smooth recognition results to increase accuracy. | ||
142 | + recognizeCommands = | ||
143 | + new RecognizeCommands( | ||
144 | + labels, | ||
145 | + AVERAGE_WINDOW_DURATION_MS, | ||
146 | + DETECTION_THRESHOLD, | ||
147 | + SUPPRESSION_MS, | ||
148 | + MINIMUM_COUNT, | ||
149 | + MINIMUM_TIME_BETWEEN_SAMPLES_MS); | ||
150 | + | ||
151 | + // Load the TensorFlow model. | ||
152 | + inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME); | ||
153 | + | ||
154 | + // Start the recording and recognition threads. | ||
155 | + requestMicrophonePermission(); | ||
156 | + startRecording(); | ||
157 | + startRecognition(); | ||
158 | + } | ||
159 | + | ||
160 | + private void requestMicrophonePermission() { | ||
161 | + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { | ||
162 | + requestPermissions( | ||
163 | + new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO); | ||
164 | + } | ||
165 | + } | ||
166 | + | ||
167 | + @Override | ||
168 | + public void onRequestPermissionsResult( | ||
169 | + int requestCode, String[] permissions, int[] grantResults) { | ||
170 | + if (requestCode == REQUEST_RECORD_AUDIO | ||
171 | + && grantResults.length > 0 | ||
172 | + && grantResults[0] == PackageManager.PERMISSION_GRANTED) { | ||
173 | + startRecording(); | ||
174 | + startRecognition(); | ||
175 | + } | ||
176 | + } | ||
177 | + | ||
178 | + public synchronized void startRecording() { | ||
179 | + if (recordingThread != null) { | ||
180 | + return; | ||
181 | + } | ||
182 | + shouldContinue = true; | ||
183 | + recordingThread = | ||
184 | + new Thread( | ||
185 | + new Runnable() { | ||
186 | + @Override | ||
187 | + public void run() { | ||
188 | + record(); | ||
189 | + } | ||
190 | + }); | ||
191 | + recordingThread.start(); | ||
192 | + } | ||
193 | + | ||
194 | + public synchronized void stopRecording() { | ||
195 | + if (recordingThread == null) { | ||
196 | + return; | ||
197 | + } | ||
198 | + shouldContinue = false; | ||
199 | + recordingThread = null; | ||
200 | + } | ||
201 | + | ||
202 | + private void record() { | ||
203 | + android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO); | ||
204 | + | ||
205 | + // Estimate the buffer size we'll need for this device. | ||
206 | + int bufferSize = | ||
207 | + AudioRecord.getMinBufferSize( | ||
208 | + SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT); | ||
209 | + if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) { | ||
210 | + bufferSize = SAMPLE_RATE * 2; | ||
211 | + } | ||
212 | + short[] audioBuffer = new short[bufferSize / 2]; | ||
213 | + | ||
214 | + AudioRecord record = | ||
215 | + new AudioRecord( | ||
216 | + MediaRecorder.AudioSource.DEFAULT, | ||
217 | + SAMPLE_RATE, | ||
218 | + AudioFormat.CHANNEL_IN_MONO, | ||
219 | + AudioFormat.ENCODING_PCM_16BIT, | ||
220 | + bufferSize); | ||
221 | + | ||
222 | + if (record.getState() != AudioRecord.STATE_INITIALIZED) { | ||
223 | + Log.e(LOG_TAG, "Audio Record can't initialize!"); | ||
224 | + return; | ||
225 | + } | ||
226 | + | ||
227 | + record.startRecording(); | ||
228 | + | ||
229 | + Log.v(LOG_TAG, "Start recording"); | ||
230 | + | ||
231 | + // Loop, gathering audio data and copying it to a round-robin buffer. | ||
232 | + while (shouldContinue) { | ||
233 | + int numberRead = record.read(audioBuffer, 0, audioBuffer.length); | ||
234 | + int maxLength = recordingBuffer.length; | ||
235 | + int newRecordingOffset = recordingOffset + numberRead; | ||
236 | + int secondCopyLength = Math.max(0, newRecordingOffset - maxLength); | ||
237 | + int firstCopyLength = numberRead - secondCopyLength; | ||
238 | + // We store off all the data for the recognition thread to access. The ML | ||
239 | + // thread will copy out of this buffer into its own, while holding the | ||
240 | + // lock, so this should be thread safe. | ||
241 | + recordingBufferLock.lock(); | ||
242 | + try { | ||
243 | + System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength); | ||
244 | + System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength); | ||
245 | + recordingOffset = newRecordingOffset % maxLength; | ||
246 | + } finally { | ||
247 | + recordingBufferLock.unlock(); | ||
248 | + } | ||
249 | + } | ||
250 | + | ||
251 | + record.stop(); | ||
252 | + record.release(); | ||
253 | + } | ||
254 | + | ||
255 | + public synchronized void startRecognition() { | ||
256 | + if (recognitionThread != null) { | ||
257 | + return; | ||
258 | + } | ||
259 | + shouldContinueRecognition = true; | ||
260 | + recognitionThread = | ||
261 | + new Thread( | ||
262 | + new Runnable() { | ||
263 | + @Override | ||
264 | + public void run() { | ||
265 | + recognize(); | ||
266 | + } | ||
267 | + }); | ||
268 | + recognitionThread.start(); | ||
269 | + } | ||
270 | + | ||
271 | + public synchronized void stopRecognition() { | ||
272 | + if (recognitionThread == null) { | ||
273 | + return; | ||
274 | + } | ||
275 | + shouldContinueRecognition = false; | ||
276 | + recognitionThread = null; | ||
277 | + } | ||
278 | + | ||
279 | + private void recognize() { | ||
280 | + Log.v(LOG_TAG, "Start recognition"); | ||
281 | + | ||
282 | + short[] inputBuffer = new short[RECORDING_LENGTH]; | ||
283 | + float[] floatInputBuffer = new float[RECORDING_LENGTH]; | ||
284 | + float[] outputScores = new float[labels.size()]; | ||
285 | + String[] outputScoresNames = new String[] {OUTPUT_SCORES_NAME}; | ||
286 | + int[] sampleRateList = new int[] {SAMPLE_RATE}; | ||
287 | + | ||
288 | + // Loop, grabbing recorded data and running the recognition model on it. | ||
289 | + while (shouldContinueRecognition) { | ||
290 | + // The recording thread places data in this round-robin buffer, so lock to | ||
291 | + // make sure there's no writing happening and then copy it to our own | ||
292 | + // local version. | ||
293 | + recordingBufferLock.lock(); | ||
294 | + try { | ||
295 | + int maxLength = recordingBuffer.length; | ||
296 | + int firstCopyLength = maxLength - recordingOffset; | ||
297 | + int secondCopyLength = recordingOffset; | ||
298 | + System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength); | ||
299 | + System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength); | ||
300 | + } finally { | ||
301 | + recordingBufferLock.unlock(); | ||
302 | + } | ||
303 | + | ||
304 | + // We need to feed in float values between -1.0f and 1.0f, so divide the | ||
305 | + // signed 16-bit inputs. | ||
306 | + for (int i = 0; i < RECORDING_LENGTH; ++i) { | ||
307 | + floatInputBuffer[i] = inputBuffer[i] / 32767.0f; | ||
308 | + } | ||
309 | + | ||
310 | + // Run the model. | ||
311 | + inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList); | ||
312 | + inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1); | ||
313 | + inferenceInterface.run(outputScoresNames); | ||
314 | + inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores); | ||
315 | + | ||
316 | + // Use the smoother to figure out if we've had a real recognition event. | ||
317 | + long currentTime = System.currentTimeMillis(); | ||
318 | + final RecognizeCommands.RecognitionResult result = | ||
319 | + recognizeCommands.processLatestResults(outputScores, currentTime); | ||
320 | + | ||
321 | + runOnUiThread( | ||
322 | + new Runnable() { | ||
323 | + @Override | ||
324 | + public void run() { | ||
325 | + // If we do have a new command, highlight the right list entry. | ||
326 | + if (!result.foundCommand.startsWith("_") && result.isNewCommand) { | ||
327 | + int labelIndex = -1; | ||
328 | + for (int i = 0; i < labels.size(); ++i) { | ||
329 | + if (labels.get(i).equals(result.foundCommand)) { | ||
330 | + labelIndex = i; | ||
331 | + } | ||
332 | + } | ||
333 | + final View labelView = labelsListView.getChildAt(labelIndex - 2); | ||
334 | + | ||
335 | + AnimatorSet colorAnimation = | ||
336 | + (AnimatorSet) | ||
337 | + AnimatorInflater.loadAnimator( | ||
338 | + SpeechActivity.this, R.animator.color_animation); | ||
339 | + colorAnimation.setTarget(labelView); | ||
340 | + colorAnimation.start(); | ||
341 | + } | ||
342 | + } | ||
343 | + }); | ||
344 | + try { | ||
345 | + // We don't need to run too frequently, so snooze for a bit. | ||
346 | + Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS); | ||
347 | + } catch (InterruptedException e) { | ||
348 | + // Ignore | ||
349 | + } | ||
350 | + } | ||
351 | + | ||
352 | + Log.v(LOG_TAG, "End recognition"); | ||
353 | + } | ||
354 | +} |
1 | +/* | ||
2 | + * Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
3 | + * | ||
4 | + * Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | + * you may not use this file except in compliance with the License. | ||
6 | + * You may obtain a copy of the License at | ||
7 | + * | ||
8 | + * http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | + * | ||
10 | + * Unless required by applicable law or agreed to in writing, software | ||
11 | + * distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | + * See the License for the specific language governing permissions and | ||
14 | + * limitations under the License. | ||
15 | + */ | ||
16 | + | ||
17 | +package org.tensorflow.demo; | ||
18 | + | ||
19 | +import android.app.UiModeManager; | ||
20 | +import android.content.Context; | ||
21 | +import android.content.res.AssetManager; | ||
22 | +import android.content.res.Configuration; | ||
23 | +import android.graphics.Bitmap; | ||
24 | +import android.graphics.Bitmap.Config; | ||
25 | +import android.graphics.BitmapFactory; | ||
26 | +import android.graphics.Canvas; | ||
27 | +import android.graphics.Color; | ||
28 | +import android.graphics.Matrix; | ||
29 | +import android.graphics.Paint; | ||
30 | +import android.graphics.Paint.Style; | ||
31 | +import android.graphics.Rect; | ||
32 | +import android.graphics.Typeface; | ||
33 | +import android.media.ImageReader.OnImageAvailableListener; | ||
34 | +import android.os.Bundle; | ||
35 | +import android.os.SystemClock; | ||
36 | +import android.util.DisplayMetrics; | ||
37 | +import android.util.Size; | ||
38 | +import android.util.TypedValue; | ||
39 | +import android.view.Display; | ||
40 | +import android.view.KeyEvent; | ||
41 | +import android.view.MotionEvent; | ||
42 | +import android.view.View; | ||
43 | +import android.view.View.OnClickListener; | ||
44 | +import android.view.View.OnTouchListener; | ||
45 | +import android.view.ViewGroup; | ||
46 | +import android.widget.BaseAdapter; | ||
47 | +import android.widget.Button; | ||
48 | +import android.widget.GridView; | ||
49 | +import android.widget.ImageView; | ||
50 | +import android.widget.RelativeLayout; | ||
51 | +import android.widget.Toast; | ||
52 | +import java.io.IOException; | ||
53 | +import java.io.InputStream; | ||
54 | +import java.util.ArrayList; | ||
55 | +import java.util.Collections; | ||
56 | +import java.util.Vector; | ||
57 | +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; | ||
58 | +import org.tensorflow.demo.OverlayView.DrawCallback; | ||
59 | +import org.tensorflow.demo.env.BorderedText; | ||
60 | +import org.tensorflow.demo.env.ImageUtils; | ||
61 | +import org.tensorflow.demo.env.Logger; | ||
62 | +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds. | ||
63 | + | ||
64 | +/** | ||
65 | + * Sample activity that stylizes the camera preview according to "A Learned Representation For | ||
66 | + * Artistic Style" (https://arxiv.org/abs/1610.07629) | ||
67 | + */ | ||
68 | +public class StylizeActivity extends CameraActivity implements OnImageAvailableListener { | ||
69 | + private static final Logger LOGGER = new Logger(); | ||
70 | + | ||
71 | + private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb"; | ||
72 | + private static final String INPUT_NODE = "input"; | ||
73 | + private static final String STYLE_NODE = "style_num"; | ||
74 | + private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid"; | ||
75 | + private static final int NUM_STYLES = 26; | ||
76 | + | ||
77 | + private static final boolean SAVE_PREVIEW_BITMAP = false; | ||
78 | + | ||
79 | + // Whether to actively manipulate non-selected sliders so that sum of activations always appears | ||
80 | + // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless. | ||
81 | + private static final boolean NORMALIZE_SLIDERS = true; | ||
82 | + | ||
83 | + private static final float TEXT_SIZE_DIP = 12; | ||
84 | + | ||
85 | + private static final boolean DEBUG_MODEL = false; | ||
86 | + | ||
87 | + private static final int[] SIZES = {128, 192, 256, 384, 512, 720}; | ||
88 | + | ||
89 | + private static final Size DESIRED_PREVIEW_SIZE = new Size(1280, 720); | ||
90 | + | ||
91 | + // Start at a medium size, but let the user step up through smaller sizes so they don't get | ||
92 | + // immediately stuck processing a large image. | ||
93 | + private int desiredSizeIndex = -1; | ||
94 | + private int desiredSize = 256; | ||
95 | + private int initializedSize = 0; | ||
96 | + | ||
97 | + private Integer sensorOrientation; | ||
98 | + | ||
99 | + private long lastProcessingTimeMs; | ||
100 | + private Bitmap rgbFrameBitmap = null; | ||
101 | + private Bitmap croppedBitmap = null; | ||
102 | + private Bitmap cropCopyBitmap = null; | ||
103 | + | ||
104 | + private final float[] styleVals = new float[NUM_STYLES]; | ||
105 | + private int[] intValues; | ||
106 | + private float[] floatValues; | ||
107 | + | ||
108 | + private int frameNum = 0; | ||
109 | + | ||
110 | + private Bitmap textureCopyBitmap; | ||
111 | + | ||
112 | + private Matrix frameToCropTransform; | ||
113 | + private Matrix cropToFrameTransform; | ||
114 | + | ||
115 | + private BorderedText borderedText; | ||
116 | + | ||
117 | + private TensorFlowInferenceInterface inferenceInterface; | ||
118 | + | ||
119 | + private int lastOtherStyle = 1; | ||
120 | + | ||
121 | + private boolean allZero = false; | ||
122 | + | ||
123 | + private ImageGridAdapter adapter; | ||
124 | + private GridView grid; | ||
125 | + | ||
126 | + private final OnTouchListener gridTouchAdapter = | ||
127 | + new OnTouchListener() { | ||
128 | + ImageSlider slider = null; | ||
129 | + | ||
130 | + @Override | ||
131 | + public boolean onTouch(final View v, final MotionEvent event) { | ||
132 | + switch (event.getActionMasked()) { | ||
133 | + case MotionEvent.ACTION_DOWN: | ||
134 | + for (int i = 0; i < NUM_STYLES; ++i) { | ||
135 | + final ImageSlider child = adapter.items[i]; | ||
136 | + final Rect rect = new Rect(); | ||
137 | + child.getHitRect(rect); | ||
138 | + if (rect.contains((int) event.getX(), (int) event.getY())) { | ||
139 | + slider = child; | ||
140 | + slider.setHilighted(true); | ||
141 | + } | ||
142 | + } | ||
143 | + break; | ||
144 | + | ||
145 | + case MotionEvent.ACTION_MOVE: | ||
146 | + if (slider != null) { | ||
147 | + final Rect rect = new Rect(); | ||
148 | + slider.getHitRect(rect); | ||
149 | + | ||
150 | + final float newSliderVal = | ||
151 | + (float) | ||
152 | + Math.min( | ||
153 | + 1.0, | ||
154 | + Math.max( | ||
155 | + 0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight())); | ||
156 | + | ||
157 | + setStyle(slider, newSliderVal); | ||
158 | + } | ||
159 | + break; | ||
160 | + | ||
161 | + case MotionEvent.ACTION_UP: | ||
162 | + if (slider != null) { | ||
163 | + slider.setHilighted(false); | ||
164 | + slider = null; | ||
165 | + } | ||
166 | + break; | ||
167 | + | ||
168 | + default: // fall out | ||
169 | + | ||
170 | + } | ||
171 | + return true; | ||
172 | + } | ||
173 | + }; | ||
174 | + | ||
175 | + @Override | ||
176 | + public void onCreate(final Bundle savedInstanceState) { | ||
177 | + super.onCreate(savedInstanceState); | ||
178 | + } | ||
179 | + | ||
180 | + @Override | ||
181 | + protected int getLayoutId() { | ||
182 | + return R.layout.camera_connection_fragment_stylize; | ||
183 | + } | ||
184 | + | ||
185 | + @Override | ||
186 | + protected Size getDesiredPreviewFrameSize() { | ||
187 | + return DESIRED_PREVIEW_SIZE; | ||
188 | + } | ||
189 | + | ||
190 | + public static Bitmap getBitmapFromAsset(final Context context, final String filePath) { | ||
191 | + final AssetManager assetManager = context.getAssets(); | ||
192 | + | ||
193 | + Bitmap bitmap = null; | ||
194 | + try { | ||
195 | + final InputStream inputStream = assetManager.open(filePath); | ||
196 | + bitmap = BitmapFactory.decodeStream(inputStream); | ||
197 | + } catch (final IOException e) { | ||
198 | + LOGGER.e("Error opening bitmap!", e); | ||
199 | + } | ||
200 | + | ||
201 | + return bitmap; | ||
202 | + } | ||
203 | + | ||
204 | + private class ImageSlider extends ImageView { | ||
205 | + private float value = 0.0f; | ||
206 | + private boolean hilighted = false; | ||
207 | + | ||
208 | + private final Paint boxPaint; | ||
209 | + private final Paint linePaint; | ||
210 | + | ||
211 | + public ImageSlider(final Context context) { | ||
212 | + super(context); | ||
213 | + value = 0.0f; | ||
214 | + | ||
215 | + boxPaint = new Paint(); | ||
216 | + boxPaint.setColor(Color.BLACK); | ||
217 | + boxPaint.setAlpha(128); | ||
218 | + | ||
219 | + linePaint = new Paint(); | ||
220 | + linePaint.setColor(Color.WHITE); | ||
221 | + linePaint.setStrokeWidth(10.0f); | ||
222 | + linePaint.setStyle(Style.STROKE); | ||
223 | + } | ||
224 | + | ||
225 | + @Override | ||
226 | + public void onDraw(final Canvas canvas) { | ||
227 | + super.onDraw(canvas); | ||
228 | + final float y = (1.0f - value) * canvas.getHeight(); | ||
229 | + | ||
230 | + // If all sliders are zero, don't bother shading anything. | ||
231 | + if (!allZero) { | ||
232 | + canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint); | ||
233 | + } | ||
234 | + | ||
235 | + if (value > 0.0f) { | ||
236 | + canvas.drawLine(0, y, canvas.getWidth(), y, linePaint); | ||
237 | + } | ||
238 | + | ||
239 | + if (hilighted) { | ||
240 | + canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint); | ||
241 | + } | ||
242 | + } | ||
243 | + | ||
244 | + @Override | ||
245 | + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { | ||
246 | + super.onMeasure(widthMeasureSpec, heightMeasureSpec); | ||
247 | + setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth()); | ||
248 | + } | ||
249 | + | ||
250 | + public void setValue(final float value) { | ||
251 | + this.value = value; | ||
252 | + postInvalidate(); | ||
253 | + } | ||
254 | + | ||
255 | + public void setHilighted(final boolean highlighted) { | ||
256 | + this.hilighted = highlighted; | ||
257 | + this.postInvalidate(); | ||
258 | + } | ||
259 | + } | ||
260 | + | ||
261 | + private class ImageGridAdapter extends BaseAdapter { | ||
262 | + final ImageSlider[] items = new ImageSlider[NUM_STYLES]; | ||
263 | + final ArrayList<Button> buttons = new ArrayList<>(); | ||
264 | + | ||
265 | + { | ||
266 | + final Button sizeButton = | ||
267 | + new Button(StylizeActivity.this) { | ||
268 | + @Override | ||
269 | + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { | ||
270 | + super.onMeasure(widthMeasureSpec, heightMeasureSpec); | ||
271 | + setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth()); | ||
272 | + } | ||
273 | + }; | ||
274 | + sizeButton.setText("" + desiredSize); | ||
275 | + sizeButton.setOnClickListener( | ||
276 | + new OnClickListener() { | ||
277 | + @Override | ||
278 | + public void onClick(final View v) { | ||
279 | + desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length; | ||
280 | + desiredSize = SIZES[desiredSizeIndex]; | ||
281 | + sizeButton.setText("" + desiredSize); | ||
282 | + sizeButton.postInvalidate(); | ||
283 | + } | ||
284 | + }); | ||
285 | + | ||
286 | + final Button saveButton = | ||
287 | + new Button(StylizeActivity.this) { | ||
288 | + @Override | ||
289 | + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { | ||
290 | + super.onMeasure(widthMeasureSpec, heightMeasureSpec); | ||
291 | + setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth()); | ||
292 | + } | ||
293 | + }; | ||
294 | + saveButton.setText("save"); | ||
295 | + saveButton.setTextSize(12); | ||
296 | + | ||
297 | + saveButton.setOnClickListener( | ||
298 | + new OnClickListener() { | ||
299 | + @Override | ||
300 | + public void onClick(final View v) { | ||
301 | + if (textureCopyBitmap != null) { | ||
302 | + // TODO(andrewharp): Save as jpeg with guaranteed unique filename. | ||
303 | + ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png"); | ||
304 | + Toast.makeText( | ||
305 | + StylizeActivity.this, | ||
306 | + "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png", | ||
307 | + Toast.LENGTH_LONG) | ||
308 | + .show(); | ||
309 | + } | ||
310 | + } | ||
311 | + }); | ||
312 | + | ||
313 | + buttons.add(sizeButton); | ||
314 | + buttons.add(saveButton); | ||
315 | + | ||
316 | + for (int i = 0; i < NUM_STYLES; ++i) { | ||
317 | + LOGGER.v("Creating item %d", i); | ||
318 | + | ||
319 | + if (items[i] == null) { | ||
320 | + final ImageSlider slider = new ImageSlider(StylizeActivity.this); | ||
321 | + final Bitmap bm = | ||
322 | + getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg"); | ||
323 | + slider.setImageBitmap(bm); | ||
324 | + | ||
325 | + items[i] = slider; | ||
326 | + } | ||
327 | + } | ||
328 | + } | ||
329 | + | ||
330 | + @Override | ||
331 | + public int getCount() { | ||
332 | + return buttons.size() + NUM_STYLES; | ||
333 | + } | ||
334 | + | ||
335 | + @Override | ||
336 | + public Object getItem(final int position) { | ||
337 | + if (position < buttons.size()) { | ||
338 | + return buttons.get(position); | ||
339 | + } else { | ||
340 | + return items[position - buttons.size()]; | ||
341 | + } | ||
342 | + } | ||
343 | + | ||
344 | + @Override | ||
345 | + public long getItemId(final int position) { | ||
346 | + return getItem(position).hashCode(); | ||
347 | + } | ||
348 | + | ||
349 | + @Override | ||
350 | + public View getView(final int position, final View convertView, final ViewGroup parent) { | ||
351 | + if (convertView != null) { | ||
352 | + return convertView; | ||
353 | + } | ||
354 | + return (View) getItem(position); | ||
355 | + } | ||
356 | + } | ||
357 | + | ||
358 | + @Override | ||
359 | + public void onPreviewSizeChosen(final Size size, final int rotation) { | ||
360 | + final float textSizePx = TypedValue.applyDimension( | ||
361 | + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); | ||
362 | + borderedText = new BorderedText(textSizePx); | ||
363 | + borderedText.setTypeface(Typeface.MONOSPACE); | ||
364 | + | ||
365 | + inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE); | ||
366 | + | ||
367 | + previewWidth = size.getWidth(); | ||
368 | + previewHeight = size.getHeight(); | ||
369 | + | ||
370 | + final Display display = getWindowManager().getDefaultDisplay(); | ||
371 | + final int screenOrientation = display.getRotation(); | ||
372 | + | ||
373 | + LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation); | ||
374 | + | ||
375 | + sensorOrientation = rotation + screenOrientation; | ||
376 | + | ||
377 | + addCallback( | ||
378 | + new DrawCallback() { | ||
379 | + @Override | ||
380 | + public void drawCallback(final Canvas canvas) { | ||
381 | + renderDebug(canvas); | ||
382 | + } | ||
383 | + }); | ||
384 | + | ||
385 | + adapter = new ImageGridAdapter(); | ||
386 | + grid = (GridView) findViewById(R.id.grid_layout); | ||
387 | + grid.setAdapter(adapter); | ||
388 | + grid.setOnTouchListener(gridTouchAdapter); | ||
389 | + | ||
390 | + // Change UI on Android TV | ||
391 | + UiModeManager uiModeManager = (UiModeManager) getSystemService(UI_MODE_SERVICE); | ||
392 | + if (uiModeManager.getCurrentModeType() == Configuration.UI_MODE_TYPE_TELEVISION) { | ||
393 | + DisplayMetrics displayMetrics = new DisplayMetrics(); | ||
394 | + getWindowManager().getDefaultDisplay().getMetrics(displayMetrics); | ||
395 | + int styleSelectorHeight = displayMetrics.heightPixels; | ||
396 | + int styleSelectorWidth = displayMetrics.widthPixels - styleSelectorHeight; | ||
397 | + RelativeLayout.LayoutParams layoutParams = new RelativeLayout.LayoutParams(styleSelectorWidth, ViewGroup.LayoutParams.MATCH_PARENT); | ||
398 | + | ||
399 | + // Calculate number of style in a row, so all the style can show up without scrolling | ||
400 | + int numOfStylePerRow = 3; | ||
401 | + while (styleSelectorWidth / numOfStylePerRow * Math.ceil((float) (adapter.getCount() - 2) / numOfStylePerRow) > styleSelectorHeight) { | ||
402 | + numOfStylePerRow++; | ||
403 | + } | ||
404 | + grid.setNumColumns(numOfStylePerRow); | ||
405 | + layoutParams.addRule(RelativeLayout.ALIGN_PARENT_RIGHT); | ||
406 | + grid.setLayoutParams(layoutParams); | ||
407 | + adapter.buttons.clear(); | ||
408 | + } | ||
409 | + | ||
410 | + setStyle(adapter.items[0], 1.0f); | ||
411 | + } | ||
412 | + | ||
413 | + private void setStyle(final ImageSlider slider, final float value) { | ||
414 | + slider.setValue(value); | ||
415 | + | ||
416 | + if (NORMALIZE_SLIDERS) { | ||
417 | + // Slider vals correspond directly to the input tensor vals, and normalization is visually | ||
418 | + // maintained by remanipulating non-selected sliders. | ||
419 | + float otherSum = 0.0f; | ||
420 | + | ||
421 | + for (int i = 0; i < NUM_STYLES; ++i) { | ||
422 | + if (adapter.items[i] != slider) { | ||
423 | + otherSum += adapter.items[i].value; | ||
424 | + } | ||
425 | + } | ||
426 | + | ||
427 | + if (otherSum > 0.0) { | ||
428 | + float highestOtherVal = 0; | ||
429 | + final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f; | ||
430 | + for (int i = 0; i < NUM_STYLES; ++i) { | ||
431 | + final ImageSlider child = adapter.items[i]; | ||
432 | + if (child == slider) { | ||
433 | + continue; | ||
434 | + } | ||
435 | + final float newVal = child.value * factor; | ||
436 | + child.setValue(newVal > 0.01f ? newVal : 0.0f); | ||
437 | + | ||
438 | + if (child.value > highestOtherVal) { | ||
439 | + lastOtherStyle = i; | ||
440 | + highestOtherVal = child.value; | ||
441 | + } | ||
442 | + } | ||
443 | + } else { | ||
444 | + // Everything else is 0, so just pick a suitable slider to push up when the | ||
445 | + // selected one goes down. | ||
446 | + if (adapter.items[lastOtherStyle] == slider) { | ||
447 | + lastOtherStyle = (lastOtherStyle + 1) % NUM_STYLES; | ||
448 | + } | ||
449 | + adapter.items[lastOtherStyle].setValue(1.0f - value); | ||
450 | + } | ||
451 | + } | ||
452 | + | ||
453 | + final boolean lastAllZero = allZero; | ||
454 | + float sum = 0.0f; | ||
455 | + for (int i = 0; i < NUM_STYLES; ++i) { | ||
456 | + sum += adapter.items[i].value; | ||
457 | + } | ||
458 | + allZero = sum == 0.0f; | ||
459 | + | ||
460 | + // Now update the values used for the input tensor. If nothing is set, mix in everything | ||
461 | + // equally. Otherwise everything is normalized to sum to 1.0. | ||
462 | + for (int i = 0; i < NUM_STYLES; ++i) { | ||
463 | + styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum; | ||
464 | + | ||
465 | + if (lastAllZero != allZero) { | ||
466 | + adapter.items[i].postInvalidate(); | ||
467 | + } | ||
468 | + } | ||
469 | + } | ||
470 | + | ||
471 | + private void resetPreviewBuffers() { | ||
472 | + croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888); | ||
473 | + | ||
474 | + frameToCropTransform = ImageUtils.getTransformationMatrix( | ||
475 | + previewWidth, previewHeight, | ||
476 | + desiredSize, desiredSize, | ||
477 | + sensorOrientation, true); | ||
478 | + | ||
479 | + cropToFrameTransform = new Matrix(); | ||
480 | + frameToCropTransform.invert(cropToFrameTransform); | ||
481 | + intValues = new int[desiredSize * desiredSize]; | ||
482 | + floatValues = new float[desiredSize * desiredSize * 3]; | ||
483 | + initializedSize = desiredSize; | ||
484 | + } | ||
485 | + | ||
486 | + @Override | ||
487 | + protected void processImage() { | ||
488 | + if (desiredSize != initializedSize) { | ||
489 | + LOGGER.i( | ||
490 | + "Initializing at size preview size %dx%d, stylize size %d", | ||
491 | + previewWidth, previewHeight, desiredSize); | ||
492 | + | ||
493 | + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); | ||
494 | + croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888); | ||
495 | + frameToCropTransform = ImageUtils.getTransformationMatrix( | ||
496 | + previewWidth, previewHeight, | ||
497 | + desiredSize, desiredSize, | ||
498 | + sensorOrientation, true); | ||
499 | + | ||
500 | + cropToFrameTransform = new Matrix(); | ||
501 | + frameToCropTransform.invert(cropToFrameTransform); | ||
502 | + intValues = new int[desiredSize * desiredSize]; | ||
503 | + floatValues = new float[desiredSize * desiredSize * 3]; | ||
504 | + initializedSize = desiredSize; | ||
505 | + } | ||
506 | + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); | ||
507 | + final Canvas canvas = new Canvas(croppedBitmap); | ||
508 | + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); | ||
509 | + | ||
510 | + // For examining the actual TF input. | ||
511 | + if (SAVE_PREVIEW_BITMAP) { | ||
512 | + ImageUtils.saveBitmap(croppedBitmap); | ||
513 | + } | ||
514 | + | ||
515 | + runInBackground( | ||
516 | + new Runnable() { | ||
517 | + @Override | ||
518 | + public void run() { | ||
519 | + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); | ||
520 | + final long startTime = SystemClock.uptimeMillis(); | ||
521 | + stylizeImage(croppedBitmap); | ||
522 | + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; | ||
523 | + textureCopyBitmap = Bitmap.createBitmap(croppedBitmap); | ||
524 | + requestRender(); | ||
525 | + readyForNextImage(); | ||
526 | + } | ||
527 | + }); | ||
528 | + if (desiredSize != initializedSize) { | ||
529 | + resetPreviewBuffers(); | ||
530 | + } | ||
531 | + } | ||
532 | + | ||
533 | + private void stylizeImage(final Bitmap bitmap) { | ||
534 | + ++frameNum; | ||
535 | + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); | ||
536 | + | ||
537 | + if (DEBUG_MODEL) { | ||
538 | + // Create a white square that steps through a black background 1 pixel per frame. | ||
539 | + final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth(); | ||
540 | + final int centerY = bitmap.getHeight() / 2; | ||
541 | + final int squareSize = 10; | ||
542 | + for (int i = 0; i < intValues.length; ++i) { | ||
543 | + final int x = i % bitmap.getWidth(); | ||
544 | + final int y = i / bitmap.getHeight(); | ||
545 | + final float val = | ||
546 | + Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f; | ||
547 | + floatValues[i * 3] = val; | ||
548 | + floatValues[i * 3 + 1] = val; | ||
549 | + floatValues[i * 3 + 2] = val; | ||
550 | + } | ||
551 | + } else { | ||
552 | + for (int i = 0; i < intValues.length; ++i) { | ||
553 | + final int val = intValues[i]; | ||
554 | + floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f; | ||
555 | + floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f; | ||
556 | + floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f; | ||
557 | + } | ||
558 | + } | ||
559 | + | ||
560 | + // Copy the input data into TensorFlow. | ||
561 | + LOGGER.i("Width: %s , Height: %s", bitmap.getWidth(), bitmap.getHeight()); | ||
562 | + inferenceInterface.feed( | ||
563 | + INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3); | ||
564 | + inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES); | ||
565 | + | ||
566 | + inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug()); | ||
567 | + inferenceInterface.fetch(OUTPUT_NODE, floatValues); | ||
568 | + | ||
569 | + for (int i = 0; i < intValues.length; ++i) { | ||
570 | + intValues[i] = | ||
571 | + 0xFF000000 | ||
572 | + | (((int) (floatValues[i * 3] * 255)) << 16) | ||
573 | + | (((int) (floatValues[i * 3 + 1] * 255)) << 8) | ||
574 | + | ((int) (floatValues[i * 3 + 2] * 255)); | ||
575 | + } | ||
576 | + | ||
577 | + bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); | ||
578 | + } | ||
579 | + | ||
580 | + private void renderDebug(final Canvas canvas) { | ||
581 | + // TODO(andrewharp): move result display to its own View instead of using debug overlay. | ||
582 | + final Bitmap texture = textureCopyBitmap; | ||
583 | + if (texture != null) { | ||
584 | + final Matrix matrix = new Matrix(); | ||
585 | + final float scaleFactor = | ||
586 | + DEBUG_MODEL | ||
587 | + ? 4.0f | ||
588 | + : Math.min( | ||
589 | + (float) canvas.getWidth() / texture.getWidth(), | ||
590 | + (float) canvas.getHeight() / texture.getHeight()); | ||
591 | + matrix.postScale(scaleFactor, scaleFactor); | ||
592 | + canvas.drawBitmap(texture, matrix, new Paint()); | ||
593 | + } | ||
594 | + | ||
595 | + if (!isDebug()) { | ||
596 | + return; | ||
597 | + } | ||
598 | + | ||
599 | + final Bitmap copy = cropCopyBitmap; | ||
600 | + if (copy == null) { | ||
601 | + return; | ||
602 | + } | ||
603 | + | ||
604 | + canvas.drawColor(0x55000000); | ||
605 | + | ||
606 | + final Matrix matrix = new Matrix(); | ||
607 | + final float scaleFactor = 2; | ||
608 | + matrix.postScale(scaleFactor, scaleFactor); | ||
609 | + matrix.postTranslate( | ||
610 | + canvas.getWidth() - copy.getWidth() * scaleFactor, | ||
611 | + canvas.getHeight() - copy.getHeight() * scaleFactor); | ||
612 | + canvas.drawBitmap(copy, matrix, new Paint()); | ||
613 | + | ||
614 | + final Vector<String> lines = new Vector<>(); | ||
615 | + | ||
616 | + final String[] statLines = inferenceInterface.getStatString().split("\n"); | ||
617 | + Collections.addAll(lines, statLines); | ||
618 | + | ||
619 | + lines.add(""); | ||
620 | + | ||
621 | + lines.add("Frame: " + previewWidth + "x" + previewHeight); | ||
622 | + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); | ||
623 | + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); | ||
624 | + lines.add("Rotation: " + sensorOrientation); | ||
625 | + lines.add("Inference time: " + lastProcessingTimeMs + "ms"); | ||
626 | + lines.add("Desired size: " + desiredSize); | ||
627 | + lines.add("Initialized size: " + initializedSize); | ||
628 | + | ||
629 | + borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); | ||
630 | + } | ||
631 | + | ||
632 | + @Override | ||
633 | + public boolean onKeyDown(int keyCode, KeyEvent event) { | ||
634 | + int moveOffset = 0; | ||
635 | + switch (keyCode) { | ||
636 | + case KeyEvent.KEYCODE_DPAD_LEFT: | ||
637 | + moveOffset = -1; | ||
638 | + break; | ||
639 | + case KeyEvent.KEYCODE_DPAD_RIGHT: | ||
640 | + moveOffset = 1; | ||
641 | + break; | ||
642 | + case KeyEvent.KEYCODE_DPAD_UP: | ||
643 | + moveOffset = -1 * grid.getNumColumns(); | ||
644 | + break; | ||
645 | + case KeyEvent.KEYCODE_DPAD_DOWN: | ||
646 | + moveOffset = grid.getNumColumns(); | ||
647 | + break; | ||
648 | + default: | ||
649 | + return super.onKeyDown(keyCode, event); | ||
650 | + } | ||
651 | + | ||
652 | + // get the highest selected style | ||
653 | + int currentSelect = 0; | ||
654 | + float highestValue = 0; | ||
655 | + for (int i = 0; i < adapter.getCount(); i++) { | ||
656 | + if (adapter.items[i].value > highestValue) { | ||
657 | + currentSelect = i; | ||
658 | + highestValue = adapter.items[i].value; | ||
659 | + } | ||
660 | + } | ||
661 | + setStyle(adapter.items[(currentSelect + moveOffset + adapter.getCount()) % adapter.getCount()], 1); | ||
662 | + | ||
663 | + return true; | ||
664 | + } | ||
665 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.content.res.AssetManager; | ||
19 | +import android.graphics.Bitmap; | ||
20 | +import android.os.Trace; | ||
21 | +import android.util.Log; | ||
22 | +import java.io.BufferedReader; | ||
23 | +import java.io.IOException; | ||
24 | +import java.io.InputStreamReader; | ||
25 | +import java.util.ArrayList; | ||
26 | +import java.util.Comparator; | ||
27 | +import java.util.List; | ||
28 | +import java.util.PriorityQueue; | ||
29 | +import java.util.Vector; | ||
30 | +import org.tensorflow.Operation; | ||
31 | +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; | ||
32 | + | ||
33 | +/** A classifier specialized to label images using TensorFlow. */ | ||
34 | +public class TensorFlowImageClassifier implements Classifier { | ||
35 | + private static final String TAG = "TensorFlowImageClassifier"; | ||
36 | + | ||
37 | + // Only return this many results with at least this confidence. | ||
38 | + private static final int MAX_RESULTS = 3; | ||
39 | + private static final float THRESHOLD = 0.1f; | ||
40 | + | ||
41 | + // Config values. | ||
42 | + private String inputName; | ||
43 | + private String outputName; | ||
44 | + private int inputSize; | ||
45 | + private int imageMean; | ||
46 | + private float imageStd; | ||
47 | + | ||
48 | + // Pre-allocated buffers. | ||
49 | + private Vector<String> labels = new Vector<String>(); | ||
50 | + private int[] intValues; | ||
51 | + private float[] floatValues; | ||
52 | + private float[] outputs; | ||
53 | + private String[] outputNames; | ||
54 | + | ||
55 | + private boolean logStats = false; | ||
56 | + | ||
57 | + private TensorFlowInferenceInterface inferenceInterface; | ||
58 | + | ||
59 | + private TensorFlowImageClassifier() {} | ||
60 | + | ||
61 | + /** | ||
62 | + * Initializes a native TensorFlow session for classifying images. | ||
63 | + * | ||
64 | + * @param assetManager The asset manager to be used to load assets. | ||
65 | + * @param modelFilename The filepath of the model GraphDef protocol buffer. | ||
66 | + * @param labelFilename The filepath of label file for classes. | ||
67 | + * @param inputSize The input size. A square image of inputSize x inputSize is assumed. | ||
68 | + * @param imageMean The assumed mean of the image values. | ||
69 | + * @param imageStd The assumed std of the image values. | ||
70 | + * @param inputName The label of the image input node. | ||
71 | + * @param outputName The label of the output node. | ||
72 | + * @throws IOException | ||
73 | + */ | ||
74 | + public static Classifier create( | ||
75 | + AssetManager assetManager, | ||
76 | + String modelFilename, | ||
77 | + String labelFilename, | ||
78 | + int inputSize, | ||
79 | + int imageMean, | ||
80 | + float imageStd, | ||
81 | + String inputName, | ||
82 | + String outputName) { | ||
83 | + TensorFlowImageClassifier c = new TensorFlowImageClassifier(); | ||
84 | + c.inputName = inputName; | ||
85 | + c.outputName = outputName; | ||
86 | + | ||
87 | + // Read the label names into memory. | ||
88 | + // TODO(andrewharp): make this handle non-assets. | ||
89 | + String actualFilename = labelFilename.split("file:///android_asset/")[1]; | ||
90 | + Log.i(TAG, "Reading labels from: " + actualFilename); | ||
91 | + BufferedReader br = null; | ||
92 | + try { | ||
93 | + br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); | ||
94 | + String line; | ||
95 | + while ((line = br.readLine()) != null) { | ||
96 | + c.labels.add(line); | ||
97 | + } | ||
98 | + br.close(); | ||
99 | + } catch (IOException e) { | ||
100 | + throw new RuntimeException("Problem reading label file!" , e); | ||
101 | + } | ||
102 | + | ||
103 | + c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); | ||
104 | + | ||
105 | + // The shape of the output is [N, NUM_CLASSES], where N is the batch size. | ||
106 | + final Operation operation = c.inferenceInterface.graphOperation(outputName); | ||
107 | + final int numClasses = (int) operation.output(0).shape().size(1); | ||
108 | + Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses); | ||
109 | + | ||
110 | + // Ideally, inputSize could have been retrieved from the shape of the input operation. Alas, | ||
111 | + // the placeholder node for input in the graphdef typically used does not specify a shape, so it | ||
112 | + // must be passed in as a parameter. | ||
113 | + c.inputSize = inputSize; | ||
114 | + c.imageMean = imageMean; | ||
115 | + c.imageStd = imageStd; | ||
116 | + | ||
117 | + // Pre-allocate buffers. | ||
118 | + c.outputNames = new String[] {outputName}; | ||
119 | + c.intValues = new int[inputSize * inputSize]; | ||
120 | + c.floatValues = new float[inputSize * inputSize * 3]; | ||
121 | + c.outputs = new float[numClasses]; | ||
122 | + | ||
123 | + return c; | ||
124 | + } | ||
125 | + | ||
126 | + @Override | ||
127 | + public List<Recognition> recognizeImage(final Bitmap bitmap) { | ||
128 | + // Log this method so that it can be analyzed with systrace. | ||
129 | + Trace.beginSection("recognizeImage"); | ||
130 | + | ||
131 | + Trace.beginSection("preprocessBitmap"); | ||
132 | + // Preprocess the image data from 0-255 int to normalized float based | ||
133 | + // on the provided parameters. | ||
134 | + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); | ||
135 | + for (int i = 0; i < intValues.length; ++i) { | ||
136 | + final int val = intValues[i]; | ||
137 | + floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd; | ||
138 | + floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd; | ||
139 | + floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd; | ||
140 | + } | ||
141 | + Trace.endSection(); | ||
142 | + | ||
143 | + // Copy the input data into TensorFlow. | ||
144 | + Trace.beginSection("feed"); | ||
145 | + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); | ||
146 | + Trace.endSection(); | ||
147 | + | ||
148 | + // Run the inference call. | ||
149 | + Trace.beginSection("run"); | ||
150 | + inferenceInterface.run(outputNames, logStats); | ||
151 | + Trace.endSection(); | ||
152 | + | ||
153 | + // Copy the output Tensor back into the output array. | ||
154 | + Trace.beginSection("fetch"); | ||
155 | + inferenceInterface.fetch(outputName, outputs); | ||
156 | + Trace.endSection(); | ||
157 | + | ||
158 | + // Find the best classifications. | ||
159 | + PriorityQueue<Recognition> pq = | ||
160 | + new PriorityQueue<Recognition>( | ||
161 | + 3, | ||
162 | + new Comparator<Recognition>() { | ||
163 | + @Override | ||
164 | + public int compare(Recognition lhs, Recognition rhs) { | ||
165 | + // Intentionally reversed to put high confidence at the head of the queue. | ||
166 | + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); | ||
167 | + } | ||
168 | + }); | ||
169 | + for (int i = 0; i < outputs.length; ++i) { | ||
170 | + if (outputs[i] > THRESHOLD) { | ||
171 | + pq.add( | ||
172 | + new Recognition( | ||
173 | + "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null)); | ||
174 | + } | ||
175 | + } | ||
176 | + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); | ||
177 | + int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); | ||
178 | + for (int i = 0; i < recognitionsSize; ++i) { | ||
179 | + recognitions.add(pq.poll()); | ||
180 | + } | ||
181 | + Trace.endSection(); // "recognizeImage" | ||
182 | + return recognitions; | ||
183 | + } | ||
184 | + | ||
185 | + @Override | ||
186 | + public void enableStatLogging(boolean logStats) { | ||
187 | + this.logStats = logStats; | ||
188 | + } | ||
189 | + | ||
190 | + @Override | ||
191 | + public String getStatString() { | ||
192 | + return inferenceInterface.getStatString(); | ||
193 | + } | ||
194 | + | ||
195 | + @Override | ||
196 | + public void close() { | ||
197 | + inferenceInterface.close(); | ||
198 | + } | ||
199 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.content.res.AssetManager; | ||
19 | +import android.graphics.Bitmap; | ||
20 | +import android.graphics.RectF; | ||
21 | +import android.os.Trace; | ||
22 | +import java.io.BufferedReader; | ||
23 | +import java.io.FileInputStream; | ||
24 | +import java.io.IOException; | ||
25 | +import java.io.InputStream; | ||
26 | +import java.io.InputStreamReader; | ||
27 | +import java.util.ArrayList; | ||
28 | +import java.util.Comparator; | ||
29 | +import java.util.List; | ||
30 | +import java.util.PriorityQueue; | ||
31 | +import java.util.StringTokenizer; | ||
32 | +import org.tensorflow.Graph; | ||
33 | +import org.tensorflow.Operation; | ||
34 | +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; | ||
35 | +import org.tensorflow.demo.env.Logger; | ||
36 | + | ||
37 | +/** | ||
38 | + * A detector for general purpose object detection as described in Scalable Object Detection using | ||
39 | + * Deep Neural Networks (https://arxiv.org/abs/1312.2249). | ||
40 | + */ | ||
41 | +public class TensorFlowMultiBoxDetector implements Classifier { | ||
42 | + private static final Logger LOGGER = new Logger(); | ||
43 | + | ||
44 | + // Only return this many results. | ||
45 | + private static final int MAX_RESULTS = Integer.MAX_VALUE; | ||
46 | + | ||
47 | + // Config values. | ||
48 | + private String inputName; | ||
49 | + private int inputSize; | ||
50 | + private int imageMean; | ||
51 | + private float imageStd; | ||
52 | + | ||
53 | + // Pre-allocated buffers. | ||
54 | + private int[] intValues; | ||
55 | + private float[] floatValues; | ||
56 | + private float[] outputLocations; | ||
57 | + private float[] outputScores; | ||
58 | + private String[] outputNames; | ||
59 | + private int numLocations; | ||
60 | + | ||
61 | + private boolean logStats = false; | ||
62 | + | ||
63 | + private TensorFlowInferenceInterface inferenceInterface; | ||
64 | + | ||
65 | + private float[] boxPriors; | ||
66 | + | ||
67 | + /** | ||
68 | + * Initializes a native TensorFlow session for classifying images. | ||
69 | + * | ||
70 | + * @param assetManager The asset manager to be used to load assets. | ||
71 | + * @param modelFilename The filepath of the model GraphDef protocol buffer. | ||
72 | + * @param locationFilename The filepath of label file for classes. | ||
73 | + * @param inputSize The input size. A square image of inputSize x inputSize is assumed. | ||
74 | + * @param imageMean The assumed mean of the image values. | ||
75 | + * @param imageStd The assumed std of the image values. | ||
76 | + * @param inputName The label of the image input node. | ||
77 | + * @param outputName The label of the output node. | ||
78 | + */ | ||
79 | + public static Classifier create( | ||
80 | + final AssetManager assetManager, | ||
81 | + final String modelFilename, | ||
82 | + final String locationFilename, | ||
83 | + final int imageMean, | ||
84 | + final float imageStd, | ||
85 | + final String inputName, | ||
86 | + final String outputLocationsName, | ||
87 | + final String outputScoresName) { | ||
88 | + final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector(); | ||
89 | + | ||
90 | + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); | ||
91 | + | ||
92 | + final Graph g = d.inferenceInterface.graph(); | ||
93 | + | ||
94 | + d.inputName = inputName; | ||
95 | + // The inputName node has a shape of [N, H, W, C], where | ||
96 | + // N is the batch size | ||
97 | + // H = W are the height and width | ||
98 | + // C is the number of channels (3 for our purposes - RGB) | ||
99 | + final Operation inputOp = g.operation(inputName); | ||
100 | + if (inputOp == null) { | ||
101 | + throw new RuntimeException("Failed to find input Node '" + inputName + "'"); | ||
102 | + } | ||
103 | + d.inputSize = (int) inputOp.output(0).shape().size(1); | ||
104 | + d.imageMean = imageMean; | ||
105 | + d.imageStd = imageStd; | ||
106 | + // The outputScoresName node has a shape of [N, NumLocations], where N | ||
107 | + // is the batch size. | ||
108 | + final Operation outputOp = g.operation(outputScoresName); | ||
109 | + if (outputOp == null) { | ||
110 | + throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'"); | ||
111 | + } | ||
112 | + d.numLocations = (int) outputOp.output(0).shape().size(1); | ||
113 | + | ||
114 | + d.boxPriors = new float[d.numLocations * 8]; | ||
115 | + | ||
116 | + try { | ||
117 | + d.loadCoderOptions(assetManager, locationFilename, d.boxPriors); | ||
118 | + } catch (final IOException e) { | ||
119 | + throw new RuntimeException("Error initializing box priors from " + locationFilename); | ||
120 | + } | ||
121 | + | ||
122 | + // Pre-allocate buffers. | ||
123 | + d.outputNames = new String[] {outputLocationsName, outputScoresName}; | ||
124 | + d.intValues = new int[d.inputSize * d.inputSize]; | ||
125 | + d.floatValues = new float[d.inputSize * d.inputSize * 3]; | ||
126 | + d.outputScores = new float[d.numLocations]; | ||
127 | + d.outputLocations = new float[d.numLocations * 4]; | ||
128 | + | ||
129 | + return d; | ||
130 | + } | ||
131 | + | ||
132 | + private TensorFlowMultiBoxDetector() {} | ||
133 | + | ||
134 | + private void loadCoderOptions( | ||
135 | + final AssetManager assetManager, final String locationFilename, final float[] boxPriors) | ||
136 | + throws IOException { | ||
137 | + // Try to be intelligent about opening from assets or sdcard depending on prefix. | ||
138 | + final String assetPrefix = "file:///android_asset/"; | ||
139 | + InputStream is; | ||
140 | + if (locationFilename.startsWith(assetPrefix)) { | ||
141 | + is = assetManager.open(locationFilename.split(assetPrefix)[1]); | ||
142 | + } else { | ||
143 | + is = new FileInputStream(locationFilename); | ||
144 | + } | ||
145 | + | ||
146 | + // Read values. Number of values per line doesn't matter, as long as they are separated | ||
147 | + // by commas and/or whitespace, and there are exactly numLocations * 8 values total. | ||
148 | + // Values are in the order mean, std for each consecutive corner of each box, for a total of 8 | ||
149 | + // per location. | ||
150 | + final BufferedReader reader = new BufferedReader(new InputStreamReader(is)); | ||
151 | + int priorIndex = 0; | ||
152 | + String line; | ||
153 | + while ((line = reader.readLine()) != null) { | ||
154 | + final StringTokenizer st = new StringTokenizer(line, ", "); | ||
155 | + while (st.hasMoreTokens()) { | ||
156 | + final String token = st.nextToken(); | ||
157 | + try { | ||
158 | + final float number = Float.parseFloat(token); | ||
159 | + boxPriors[priorIndex++] = number; | ||
160 | + } catch (final NumberFormatException e) { | ||
161 | + // Silently ignore. | ||
162 | + } | ||
163 | + } | ||
164 | + } | ||
165 | + if (priorIndex != boxPriors.length) { | ||
166 | + throw new RuntimeException( | ||
167 | + "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length); | ||
168 | + } | ||
169 | + } | ||
170 | + | ||
171 | + private float[] decodeLocationsEncoding(final float[] locationEncoding) { | ||
172 | + final float[] locations = new float[locationEncoding.length]; | ||
173 | + boolean nonZero = false; | ||
174 | + for (int i = 0; i < numLocations; ++i) { | ||
175 | + for (int j = 0; j < 4; ++j) { | ||
176 | + final float currEncoding = locationEncoding[4 * i + j]; | ||
177 | + nonZero = nonZero || currEncoding != 0.0f; | ||
178 | + | ||
179 | + final float mean = boxPriors[i * 8 + j * 2]; | ||
180 | + final float stdDev = boxPriors[i * 8 + j * 2 + 1]; | ||
181 | + float currentLocation = currEncoding * stdDev + mean; | ||
182 | + currentLocation = Math.max(currentLocation, 0.0f); | ||
183 | + currentLocation = Math.min(currentLocation, 1.0f); | ||
184 | + locations[4 * i + j] = currentLocation; | ||
185 | + } | ||
186 | + } | ||
187 | + | ||
188 | + if (!nonZero) { | ||
189 | + LOGGER.w("No non-zero encodings; check log for inference errors."); | ||
190 | + } | ||
191 | + return locations; | ||
192 | + } | ||
193 | + | ||
194 | + private float[] decodeScoresEncoding(final float[] scoresEncoding) { | ||
195 | + final float[] scores = new float[scoresEncoding.length]; | ||
196 | + for (int i = 0; i < scoresEncoding.length; ++i) { | ||
197 | + scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i]))); | ||
198 | + } | ||
199 | + return scores; | ||
200 | + } | ||
201 | + | ||
202 | + @Override | ||
203 | + public List<Recognition> recognizeImage(final Bitmap bitmap) { | ||
204 | + // Log this method so that it can be analyzed with systrace. | ||
205 | + Trace.beginSection("recognizeImage"); | ||
206 | + | ||
207 | + Trace.beginSection("preprocessBitmap"); | ||
208 | + // Preprocess the image data from 0-255 int to normalized float based | ||
209 | + // on the provided parameters. | ||
210 | + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); | ||
211 | + | ||
212 | + for (int i = 0; i < intValues.length; ++i) { | ||
213 | + floatValues[i * 3 + 0] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd; | ||
214 | + floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd; | ||
215 | + floatValues[i * 3 + 2] = ((intValues[i] & 0xFF) - imageMean) / imageStd; | ||
216 | + } | ||
217 | + Trace.endSection(); // preprocessBitmap | ||
218 | + | ||
219 | + // Copy the input data into TensorFlow. | ||
220 | + Trace.beginSection("feed"); | ||
221 | + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); | ||
222 | + Trace.endSection(); | ||
223 | + | ||
224 | + // Run the inference call. | ||
225 | + Trace.beginSection("run"); | ||
226 | + inferenceInterface.run(outputNames, logStats); | ||
227 | + Trace.endSection(); | ||
228 | + | ||
229 | + // Copy the output Tensor back into the output array. | ||
230 | + Trace.beginSection("fetch"); | ||
231 | + final float[] outputScoresEncoding = new float[numLocations]; | ||
232 | + final float[] outputLocationsEncoding = new float[numLocations * 4]; | ||
233 | + inferenceInterface.fetch(outputNames[0], outputLocationsEncoding); | ||
234 | + inferenceInterface.fetch(outputNames[1], outputScoresEncoding); | ||
235 | + Trace.endSection(); | ||
236 | + | ||
237 | + outputLocations = decodeLocationsEncoding(outputLocationsEncoding); | ||
238 | + outputScores = decodeScoresEncoding(outputScoresEncoding); | ||
239 | + | ||
240 | + // Find the best detections. | ||
241 | + final PriorityQueue<Recognition> pq = | ||
242 | + new PriorityQueue<Recognition>( | ||
243 | + 1, | ||
244 | + new Comparator<Recognition>() { | ||
245 | + @Override | ||
246 | + public int compare(final Recognition lhs, final Recognition rhs) { | ||
247 | + // Intentionally reversed to put high confidence at the head of the queue. | ||
248 | + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); | ||
249 | + } | ||
250 | + }); | ||
251 | + | ||
252 | + // Scale them back to the input size. | ||
253 | + for (int i = 0; i < outputScores.length; ++i) { | ||
254 | + final RectF detection = | ||
255 | + new RectF( | ||
256 | + outputLocations[4 * i] * inputSize, | ||
257 | + outputLocations[4 * i + 1] * inputSize, | ||
258 | + outputLocations[4 * i + 2] * inputSize, | ||
259 | + outputLocations[4 * i + 3] * inputSize); | ||
260 | + pq.add(new Recognition("" + i, null, outputScores[i], detection)); | ||
261 | + } | ||
262 | + | ||
263 | + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); | ||
264 | + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { | ||
265 | + recognitions.add(pq.poll()); | ||
266 | + } | ||
267 | + Trace.endSection(); // "recognizeImage" | ||
268 | + return recognitions; | ||
269 | + } | ||
270 | + | ||
271 | + @Override | ||
272 | + public void enableStatLogging(final boolean logStats) { | ||
273 | + this.logStats = logStats; | ||
274 | + } | ||
275 | + | ||
276 | + @Override | ||
277 | + public String getStatString() { | ||
278 | + return inferenceInterface.getStatString(); | ||
279 | + } | ||
280 | + | ||
281 | + @Override | ||
282 | + public void close() { | ||
283 | + inferenceInterface.close(); | ||
284 | + } | ||
285 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.content.res.AssetManager; | ||
19 | +import android.graphics.Bitmap; | ||
20 | +import android.graphics.RectF; | ||
21 | +import android.os.Trace; | ||
22 | +import java.io.BufferedReader; | ||
23 | +import java.io.IOException; | ||
24 | +import java.io.InputStream; | ||
25 | +import java.io.InputStreamReader; | ||
26 | +import java.util.ArrayList; | ||
27 | +import java.util.Comparator; | ||
28 | +import java.util.List; | ||
29 | +import java.util.PriorityQueue; | ||
30 | +import java.util.Vector; | ||
31 | +import org.tensorflow.Graph; | ||
32 | +import org.tensorflow.Operation; | ||
33 | +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; | ||
34 | +import org.tensorflow.demo.env.Logger; | ||
35 | + | ||
36 | +/** | ||
37 | + * Wrapper for frozen detection models trained using the Tensorflow Object Detection API: | ||
38 | + * github.com/tensorflow/models/tree/master/research/object_detection | ||
39 | + */ | ||
40 | +public class TensorFlowObjectDetectionAPIModel implements Classifier { | ||
41 | + private static final Logger LOGGER = new Logger(); | ||
42 | + | ||
43 | + // Only return this many results. | ||
44 | + private static final int MAX_RESULTS = 100; | ||
45 | + | ||
46 | + // Config values. | ||
47 | + private String inputName; | ||
48 | + private int inputSize; | ||
49 | + | ||
50 | + // Pre-allocated buffers. | ||
51 | + private Vector<String> labels = new Vector<String>(); | ||
52 | + private int[] intValues; | ||
53 | + private byte[] byteValues; | ||
54 | + private float[] outputLocations; | ||
55 | + private float[] outputScores; | ||
56 | + private float[] outputClasses; | ||
57 | + private float[] outputNumDetections; | ||
58 | + private String[] outputNames; | ||
59 | + | ||
60 | + private boolean logStats = false; | ||
61 | + | ||
62 | + private TensorFlowInferenceInterface inferenceInterface; | ||
63 | + | ||
64 | + /** | ||
65 | + * Initializes a native TensorFlow session for classifying images. | ||
66 | + * | ||
67 | + * @param assetManager The asset manager to be used to load assets. | ||
68 | + * @param modelFilename The filepath of the model GraphDef protocol buffer. | ||
69 | + * @param labelFilename The filepath of label file for classes. | ||
70 | + */ | ||
71 | + public static Classifier create( | ||
72 | + final AssetManager assetManager, | ||
73 | + final String modelFilename, | ||
74 | + final String labelFilename, | ||
75 | + final int inputSize) throws IOException { | ||
76 | + final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel(); | ||
77 | + | ||
78 | + InputStream labelsInput = null; | ||
79 | + String actualFilename = labelFilename.split("file:///android_asset/")[1]; | ||
80 | + labelsInput = assetManager.open(actualFilename); | ||
81 | + BufferedReader br = null; | ||
82 | + br = new BufferedReader(new InputStreamReader(labelsInput)); | ||
83 | + String line; | ||
84 | + while ((line = br.readLine()) != null) { | ||
85 | + LOGGER.w(line); | ||
86 | + d.labels.add(line); | ||
87 | + } | ||
88 | + br.close(); | ||
89 | + | ||
90 | + | ||
91 | + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); | ||
92 | + | ||
93 | + final Graph g = d.inferenceInterface.graph(); | ||
94 | + | ||
95 | + d.inputName = "image_tensor"; | ||
96 | + // The inputName node has a shape of [N, H, W, C], where | ||
97 | + // N is the batch size | ||
98 | + // H = W are the height and width | ||
99 | + // C is the number of channels (3 for our purposes - RGB) | ||
100 | + final Operation inputOp = g.operation(d.inputName); | ||
101 | + if (inputOp == null) { | ||
102 | + throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); | ||
103 | + } | ||
104 | + d.inputSize = inputSize; | ||
105 | + // The outputScoresName node has a shape of [N, NumLocations], where N | ||
106 | + // is the batch size. | ||
107 | + final Operation outputOp1 = g.operation("detection_scores"); | ||
108 | + if (outputOp1 == null) { | ||
109 | + throw new RuntimeException("Failed to find output Node 'detection_scores'"); | ||
110 | + } | ||
111 | + final Operation outputOp2 = g.operation("detection_boxes"); | ||
112 | + if (outputOp2 == null) { | ||
113 | + throw new RuntimeException("Failed to find output Node 'detection_boxes'"); | ||
114 | + } | ||
115 | + final Operation outputOp3 = g.operation("detection_classes"); | ||
116 | + if (outputOp3 == null) { | ||
117 | + throw new RuntimeException("Failed to find output Node 'detection_classes'"); | ||
118 | + } | ||
119 | + | ||
120 | + // Pre-allocate buffers. | ||
121 | + d.outputNames = new String[] {"detection_boxes", "detection_scores", | ||
122 | + "detection_classes", "num_detections"}; | ||
123 | + d.intValues = new int[d.inputSize * d.inputSize]; | ||
124 | + d.byteValues = new byte[d.inputSize * d.inputSize * 3]; | ||
125 | + d.outputScores = new float[MAX_RESULTS]; | ||
126 | + d.outputLocations = new float[MAX_RESULTS * 4]; | ||
127 | + d.outputClasses = new float[MAX_RESULTS]; | ||
128 | + d.outputNumDetections = new float[1]; | ||
129 | + return d; | ||
130 | + } | ||
131 | + | ||
132 | + private TensorFlowObjectDetectionAPIModel() {} | ||
133 | + | ||
134 | + @Override | ||
135 | + public List<Recognition> recognizeImage(final Bitmap bitmap) { | ||
136 | + // Log this method so that it can be analyzed with systrace. | ||
137 | + Trace.beginSection("recognizeImage"); | ||
138 | + | ||
139 | + Trace.beginSection("preprocessBitmap"); | ||
140 | + // Preprocess the image data to extract R, G and B bytes from int of form 0x00RRGGBB | ||
141 | + // on the provided parameters. | ||
142 | + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); | ||
143 | + | ||
144 | + for (int i = 0; i < intValues.length; ++i) { | ||
145 | + byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); | ||
146 | + byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); | ||
147 | + byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); | ||
148 | + } | ||
149 | + Trace.endSection(); // preprocessBitmap | ||
150 | + | ||
151 | + // Copy the input data into TensorFlow. | ||
152 | + Trace.beginSection("feed"); | ||
153 | + inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3); | ||
154 | + Trace.endSection(); | ||
155 | + | ||
156 | + // Run the inference call. | ||
157 | + Trace.beginSection("run"); | ||
158 | + inferenceInterface.run(outputNames, logStats); | ||
159 | + Trace.endSection(); | ||
160 | + | ||
161 | + // Copy the output Tensor back into the output array. | ||
162 | + Trace.beginSection("fetch"); | ||
163 | + outputLocations = new float[MAX_RESULTS * 4]; | ||
164 | + outputScores = new float[MAX_RESULTS]; | ||
165 | + outputClasses = new float[MAX_RESULTS]; | ||
166 | + outputNumDetections = new float[1]; | ||
167 | + inferenceInterface.fetch(outputNames[0], outputLocations); | ||
168 | + inferenceInterface.fetch(outputNames[1], outputScores); | ||
169 | + inferenceInterface.fetch(outputNames[2], outputClasses); | ||
170 | + inferenceInterface.fetch(outputNames[3], outputNumDetections); | ||
171 | + Trace.endSection(); | ||
172 | + | ||
173 | + // Find the best detections. | ||
174 | + final PriorityQueue<Recognition> pq = | ||
175 | + new PriorityQueue<Recognition>( | ||
176 | + 1, | ||
177 | + new Comparator<Recognition>() { | ||
178 | + @Override | ||
179 | + public int compare(final Recognition lhs, final Recognition rhs) { | ||
180 | + // Intentionally reversed to put high confidence at the head of the queue. | ||
181 | + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); | ||
182 | + } | ||
183 | + }); | ||
184 | + | ||
185 | + // Scale them back to the input size. | ||
186 | + for (int i = 0; i < outputScores.length; ++i) { | ||
187 | + final RectF detection = | ||
188 | + new RectF( | ||
189 | + outputLocations[4 * i + 1] * inputSize, | ||
190 | + outputLocations[4 * i] * inputSize, | ||
191 | + outputLocations[4 * i + 3] * inputSize, | ||
192 | + outputLocations[4 * i + 2] * inputSize); | ||
193 | + pq.add( | ||
194 | + new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection)); | ||
195 | + } | ||
196 | + | ||
197 | + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); | ||
198 | + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { | ||
199 | + recognitions.add(pq.poll()); | ||
200 | + } | ||
201 | + Trace.endSection(); // "recognizeImage" | ||
202 | + return recognitions; | ||
203 | + } | ||
204 | + | ||
205 | + @Override | ||
206 | + public void enableStatLogging(final boolean logStats) { | ||
207 | + this.logStats = logStats; | ||
208 | + } | ||
209 | + | ||
210 | + @Override | ||
211 | + public String getStatString() { | ||
212 | + return inferenceInterface.getStatString(); | ||
213 | + } | ||
214 | + | ||
215 | + @Override | ||
216 | + public void close() { | ||
217 | + inferenceInterface.close(); | ||
218 | + } | ||
219 | +} |
1 | +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo; | ||
17 | + | ||
18 | +import android.content.res.AssetManager; | ||
19 | +import android.graphics.Bitmap; | ||
20 | +import android.graphics.RectF; | ||
21 | +import android.os.Trace; | ||
22 | +import java.util.ArrayList; | ||
23 | +import java.util.Comparator; | ||
24 | +import java.util.List; | ||
25 | +import java.util.PriorityQueue; | ||
26 | +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; | ||
27 | +import org.tensorflow.demo.env.Logger; | ||
28 | +import org.tensorflow.demo.env.SplitTimer; | ||
29 | + | ||
30 | +/** An object detector that uses TF and a YOLO model to detect objects. */ | ||
31 | +public class TensorFlowYoloDetector implements Classifier { | ||
32 | + private static final Logger LOGGER = new Logger(); | ||
33 | + | ||
34 | + // Only return this many results with at least this confidence. | ||
35 | + private static final int MAX_RESULTS = 5; | ||
36 | + | ||
37 | + private static final int NUM_CLASSES = 1; | ||
38 | + | ||
39 | + private static final int NUM_BOXES_PER_BLOCK = 5; | ||
40 | + | ||
41 | + // TODO(andrewharp): allow loading anchors and classes | ||
42 | + // from files. | ||
43 | + private static final double[] ANCHORS = { | ||
44 | + 1.08, 1.19, | ||
45 | + 3.42, 4.41, | ||
46 | + 6.63, 11.38, | ||
47 | + 9.42, 5.11, | ||
48 | + 16.62, 10.52 | ||
49 | + }; | ||
50 | + | ||
51 | + private static final String[] LABELS = { | ||
52 | + "dog" | ||
53 | + }; | ||
54 | + | ||
55 | + // Config values. | ||
56 | + private String inputName; | ||
57 | + private int inputSize; | ||
58 | + | ||
59 | + // Pre-allocated buffers. | ||
60 | + private int[] intValues; | ||
61 | + private float[] floatValues; | ||
62 | + private String[] outputNames; | ||
63 | + | ||
64 | + private int blockSize; | ||
65 | + | ||
66 | + private boolean logStats = false; | ||
67 | + | ||
68 | + private TensorFlowInferenceInterface inferenceInterface; | ||
69 | + | ||
70 | + /** Initializes a native TensorFlow session for classifying images. */ | ||
71 | + public static Classifier create( | ||
72 | + final AssetManager assetManager, | ||
73 | + final String modelFilename, | ||
74 | + final int inputSize, | ||
75 | + final String inputName, | ||
76 | + final String outputName, | ||
77 | + final int blockSize) { | ||
78 | + TensorFlowYoloDetector d = new TensorFlowYoloDetector(); | ||
79 | + d.inputName = inputName; | ||
80 | + d.inputSize = inputSize; | ||
81 | + | ||
82 | + // Pre-allocate buffers. | ||
83 | + d.outputNames = outputName.split(","); | ||
84 | + d.intValues = new int[inputSize * inputSize]; | ||
85 | + d.floatValues = new float[inputSize * inputSize * 3]; | ||
86 | + d.blockSize = blockSize; | ||
87 | + | ||
88 | + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); | ||
89 | + | ||
90 | + return d; | ||
91 | + } | ||
92 | + | ||
93 | + private TensorFlowYoloDetector() {} | ||
94 | + | ||
95 | + private float expit(final float x) { | ||
96 | + return (float) (1. / (1. + Math.exp(-x))); | ||
97 | + } | ||
98 | + | ||
99 | + private void softmax(final float[] vals) { | ||
100 | + float max = Float.NEGATIVE_INFINITY; | ||
101 | + for (final float val : vals) { | ||
102 | + max = Math.max(max, val); | ||
103 | + } | ||
104 | + float sum = 0.0f; | ||
105 | + for (int i = 0; i < vals.length; ++i) { | ||
106 | + vals[i] = (float) Math.exp(vals[i] - max); | ||
107 | + sum += vals[i]; | ||
108 | + } | ||
109 | + for (int i = 0; i < vals.length; ++i) { | ||
110 | + vals[i] = vals[i] / sum; | ||
111 | + } | ||
112 | + } | ||
113 | + | ||
114 | + @Override | ||
115 | + public List<Recognition> recognizeImage(final Bitmap bitmap) { | ||
116 | + final SplitTimer timer = new SplitTimer("recognizeImage"); | ||
117 | + | ||
118 | + // Log this method so that it can be analyzed with systrace. | ||
119 | + Trace.beginSection("recognizeImage"); | ||
120 | + | ||
121 | + Trace.beginSection("preprocessBitmap"); | ||
122 | + // Preprocess the image data from 0-255 int to normalized float based | ||
123 | + // on the provided parameters. | ||
124 | + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); | ||
125 | + | ||
126 | + for (int i = 0; i < intValues.length; ++i) { | ||
127 | + floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f; | ||
128 | + floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f; | ||
129 | + floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f; | ||
130 | + } | ||
131 | + Trace.endSection(); // preprocessBitmap | ||
132 | + | ||
133 | + // Copy the input data into TensorFlow. | ||
134 | + Trace.beginSection("feed"); | ||
135 | + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); | ||
136 | + Trace.endSection(); | ||
137 | + | ||
138 | + timer.endSplit("ready for inference"); | ||
139 | + | ||
140 | + // Run the inference call. | ||
141 | + Trace.beginSection("run"); | ||
142 | + inferenceInterface.run(outputNames, logStats); | ||
143 | + Trace.endSection(); | ||
144 | + | ||
145 | + timer.endSplit("ran inference"); | ||
146 | + | ||
147 | + // Copy the output Tensor back into the output array. | ||
148 | + Trace.beginSection("fetch"); | ||
149 | + final int gridWidth = bitmap.getWidth() / blockSize; | ||
150 | + final int gridHeight = bitmap.getHeight() / blockSize; | ||
151 | + final float[] output = | ||
152 | + new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK]; | ||
153 | + inferenceInterface.fetch(outputNames[0], output); | ||
154 | + Trace.endSection(); | ||
155 | + | ||
156 | + // Find the best detections. | ||
157 | + final PriorityQueue<Recognition> pq = | ||
158 | + new PriorityQueue<Recognition>( | ||
159 | + 1, | ||
160 | + new Comparator<Recognition>() { | ||
161 | + @Override | ||
162 | + public int compare(final Recognition lhs, final Recognition rhs) { | ||
163 | + // Intentionally reversed to put high confidence at the head of the queue. | ||
164 | + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); | ||
165 | + } | ||
166 | + }); | ||
167 | + | ||
168 | + for (int y = 0; y < gridHeight; ++y) { | ||
169 | + for (int x = 0; x < gridWidth; ++x) { | ||
170 | + for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) { | ||
171 | + final int offset = | ||
172 | + (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y | ||
173 | + + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x | ||
174 | + + (NUM_CLASSES + 5) * b; | ||
175 | + | ||
176 | + final float xPos = (x + expit(output[offset + 0])) * blockSize; | ||
177 | + final float yPos = (y + expit(output[offset + 1])) * blockSize; | ||
178 | + | ||
179 | + final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize; | ||
180 | + final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize; | ||
181 | + | ||
182 | + final RectF rect = | ||
183 | + new RectF( | ||
184 | + Math.max(0, xPos - w / 2), | ||
185 | + Math.max(0, yPos - h / 2), | ||
186 | + Math.min(bitmap.getWidth() - 1, xPos + w / 2), | ||
187 | + Math.min(bitmap.getHeight() - 1, yPos + h / 2)); | ||
188 | + final float confidence = expit(output[offset + 4]); | ||
189 | + | ||
190 | + int detectedClass = -1; | ||
191 | + float maxClass = 0; | ||
192 | + | ||
193 | + final float[] classes = new float[NUM_CLASSES]; | ||
194 | + for (int c = 0; c < NUM_CLASSES; ++c) { | ||
195 | + classes[c] = output[offset + 5 + c]; | ||
196 | + } | ||
197 | + softmax(classes); | ||
198 | + | ||
199 | + for (int c = 0; c < NUM_CLASSES; ++c) { | ||
200 | + if (classes[c] > maxClass) { | ||
201 | + detectedClass = c; | ||
202 | + maxClass = classes[c]; | ||
203 | + } | ||
204 | + } | ||
205 | + | ||
206 | + final float confidenceInClass = maxClass * confidence; | ||
207 | + if (confidenceInClass > 0.01) { | ||
208 | + LOGGER.i( | ||
209 | + "%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect); | ||
210 | + pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect)); | ||
211 | + } | ||
212 | + } | ||
213 | + } | ||
214 | + } | ||
215 | + timer.endSplit("decoded results"); | ||
216 | + | ||
217 | + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); | ||
218 | + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { | ||
219 | + recognitions.add(pq.poll()); | ||
220 | + } | ||
221 | + Trace.endSection(); // "recognizeImage" | ||
222 | + | ||
223 | + timer.endSplit("processed results"); | ||
224 | + | ||
225 | + return recognitions; | ||
226 | + } | ||
227 | + | ||
228 | + @Override | ||
229 | + public void enableStatLogging(final boolean logStats) { | ||
230 | + this.logStats = logStats; | ||
231 | + } | ||
232 | + | ||
233 | + @Override | ||
234 | + public String getStatString() { | ||
235 | + return inferenceInterface.getStatString(); | ||
236 | + } | ||
237 | + | ||
238 | + @Override | ||
239 | + public void close() { | ||
240 | + inferenceInterface.close(); | ||
241 | + } | ||
242 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.env; | ||
17 | + | ||
18 | +import android.graphics.Canvas; | ||
19 | +import android.graphics.Color; | ||
20 | +import android.graphics.Paint; | ||
21 | +import android.graphics.Paint.Align; | ||
22 | +import android.graphics.Paint.Style; | ||
23 | +import android.graphics.Rect; | ||
24 | +import android.graphics.Typeface; | ||
25 | +import java.util.Vector; | ||
26 | + | ||
27 | +/** | ||
28 | + * A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas. | ||
29 | + */ | ||
30 | +public class BorderedText { | ||
31 | + private final Paint interiorPaint; | ||
32 | + private final Paint exteriorPaint; | ||
33 | + | ||
34 | + private final float textSize; | ||
35 | + | ||
36 | + /** | ||
37 | + * Creates a left-aligned bordered text object with a white interior, and a black exterior with | ||
38 | + * the specified text size. | ||
39 | + * | ||
40 | + * @param textSize text size in pixels | ||
41 | + */ | ||
42 | + public BorderedText(final float textSize) { | ||
43 | + this(Color.WHITE, Color.BLACK, textSize); | ||
44 | + } | ||
45 | + | ||
46 | + /** | ||
47 | + * Create a bordered text object with the specified interior and exterior colors, text size and | ||
48 | + * alignment. | ||
49 | + * | ||
50 | + * @param interiorColor the interior text color | ||
51 | + * @param exteriorColor the exterior text color | ||
52 | + * @param textSize text size in pixels | ||
53 | + */ | ||
54 | + public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) { | ||
55 | + interiorPaint = new Paint(); | ||
56 | + interiorPaint.setTextSize(textSize); | ||
57 | + interiorPaint.setColor(interiorColor); | ||
58 | + interiorPaint.setStyle(Style.FILL); | ||
59 | + interiorPaint.setAntiAlias(false); | ||
60 | + interiorPaint.setAlpha(255); | ||
61 | + | ||
62 | + exteriorPaint = new Paint(); | ||
63 | + exteriorPaint.setTextSize(textSize); | ||
64 | + exteriorPaint.setColor(exteriorColor); | ||
65 | + exteriorPaint.setStyle(Style.FILL_AND_STROKE); | ||
66 | + exteriorPaint.setStrokeWidth(textSize / 8); | ||
67 | + exteriorPaint.setAntiAlias(false); | ||
68 | + exteriorPaint.setAlpha(255); | ||
69 | + | ||
70 | + this.textSize = textSize; | ||
71 | + } | ||
72 | + | ||
73 | + public void setTypeface(Typeface typeface) { | ||
74 | + interiorPaint.setTypeface(typeface); | ||
75 | + exteriorPaint.setTypeface(typeface); | ||
76 | + } | ||
77 | + | ||
78 | + public void drawText(final Canvas canvas, final float posX, final float posY, final String text) { | ||
79 | + canvas.drawText(text, posX, posY, exteriorPaint); | ||
80 | + canvas.drawText(text, posX, posY, interiorPaint); | ||
81 | + } | ||
82 | + | ||
83 | + public void drawLines(Canvas canvas, final float posX, final float posY, Vector<String> lines) { | ||
84 | + int lineNum = 0; | ||
85 | + for (final String line : lines) { | ||
86 | + drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line); | ||
87 | + ++lineNum; | ||
88 | + } | ||
89 | + } | ||
90 | + | ||
91 | + public void setInteriorColor(final int color) { | ||
92 | + interiorPaint.setColor(color); | ||
93 | + } | ||
94 | + | ||
95 | + public void setExteriorColor(final int color) { | ||
96 | + exteriorPaint.setColor(color); | ||
97 | + } | ||
98 | + | ||
99 | + public float getTextSize() { | ||
100 | + return textSize; | ||
101 | + } | ||
102 | + | ||
103 | + public void setAlpha(final int alpha) { | ||
104 | + interiorPaint.setAlpha(alpha); | ||
105 | + exteriorPaint.setAlpha(alpha); | ||
106 | + } | ||
107 | + | ||
108 | + public void getTextBounds( | ||
109 | + final String line, final int index, final int count, final Rect lineBounds) { | ||
110 | + interiorPaint.getTextBounds(line, index, count, lineBounds); | ||
111 | + } | ||
112 | + | ||
113 | + public void setTextAlign(final Align align) { | ||
114 | + interiorPaint.setTextAlign(align); | ||
115 | + exteriorPaint.setTextAlign(align); | ||
116 | + } | ||
117 | +} |
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.env; | ||
17 | + | ||
18 | +import android.graphics.Bitmap; | ||
19 | +import android.graphics.Matrix; | ||
20 | +import android.os.Environment; | ||
21 | +import java.io.File; | ||
22 | +import java.io.FileOutputStream; | ||
23 | + | ||
24 | +/** | ||
25 | + * Utility class for manipulating images. | ||
26 | + **/ | ||
27 | +public class ImageUtils { | ||
28 | + @SuppressWarnings("unused") | ||
29 | + private static final Logger LOGGER = new Logger(); | ||
30 | + | ||
31 | + static { | ||
32 | + try { | ||
33 | + System.loadLibrary("tensorflow_demo"); | ||
34 | + } catch (UnsatisfiedLinkError e) { | ||
35 | + LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable."); | ||
36 | + } | ||
37 | + } | ||
38 | + | ||
39 | + /** | ||
40 | + * Utility method to compute the allocated size in bytes of a YUV420SP image | ||
41 | + * of the given dimensions. | ||
42 | + */ | ||
43 | + public static int getYUVByteSize(final int width, final int height) { | ||
44 | + // The luminance plane requires 1 byte per pixel. | ||
45 | + final int ySize = width * height; | ||
46 | + | ||
47 | + // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up. | ||
48 | + // Each 2x2 block takes 2 bytes to encode, one each for U and V. | ||
49 | + final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2; | ||
50 | + | ||
51 | + return ySize + uvSize; | ||
52 | + } | ||
53 | + | ||
54 | + /** | ||
55 | + * Saves a Bitmap object to disk for analysis. | ||
56 | + * | ||
57 | + * @param bitmap The bitmap to save. | ||
58 | + */ | ||
59 | + public static void saveBitmap(final Bitmap bitmap) { | ||
60 | + saveBitmap(bitmap, "preview.png"); | ||
61 | + } | ||
62 | + | ||
63 | + /** | ||
64 | + * Saves a Bitmap object to disk for analysis. | ||
65 | + * | ||
66 | + * @param bitmap The bitmap to save. | ||
67 | + * @param filename The location to save the bitmap to. | ||
68 | + */ | ||
69 | + public static void saveBitmap(final Bitmap bitmap, final String filename) { | ||
70 | + final String root = | ||
71 | + Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow"; | ||
72 | + LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root); | ||
73 | + final File myDir = new File(root); | ||
74 | + | ||
75 | + if (!myDir.mkdirs()) { | ||
76 | + LOGGER.i("Make dir failed"); | ||
77 | + } | ||
78 | + | ||
79 | + final String fname = filename; | ||
80 | + final File file = new File(myDir, fname); | ||
81 | + if (file.exists()) { | ||
82 | + file.delete(); | ||
83 | + } | ||
84 | + try { | ||
85 | + final FileOutputStream out = new FileOutputStream(file); | ||
86 | + bitmap.compress(Bitmap.CompressFormat.PNG, 99, out); | ||
87 | + out.flush(); | ||
88 | + out.close(); | ||
89 | + } catch (final Exception e) { | ||
90 | + LOGGER.e(e, "Exception!"); | ||
91 | + } | ||
92 | + } | ||
93 | + | ||
94 | + // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges | ||
95 | + // are normalized to eight bits. | ||
96 | + static final int kMaxChannelValue = 262143; | ||
97 | + | ||
98 | + // Always prefer the native implementation if available. | ||
99 | + private static boolean useNativeConversion = true; | ||
100 | + | ||
101 | + public static void convertYUV420SPToARGB8888( | ||
102 | + byte[] input, | ||
103 | + int width, | ||
104 | + int height, | ||
105 | + int[] output) { | ||
106 | + if (useNativeConversion) { | ||
107 | + try { | ||
108 | + ImageUtils.convertYUV420SPToARGB8888(input, output, width, height, false); | ||
109 | + return; | ||
110 | + } catch (UnsatisfiedLinkError e) { | ||
111 | + LOGGER.w( | ||
112 | + "Native YUV420SP -> RGB implementation not found, falling back to Java implementation"); | ||
113 | + useNativeConversion = false; | ||
114 | + } | ||
115 | + } | ||
116 | + | ||
117 | + // Java implementation of YUV420SP to ARGB8888 converting | ||
118 | + final int frameSize = width * height; | ||
119 | + for (int j = 0, yp = 0; j < height; j++) { | ||
120 | + int uvp = frameSize + (j >> 1) * width; | ||
121 | + int u = 0; | ||
122 | + int v = 0; | ||
123 | + | ||
124 | + for (int i = 0; i < width; i++, yp++) { | ||
125 | + int y = 0xff & input[yp]; | ||
126 | + if ((i & 1) == 0) { | ||
127 | + v = 0xff & input[uvp++]; | ||
128 | + u = 0xff & input[uvp++]; | ||
129 | + } | ||
130 | + | ||
131 | + output[yp] = YUV2RGB(y, u, v); | ||
132 | + } | ||
133 | + } | ||
134 | + } | ||
135 | + | ||
136 | + private static int YUV2RGB(int y, int u, int v) { | ||
137 | + // Adjust and check YUV values | ||
138 | + y = (y - 16) < 0 ? 0 : (y - 16); | ||
139 | + u -= 128; | ||
140 | + v -= 128; | ||
141 | + | ||
142 | + // This is the floating point equivalent. We do the conversion in integer | ||
143 | + // because some Android devices do not have floating point in hardware. | ||
144 | + // nR = (int)(1.164 * nY + 2.018 * nU); | ||
145 | + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); | ||
146 | + // nB = (int)(1.164 * nY + 1.596 * nV); | ||
147 | + int y1192 = 1192 * y; | ||
148 | + int r = (y1192 + 1634 * v); | ||
149 | + int g = (y1192 - 833 * v - 400 * u); | ||
150 | + int b = (y1192 + 2066 * u); | ||
151 | + | ||
152 | + // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ] | ||
153 | + r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r); | ||
154 | + g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g); | ||
155 | + b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b); | ||
156 | + | ||
157 | + return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff); | ||
158 | + } | ||
159 | + | ||
160 | + | ||
161 | + public static void convertYUV420ToARGB8888( | ||
162 | + byte[] yData, | ||
163 | + byte[] uData, | ||
164 | + byte[] vData, | ||
165 | + int width, | ||
166 | + int height, | ||
167 | + int yRowStride, | ||
168 | + int uvRowStride, | ||
169 | + int uvPixelStride, | ||
170 | + int[] out) { | ||
171 | + if (useNativeConversion) { | ||
172 | + try { | ||
173 | + convertYUV420ToARGB8888( | ||
174 | + yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false); | ||
175 | + return; | ||
176 | + } catch (UnsatisfiedLinkError e) { | ||
177 | + LOGGER.w( | ||
178 | + "Native YUV420 -> RGB implementation not found, falling back to Java implementation"); | ||
179 | + useNativeConversion = false; | ||
180 | + } | ||
181 | + } | ||
182 | + | ||
183 | + int yp = 0; | ||
184 | + for (int j = 0; j < height; j++) { | ||
185 | + int pY = yRowStride * j; | ||
186 | + int pUV = uvRowStride * (j >> 1); | ||
187 | + | ||
188 | + for (int i = 0; i < width; i++) { | ||
189 | + int uv_offset = pUV + (i >> 1) * uvPixelStride; | ||
190 | + | ||
191 | + out[yp++] = YUV2RGB( | ||
192 | + 0xff & yData[pY + i], | ||
193 | + 0xff & uData[uv_offset], | ||
194 | + 0xff & vData[uv_offset]); | ||
195 | + } | ||
196 | + } | ||
197 | + } | ||
198 | + | ||
199 | + | ||
200 | + /** | ||
201 | + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The | ||
202 | + * input and output must already be allocated and non-null. For efficiency, no error checking is | ||
203 | + * performed. | ||
204 | + * | ||
205 | + * @param input The array of YUV 4:2:0 input data. | ||
206 | + * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. | ||
207 | + * @param width The width of the input image. | ||
208 | + * @param height The height of the input image. | ||
209 | + * @param halfSize If true, downsample to 50% in each dimension, otherwise not. | ||
210 | + */ | ||
211 | + private static native void convertYUV420SPToARGB8888( | ||
212 | + byte[] input, int[] output, int width, int height, boolean halfSize); | ||
213 | + | ||
214 | + /** | ||
215 | + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width | ||
216 | + * and height. The input and output must already be allocated and non-null. | ||
217 | + * For efficiency, no error checking is performed. | ||
218 | + * | ||
219 | + * @param y | ||
220 | + * @param u | ||
221 | + * @param v | ||
222 | + * @param uvPixelStride | ||
223 | + * @param width The width of the input image. | ||
224 | + * @param height The height of the input image. | ||
225 | + * @param halfSize If true, downsample to 50% in each dimension, otherwise not. | ||
226 | + * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. | ||
227 | + */ | ||
228 | + private static native void convertYUV420ToARGB8888( | ||
229 | + byte[] y, | ||
230 | + byte[] u, | ||
231 | + byte[] v, | ||
232 | + int[] output, | ||
233 | + int width, | ||
234 | + int height, | ||
235 | + int yRowStride, | ||
236 | + int uvRowStride, | ||
237 | + int uvPixelStride, | ||
238 | + boolean halfSize); | ||
239 | + | ||
240 | + /** | ||
241 | + * Converts YUV420 semi-planar data to RGB 565 data using the supplied width | ||
242 | + * and height. The input and output must already be allocated and non-null. | ||
243 | + * For efficiency, no error checking is performed. | ||
244 | + * | ||
245 | + * @param input The array of YUV 4:2:0 input data. | ||
246 | + * @param output A pre-allocated array for the RGB 5:6:5 output data. | ||
247 | + * @param width The width of the input image. | ||
248 | + * @param height The height of the input image. | ||
249 | + */ | ||
250 | + private static native void convertYUV420SPToRGB565( | ||
251 | + byte[] input, byte[] output, int width, int height); | ||
252 | + | ||
253 | + /** | ||
254 | + * Converts 32-bit ARGB8888 image data to YUV420SP data. This is useful, for | ||
255 | + * instance, in creating data to feed the classes that rely on raw camera | ||
256 | + * preview frames. | ||
257 | + * | ||
258 | + * @param input An array of input pixels in ARGB8888 format. | ||
259 | + * @param output A pre-allocated array for the YUV420SP output data. | ||
260 | + * @param width The width of the input image. | ||
261 | + * @param height The height of the input image. | ||
262 | + */ | ||
263 | + private static native void convertARGB8888ToYUV420SP( | ||
264 | + int[] input, byte[] output, int width, int height); | ||
265 | + | ||
266 | + /** | ||
267 | + * Converts 16-bit RGB565 image data to YUV420SP data. This is useful, for | ||
268 | + * instance, in creating data to feed the classes that rely on raw camera | ||
269 | + * preview frames. | ||
270 | + * | ||
271 | + * @param input An array of input pixels in RGB565 format. | ||
272 | + * @param output A pre-allocated array for the YUV420SP output data. | ||
273 | + * @param width The width of the input image. | ||
274 | + * @param height The height of the input image. | ||
275 | + */ | ||
276 | + private static native void convertRGB565ToYUV420SP( | ||
277 | + byte[] input, byte[] output, int width, int height); | ||
278 | + | ||
279 | + /** | ||
280 | + * Returns a transformation matrix from one reference frame into another. | ||
281 | + * Handles cropping (if maintaining aspect ratio is desired) and rotation. | ||
282 | + * | ||
283 | + * @param srcWidth Width of source frame. | ||
284 | + * @param srcHeight Height of source frame. | ||
285 | + * @param dstWidth Width of destination frame. | ||
286 | + * @param dstHeight Height of destination frame. | ||
287 | + * @param applyRotation Amount of rotation to apply from one frame to another. | ||
288 | + * Must be a multiple of 90. | ||
289 | + * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant, | ||
290 | + * cropping the image if necessary. | ||
291 | + * @return The transformation fulfilling the desired requirements. | ||
292 | + */ | ||
293 | + public static Matrix getTransformationMatrix( | ||
294 | + final int srcWidth, | ||
295 | + final int srcHeight, | ||
296 | + final int dstWidth, | ||
297 | + final int dstHeight, | ||
298 | + final int applyRotation, | ||
299 | + final boolean maintainAspectRatio) { | ||
300 | + final Matrix matrix = new Matrix(); | ||
301 | + | ||
302 | + if (applyRotation != 0) { | ||
303 | + if (applyRotation % 90 != 0) { | ||
304 | + LOGGER.w("Rotation of %d % 90 != 0", applyRotation); | ||
305 | + } | ||
306 | + | ||
307 | + // Translate so center of image is at origin. | ||
308 | + matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f); | ||
309 | + | ||
310 | + // Rotate around origin. | ||
311 | + matrix.postRotate(applyRotation); | ||
312 | + } | ||
313 | + | ||
314 | + // Account for the already applied rotation, if any, and then determine how | ||
315 | + // much scaling is needed for each axis. | ||
316 | + final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0; | ||
317 | + | ||
318 | + final int inWidth = transpose ? srcHeight : srcWidth; | ||
319 | + final int inHeight = transpose ? srcWidth : srcHeight; | ||
320 | + | ||
321 | + // Apply scaling if necessary. | ||
322 | + if (inWidth != dstWidth || inHeight != dstHeight) { | ||
323 | + final float scaleFactorX = dstWidth / (float) inWidth; | ||
324 | + final float scaleFactorY = dstHeight / (float) inHeight; | ||
325 | + | ||
326 | + if (maintainAspectRatio) { | ||
327 | + // Scale by minimum factor so that dst is filled completely while | ||
328 | + // maintaining the aspect ratio. Some image may fall off the edge. | ||
329 | + final float scaleFactor = Math.max(scaleFactorX, scaleFactorY); | ||
330 | + matrix.postScale(scaleFactor, scaleFactor); | ||
331 | + } else { | ||
332 | + // Scale exactly to fill dst from src. | ||
333 | + matrix.postScale(scaleFactorX, scaleFactorY); | ||
334 | + } | ||
335 | + } | ||
336 | + | ||
337 | + if (applyRotation != 0) { | ||
338 | + // Translate back from origin centered reference to destination frame. | ||
339 | + matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f); | ||
340 | + } | ||
341 | + | ||
342 | + return matrix; | ||
343 | + } | ||
344 | +} |
1 | +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.env; | ||
17 | + | ||
18 | +import android.util.Log; | ||
19 | + | ||
20 | +import java.util.HashSet; | ||
21 | +import java.util.Set; | ||
22 | + | ||
23 | +/** | ||
24 | + * Wrapper for the platform log function, allows convenient message prefixing and log disabling. | ||
25 | + */ | ||
26 | +public final class Logger { | ||
27 | + private static final String DEFAULT_TAG = "tensorflow"; | ||
28 | + private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG; | ||
29 | + | ||
30 | + // Classes to be ignored when examining the stack trace | ||
31 | + private static final Set<String> IGNORED_CLASS_NAMES; | ||
32 | + | ||
33 | + static { | ||
34 | + IGNORED_CLASS_NAMES = new HashSet<String>(3); | ||
35 | + IGNORED_CLASS_NAMES.add("dalvik.system.VMStack"); | ||
36 | + IGNORED_CLASS_NAMES.add("java.lang.Thread"); | ||
37 | + IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName()); | ||
38 | + } | ||
39 | + | ||
40 | + private final String tag; | ||
41 | + private final String messagePrefix; | ||
42 | + private int minLogLevel = DEFAULT_MIN_LOG_LEVEL; | ||
43 | + | ||
44 | + /** | ||
45 | + * Creates a Logger using the class name as the message prefix. | ||
46 | + * | ||
47 | + * @param clazz the simple name of this class is used as the message prefix. | ||
48 | + */ | ||
49 | + public Logger(final Class<?> clazz) { | ||
50 | + this(clazz.getSimpleName()); | ||
51 | + } | ||
52 | + | ||
53 | + /** | ||
54 | + * Creates a Logger using the specified message prefix. | ||
55 | + * | ||
56 | + * @param messagePrefix is prepended to the text of every message. | ||
57 | + */ | ||
58 | + public Logger(final String messagePrefix) { | ||
59 | + this(DEFAULT_TAG, messagePrefix); | ||
60 | + } | ||
61 | + | ||
62 | + /** | ||
63 | + * Creates a Logger with a custom tag and a custom message prefix. If the message prefix | ||
64 | + * is set to <pre>null</pre>, the caller's class name is used as the prefix. | ||
65 | + * | ||
66 | + * @param tag identifies the source of a log message. | ||
67 | + * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is | ||
68 | + * being used | ||
69 | + */ | ||
70 | + public Logger(final String tag, final String messagePrefix) { | ||
71 | + this.tag = tag; | ||
72 | + final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix; | ||
73 | + this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix; | ||
74 | + } | ||
75 | + | ||
76 | + /** | ||
77 | + * Creates a Logger using the caller's class name as the message prefix. | ||
78 | + */ | ||
79 | + public Logger() { | ||
80 | + this(DEFAULT_TAG, null); | ||
81 | + } | ||
82 | + | ||
83 | + /** | ||
84 | + * Creates a Logger using the caller's class name as the message prefix. | ||
85 | + */ | ||
86 | + public Logger(final int minLogLevel) { | ||
87 | + this(DEFAULT_TAG, null); | ||
88 | + this.minLogLevel = minLogLevel; | ||
89 | + } | ||
90 | + | ||
91 | + public void setMinLogLevel(final int minLogLevel) { | ||
92 | + this.minLogLevel = minLogLevel; | ||
93 | + } | ||
94 | + | ||
95 | + public boolean isLoggable(final int logLevel) { | ||
96 | + return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel); | ||
97 | + } | ||
98 | + | ||
99 | + /** | ||
100 | + * Return caller's simple name. | ||
101 | + * | ||
102 | + * Android getStackTrace() returns an array that looks like this: | ||
103 | + * stackTrace[0]: dalvik.system.VMStack | ||
104 | + * stackTrace[1]: java.lang.Thread | ||
105 | + * stackTrace[2]: com.google.android.apps.unveil.env.UnveilLogger | ||
106 | + * stackTrace[3]: com.google.android.apps.unveil.BaseApplication | ||
107 | + * | ||
108 | + * This function returns the simple version of the first non-filtered name. | ||
109 | + * | ||
110 | + * @return caller's simple name | ||
111 | + */ | ||
112 | + private static String getCallerSimpleName() { | ||
113 | + // Get the current callstack so we can pull the class of the caller off of it. | ||
114 | + final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); | ||
115 | + | ||
116 | + for (final StackTraceElement elem : stackTrace) { | ||
117 | + final String className = elem.getClassName(); | ||
118 | + if (!IGNORED_CLASS_NAMES.contains(className)) { | ||
119 | + // We're only interested in the simple name of the class, not the complete package. | ||
120 | + final String[] classParts = className.split("\\."); | ||
121 | + return classParts[classParts.length - 1]; | ||
122 | + } | ||
123 | + } | ||
124 | + | ||
125 | + return Logger.class.getSimpleName(); | ||
126 | + } | ||
127 | + | ||
128 | + private String toMessage(final String format, final Object... args) { | ||
129 | + return messagePrefix + (args.length > 0 ? String.format(format, args) : format); | ||
130 | + } | ||
131 | + | ||
132 | + public void v(final String format, final Object... args) { | ||
133 | + if (isLoggable(Log.VERBOSE)) { | ||
134 | + Log.v(tag, toMessage(format, args)); | ||
135 | + } | ||
136 | + } | ||
137 | + | ||
138 | + public void v(final Throwable t, final String format, final Object... args) { | ||
139 | + if (isLoggable(Log.VERBOSE)) { | ||
140 | + Log.v(tag, toMessage(format, args), t); | ||
141 | + } | ||
142 | + } | ||
143 | + | ||
144 | + public void d(final String format, final Object... args) { | ||
145 | + if (isLoggable(Log.DEBUG)) { | ||
146 | + Log.d(tag, toMessage(format, args)); | ||
147 | + } | ||
148 | + } | ||
149 | + | ||
150 | + public void d(final Throwable t, final String format, final Object... args) { | ||
151 | + if (isLoggable(Log.DEBUG)) { | ||
152 | + Log.d(tag, toMessage(format, args), t); | ||
153 | + } | ||
154 | + } | ||
155 | + | ||
156 | + public void i(final String format, final Object... args) { | ||
157 | + if (isLoggable(Log.INFO)) { | ||
158 | + Log.i(tag, toMessage(format, args)); | ||
159 | + } | ||
160 | + } | ||
161 | + | ||
162 | + public void i(final Throwable t, final String format, final Object... args) { | ||
163 | + if (isLoggable(Log.INFO)) { | ||
164 | + Log.i(tag, toMessage(format, args), t); | ||
165 | + } | ||
166 | + } | ||
167 | + | ||
168 | + public void w(final String format, final Object... args) { | ||
169 | + if (isLoggable(Log.WARN)) { | ||
170 | + Log.w(tag, toMessage(format, args)); | ||
171 | + } | ||
172 | + } | ||
173 | + | ||
174 | + public void w(final Throwable t, final String format, final Object... args) { | ||
175 | + if (isLoggable(Log.WARN)) { | ||
176 | + Log.w(tag, toMessage(format, args), t); | ||
177 | + } | ||
178 | + } | ||
179 | + | ||
180 | + public void e(final String format, final Object... args) { | ||
181 | + if (isLoggable(Log.ERROR)) { | ||
182 | + Log.e(tag, toMessage(format, args)); | ||
183 | + } | ||
184 | + } | ||
185 | + | ||
186 | + public void e(final Throwable t, final String format, final Object... args) { | ||
187 | + if (isLoggable(Log.ERROR)) { | ||
188 | + Log.e(tag, toMessage(format, args), t); | ||
189 | + } | ||
190 | + } | ||
191 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.env; | ||
17 | + | ||
18 | +import android.graphics.Bitmap; | ||
19 | +import android.text.TextUtils; | ||
20 | +import java.io.Serializable; | ||
21 | +import java.util.ArrayList; | ||
22 | +import java.util.List; | ||
23 | + | ||
24 | +/** | ||
25 | + * Size class independent of a Camera object. | ||
26 | + */ | ||
27 | +public class Size implements Comparable<Size>, Serializable { | ||
28 | + | ||
29 | + // 1.4 went out with this UID so we'll need to maintain it to preserve pending queries when | ||
30 | + // upgrading. | ||
31 | + public static final long serialVersionUID = 7689808733290872361L; | ||
32 | + | ||
33 | + public final int width; | ||
34 | + public final int height; | ||
35 | + | ||
36 | + public Size(final int width, final int height) { | ||
37 | + this.width = width; | ||
38 | + this.height = height; | ||
39 | + } | ||
40 | + | ||
41 | + public Size(final Bitmap bmp) { | ||
42 | + this.width = bmp.getWidth(); | ||
43 | + this.height = bmp.getHeight(); | ||
44 | + } | ||
45 | + | ||
46 | + /** | ||
47 | + * Rotate a size by the given number of degrees. | ||
48 | + * @param size Size to rotate. | ||
49 | + * @param rotation Degrees {0, 90, 180, 270} to rotate the size. | ||
50 | + * @return Rotated size. | ||
51 | + */ | ||
52 | + public static Size getRotatedSize(final Size size, final int rotation) { | ||
53 | + if (rotation % 180 != 0) { | ||
54 | + // The phone is portrait, therefore the camera is sideways and frame should be rotated. | ||
55 | + return new Size(size.height, size.width); | ||
56 | + } | ||
57 | + return size; | ||
58 | + } | ||
59 | + | ||
60 | + public static Size parseFromString(String sizeString) { | ||
61 | + if (TextUtils.isEmpty(sizeString)) { | ||
62 | + return null; | ||
63 | + } | ||
64 | + | ||
65 | + sizeString = sizeString.trim(); | ||
66 | + | ||
67 | + // The expected format is "<width>x<height>". | ||
68 | + final String[] components = sizeString.split("x"); | ||
69 | + if (components.length == 2) { | ||
70 | + try { | ||
71 | + final int width = Integer.parseInt(components[0]); | ||
72 | + final int height = Integer.parseInt(components[1]); | ||
73 | + return new Size(width, height); | ||
74 | + } catch (final NumberFormatException e) { | ||
75 | + return null; | ||
76 | + } | ||
77 | + } else { | ||
78 | + return null; | ||
79 | + } | ||
80 | + } | ||
81 | + | ||
82 | + public static List<Size> sizeStringToList(final String sizes) { | ||
83 | + final List<Size> sizeList = new ArrayList<Size>(); | ||
84 | + if (sizes != null) { | ||
85 | + final String[] pairs = sizes.split(","); | ||
86 | + for (final String pair : pairs) { | ||
87 | + final Size size = Size.parseFromString(pair); | ||
88 | + if (size != null) { | ||
89 | + sizeList.add(size); | ||
90 | + } | ||
91 | + } | ||
92 | + } | ||
93 | + return sizeList; | ||
94 | + } | ||
95 | + | ||
96 | + public static String sizeListToString(final List<Size> sizes) { | ||
97 | + String sizesString = ""; | ||
98 | + if (sizes != null && sizes.size() > 0) { | ||
99 | + sizesString = sizes.get(0).toString(); | ||
100 | + for (int i = 1; i < sizes.size(); i++) { | ||
101 | + sizesString += "," + sizes.get(i).toString(); | ||
102 | + } | ||
103 | + } | ||
104 | + return sizesString; | ||
105 | + } | ||
106 | + | ||
107 | + public final float aspectRatio() { | ||
108 | + return (float) width / (float) height; | ||
109 | + } | ||
110 | + | ||
111 | + @Override | ||
112 | + public int compareTo(final Size other) { | ||
113 | + return width * height - other.width * other.height; | ||
114 | + } | ||
115 | + | ||
116 | + @Override | ||
117 | + public boolean equals(final Object other) { | ||
118 | + if (other == null) { | ||
119 | + return false; | ||
120 | + } | ||
121 | + | ||
122 | + if (!(other instanceof Size)) { | ||
123 | + return false; | ||
124 | + } | ||
125 | + | ||
126 | + final Size otherSize = (Size) other; | ||
127 | + return (width == otherSize.width && height == otherSize.height); | ||
128 | + } | ||
129 | + | ||
130 | + @Override | ||
131 | + public int hashCode() { | ||
132 | + return width * 32713 + height; | ||
133 | + } | ||
134 | + | ||
135 | + @Override | ||
136 | + public String toString() { | ||
137 | + return dimensionsAsString(width, height); | ||
138 | + } | ||
139 | + | ||
140 | + public static final String dimensionsAsString(final int width, final int height) { | ||
141 | + return width + "x" + height; | ||
142 | + } | ||
143 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.env; | ||
17 | + | ||
18 | +import android.os.SystemClock; | ||
19 | + | ||
20 | +/** | ||
21 | + * A simple utility timer for measuring CPU time and wall-clock splits. | ||
22 | + */ | ||
23 | +public class SplitTimer { | ||
24 | + private final Logger logger; | ||
25 | + | ||
26 | + private long lastWallTime; | ||
27 | + private long lastCpuTime; | ||
28 | + | ||
29 | + public SplitTimer(final String name) { | ||
30 | + logger = new Logger(name); | ||
31 | + newSplit(); | ||
32 | + } | ||
33 | + | ||
34 | + public void newSplit() { | ||
35 | + lastWallTime = SystemClock.uptimeMillis(); | ||
36 | + lastCpuTime = SystemClock.currentThreadTimeMillis(); | ||
37 | + } | ||
38 | + | ||
39 | + public void endSplit(final String splitName) { | ||
40 | + final long currWallTime = SystemClock.uptimeMillis(); | ||
41 | + final long currCpuTime = SystemClock.currentThreadTimeMillis(); | ||
42 | + | ||
43 | + logger.i( | ||
44 | + "%s: cpu=%dms wall=%dms", | ||
45 | + splitName, currCpuTime - lastCpuTime, currWallTime - lastWallTime); | ||
46 | + | ||
47 | + lastWallTime = currWallTime; | ||
48 | + lastCpuTime = currCpuTime; | ||
49 | + } | ||
50 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.tracking; | ||
17 | + | ||
18 | +import android.content.Context; | ||
19 | +import android.graphics.Canvas; | ||
20 | +import android.graphics.Color; | ||
21 | +import android.graphics.Matrix; | ||
22 | +import android.graphics.Paint; | ||
23 | +import android.graphics.Paint.Cap; | ||
24 | +import android.graphics.Paint.Join; | ||
25 | +import android.graphics.Paint.Style; | ||
26 | +import android.graphics.RectF; | ||
27 | +import android.text.TextUtils; | ||
28 | +import android.util.Pair; | ||
29 | +import android.util.TypedValue; | ||
30 | +import android.widget.Toast; | ||
31 | +import java.util.LinkedList; | ||
32 | +import java.util.List; | ||
33 | +import java.util.Queue; | ||
34 | +import org.tensorflow.demo.Classifier.Recognition; | ||
35 | +import org.tensorflow.demo.env.BorderedText; | ||
36 | +import org.tensorflow.demo.env.ImageUtils; | ||
37 | +import org.tensorflow.demo.env.Logger; | ||
38 | + | ||
39 | +/** | ||
40 | + * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing | ||
41 | + * objects to new detections. | ||
42 | + */ | ||
43 | +public class MultiBoxTracker { | ||
44 | + private final Logger logger = new Logger(); | ||
45 | + | ||
46 | + private static final float TEXT_SIZE_DIP = 18; | ||
47 | + | ||
48 | + // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise | ||
49 | + // the lower scored box (new or old) will be removed. | ||
50 | + private static final float MAX_OVERLAP = 0.2f; | ||
51 | + | ||
52 | + private static final float MIN_SIZE = 16.0f; | ||
53 | + | ||
54 | + // Allow replacement of the tracked box with new results if | ||
55 | + // correlation has dropped below this level. | ||
56 | + private static final float MARGINAL_CORRELATION = 0.75f; | ||
57 | + | ||
58 | + // Consider object to be lost if correlation falls below this threshold. | ||
59 | + private static final float MIN_CORRELATION = 0.3f; | ||
60 | + | ||
61 | + private static final int[] COLORS = { | ||
62 | + Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE, | ||
63 | + Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"), | ||
64 | + Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"), | ||
65 | + Color.parseColor("#AA33AA"), Color.parseColor("#0D0068") | ||
66 | + }; | ||
67 | + | ||
68 | + private final Queue<Integer> availableColors = new LinkedList<Integer>(); | ||
69 | + | ||
70 | + public ObjectTracker objectTracker; | ||
71 | + | ||
72 | + final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>(); | ||
73 | + | ||
74 | + private static class TrackedRecognition { | ||
75 | + ObjectTracker.TrackedObject trackedObject; | ||
76 | + RectF location; | ||
77 | + float detectionConfidence; | ||
78 | + int color; | ||
79 | + String title; | ||
80 | + } | ||
81 | + | ||
82 | + private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>(); | ||
83 | + | ||
84 | + private final Paint boxPaint = new Paint(); | ||
85 | + | ||
86 | + private final float textSizePx; | ||
87 | + private final BorderedText borderedText; | ||
88 | + | ||
89 | + private Matrix frameToCanvasMatrix; | ||
90 | + | ||
91 | + private int frameWidth; | ||
92 | + private int frameHeight; | ||
93 | + | ||
94 | + private int sensorOrientation; | ||
95 | + private Context context; | ||
96 | + | ||
97 | + public MultiBoxTracker(final Context context) { | ||
98 | + this.context = context; | ||
99 | + for (final int color : COLORS) { | ||
100 | + availableColors.add(color); | ||
101 | + } | ||
102 | + | ||
103 | + boxPaint.setColor(Color.RED); | ||
104 | + boxPaint.setStyle(Style.STROKE); | ||
105 | + boxPaint.setStrokeWidth(12.0f); | ||
106 | + boxPaint.setStrokeCap(Cap.ROUND); | ||
107 | + boxPaint.setStrokeJoin(Join.ROUND); | ||
108 | + boxPaint.setStrokeMiter(100); | ||
109 | + | ||
110 | + textSizePx = | ||
111 | + TypedValue.applyDimension( | ||
112 | + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics()); | ||
113 | + borderedText = new BorderedText(textSizePx); | ||
114 | + } | ||
115 | + | ||
116 | + private Matrix getFrameToCanvasMatrix() { | ||
117 | + return frameToCanvasMatrix; | ||
118 | + } | ||
119 | + | ||
120 | + public synchronized void drawDebug(final Canvas canvas) { | ||
121 | + final Paint textPaint = new Paint(); | ||
122 | + textPaint.setColor(Color.WHITE); | ||
123 | + textPaint.setTextSize(60.0f); | ||
124 | + | ||
125 | + final Paint boxPaint = new Paint(); | ||
126 | + boxPaint.setColor(Color.RED); | ||
127 | + boxPaint.setAlpha(200); | ||
128 | + boxPaint.setStyle(Style.STROKE); | ||
129 | + | ||
130 | + for (final Pair<Float, RectF> detection : screenRects) { | ||
131 | + final RectF rect = detection.second; | ||
132 | + canvas.drawRect(rect, boxPaint); | ||
133 | + canvas.drawText("" + detection.first, rect.left, rect.top, textPaint); | ||
134 | + borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first); | ||
135 | + } | ||
136 | + | ||
137 | + if (objectTracker == null) { | ||
138 | + return; | ||
139 | + } | ||
140 | + | ||
141 | + // Draw correlations. | ||
142 | + for (final TrackedRecognition recognition : trackedObjects) { | ||
143 | + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; | ||
144 | + | ||
145 | + final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); | ||
146 | + | ||
147 | + if (getFrameToCanvasMatrix().mapRect(trackedPos)) { | ||
148 | + final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation()); | ||
149 | + borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString); | ||
150 | + } | ||
151 | + } | ||
152 | + | ||
153 | + final Matrix matrix = getFrameToCanvasMatrix(); | ||
154 | + objectTracker.drawDebug(canvas, matrix); | ||
155 | + } | ||
156 | + | ||
157 | + public synchronized void trackResults( | ||
158 | + final List<Recognition> results, final byte[] frame, final long timestamp) { | ||
159 | + logger.i("Processing %d results from %d", results.size(), timestamp); | ||
160 | + processResults(timestamp, results, frame); | ||
161 | + } | ||
162 | + | ||
163 | + public synchronized void draw(final Canvas canvas) { | ||
164 | + final boolean rotated = sensorOrientation % 180 == 90; | ||
165 | + final float multiplier = | ||
166 | + Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight), | ||
167 | + canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth)); | ||
168 | + frameToCanvasMatrix = | ||
169 | + ImageUtils.getTransformationMatrix( | ||
170 | + frameWidth, | ||
171 | + frameHeight, | ||
172 | + (int) (multiplier * (rotated ? frameHeight : frameWidth)), | ||
173 | + (int) (multiplier * (rotated ? frameWidth : frameHeight)), | ||
174 | + sensorOrientation, | ||
175 | + false); | ||
176 | + for (final TrackedRecognition recognition : trackedObjects) { | ||
177 | + final RectF trackedPos = | ||
178 | + (objectTracker != null) | ||
179 | + ? recognition.trackedObject.getTrackedPositionInPreviewFrame() | ||
180 | + : new RectF(recognition.location); | ||
181 | + | ||
182 | + getFrameToCanvasMatrix().mapRect(trackedPos); | ||
183 | + boxPaint.setColor(recognition.color); | ||
184 | + | ||
185 | + final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f; | ||
186 | + canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint); | ||
187 | + | ||
188 | + final String labelString = | ||
189 | + !TextUtils.isEmpty(recognition.title) | ||
190 | + ? String.format("%s %.2f", recognition.title, recognition.detectionConfidence) | ||
191 | + : String.format("%.2f", recognition.detectionConfidence); | ||
192 | + borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString); | ||
193 | + } | ||
194 | + } | ||
195 | + | ||
196 | + private boolean initialized = false; | ||
197 | + | ||
198 | + public synchronized void onFrame( | ||
199 | + final int w, | ||
200 | + final int h, | ||
201 | + final int rowStride, | ||
202 | + final int sensorOrientation, | ||
203 | + final byte[] frame, | ||
204 | + final long timestamp) { | ||
205 | + if (objectTracker == null && !initialized) { | ||
206 | + ObjectTracker.clearInstance(); | ||
207 | + | ||
208 | + logger.i("Initializing ObjectTracker: %dx%d", w, h); | ||
209 | + objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); | ||
210 | + frameWidth = w; | ||
211 | + frameHeight = h; | ||
212 | + this.sensorOrientation = sensorOrientation; | ||
213 | + initialized = true; | ||
214 | + | ||
215 | + if (objectTracker == null) { | ||
216 | + String message = | ||
217 | + "Object tracking support not found. " | ||
218 | + + "See tensorflow/examples/android/README.md for details."; | ||
219 | + Toast.makeText(context, message, Toast.LENGTH_LONG).show(); | ||
220 | + logger.e(message); | ||
221 | + } | ||
222 | + } | ||
223 | + | ||
224 | + if (objectTracker == null) { | ||
225 | + return; | ||
226 | + } | ||
227 | + | ||
228 | + objectTracker.nextFrame(frame, null, timestamp, null, true); | ||
229 | + | ||
230 | + // Clean up any objects not worth tracking any more. | ||
231 | + final LinkedList<TrackedRecognition> copyList = | ||
232 | + new LinkedList<TrackedRecognition>(trackedObjects); | ||
233 | + for (final TrackedRecognition recognition : copyList) { | ||
234 | + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; | ||
235 | + final float correlation = trackedObject.getCurrentCorrelation(); | ||
236 | + if (correlation < MIN_CORRELATION) { | ||
237 | + logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation); | ||
238 | + trackedObject.stopTracking(); | ||
239 | + trackedObjects.remove(recognition); | ||
240 | + | ||
241 | + availableColors.add(recognition.color); | ||
242 | + } | ||
243 | + } | ||
244 | + } | ||
245 | + | ||
246 | + private void processResults( | ||
247 | + final long timestamp, final List<Recognition> results, final byte[] originalFrame) { | ||
248 | + final List<Pair<Float, Recognition>> rectsToTrack = new LinkedList<Pair<Float, Recognition>>(); | ||
249 | + | ||
250 | + screenRects.clear(); | ||
251 | + final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix()); | ||
252 | + | ||
253 | + for (final Recognition result : results) { | ||
254 | + if (result.getLocation() == null) { | ||
255 | + continue; | ||
256 | + } | ||
257 | + final RectF detectionFrameRect = new RectF(result.getLocation()); | ||
258 | + | ||
259 | + final RectF detectionScreenRect = new RectF(); | ||
260 | + rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect); | ||
261 | + | ||
262 | + logger.v( | ||
263 | + "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect); | ||
264 | + | ||
265 | + screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect)); | ||
266 | + | ||
267 | + if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) { | ||
268 | + logger.w("Degenerate rectangle! " + detectionFrameRect); | ||
269 | + continue; | ||
270 | + } | ||
271 | + | ||
272 | + rectsToTrack.add(new Pair<Float, Recognition>(result.getConfidence(), result)); | ||
273 | + } | ||
274 | + | ||
275 | + if (rectsToTrack.isEmpty()) { | ||
276 | + logger.v("Nothing to track, aborting."); | ||
277 | + return; | ||
278 | + } | ||
279 | + | ||
280 | + if (objectTracker == null) { | ||
281 | + trackedObjects.clear(); | ||
282 | + for (final Pair<Float, Recognition> potential : rectsToTrack) { | ||
283 | + final TrackedRecognition trackedRecognition = new TrackedRecognition(); | ||
284 | + trackedRecognition.detectionConfidence = potential.first; | ||
285 | + trackedRecognition.location = new RectF(potential.second.getLocation()); | ||
286 | + trackedRecognition.trackedObject = null; | ||
287 | + trackedRecognition.title = potential.second.getTitle(); | ||
288 | + trackedRecognition.color = COLORS[trackedObjects.size()]; | ||
289 | + trackedObjects.add(trackedRecognition); | ||
290 | + | ||
291 | + if (trackedObjects.size() >= COLORS.length) { | ||
292 | + break; | ||
293 | + } | ||
294 | + } | ||
295 | + return; | ||
296 | + } | ||
297 | + | ||
298 | + logger.i("%d rects to track", rectsToTrack.size()); | ||
299 | + for (final Pair<Float, Recognition> potential : rectsToTrack) { | ||
300 | + handleDetection(originalFrame, timestamp, potential); | ||
301 | + } | ||
302 | + } | ||
303 | + | ||
304 | + private void handleDetection( | ||
305 | + final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential) { | ||
306 | + final ObjectTracker.TrackedObject potentialObject = | ||
307 | + objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy); | ||
308 | + | ||
309 | + final float potentialCorrelation = potentialObject.getCurrentCorrelation(); | ||
310 | + logger.v( | ||
311 | + "Tracked object went from %s to %s with correlation %.2f", | ||
312 | + potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation); | ||
313 | + | ||
314 | + if (potentialCorrelation < MARGINAL_CORRELATION) { | ||
315 | + logger.v("Correlation too low to begin tracking %s.", potentialObject); | ||
316 | + potentialObject.stopTracking(); | ||
317 | + return; | ||
318 | + } | ||
319 | + | ||
320 | + final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>(); | ||
321 | + | ||
322 | + float maxIntersect = 0.0f; | ||
323 | + | ||
324 | + // This is the current tracked object whose color we will take. If left null we'll take the | ||
325 | + // first one from the color queue. | ||
326 | + TrackedRecognition recogToReplace = null; | ||
327 | + | ||
328 | + // Look for intersections that will be overridden by this object or an intersection that would | ||
329 | + // prevent this one from being placed. | ||
330 | + for (final TrackedRecognition trackedRecognition : trackedObjects) { | ||
331 | + final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame(); | ||
332 | + final RectF b = potentialObject.getTrackedPositionInPreviewFrame(); | ||
333 | + final RectF intersection = new RectF(); | ||
334 | + final boolean intersects = intersection.setIntersect(a, b); | ||
335 | + | ||
336 | + final float intersectArea = intersection.width() * intersection.height(); | ||
337 | + final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea; | ||
338 | + final float intersectOverUnion = intersectArea / totalArea; | ||
339 | + | ||
340 | + // If there is an intersection with this currently tracked box above the maximum overlap | ||
341 | + // percentage allowed, either the new recognition needs to be dismissed or the old | ||
342 | + // recognition needs to be removed and possibly replaced with the new one. | ||
343 | + if (intersects && intersectOverUnion > MAX_OVERLAP) { | ||
344 | + if (potential.first < trackedRecognition.detectionConfidence | ||
345 | + && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) { | ||
346 | + // If track for the existing object is still going strong and the detection score was | ||
347 | + // good, reject this new object. | ||
348 | + potentialObject.stopTracking(); | ||
349 | + return; | ||
350 | + } else { | ||
351 | + removeList.add(trackedRecognition); | ||
352 | + | ||
353 | + // Let the previously tracked object with max intersection amount donate its color to | ||
354 | + // the new object. | ||
355 | + if (intersectOverUnion > maxIntersect) { | ||
356 | + maxIntersect = intersectOverUnion; | ||
357 | + recogToReplace = trackedRecognition; | ||
358 | + } | ||
359 | + } | ||
360 | + } | ||
361 | + } | ||
362 | + | ||
363 | + // If we're already tracking the max object and no intersections were found to bump off, | ||
364 | + // pick the worst current tracked object to remove, if it's also worse than this candidate | ||
365 | + // object. | ||
366 | + if (availableColors.isEmpty() && removeList.isEmpty()) { | ||
367 | + for (final TrackedRecognition candidate : trackedObjects) { | ||
368 | + if (candidate.detectionConfidence < potential.first) { | ||
369 | + if (recogToReplace == null | ||
370 | + || candidate.detectionConfidence < recogToReplace.detectionConfidence) { | ||
371 | + // Save it so that we use this color for the new object. | ||
372 | + recogToReplace = candidate; | ||
373 | + } | ||
374 | + } | ||
375 | + } | ||
376 | + if (recogToReplace != null) { | ||
377 | + logger.v("Found non-intersecting object to remove."); | ||
378 | + removeList.add(recogToReplace); | ||
379 | + } else { | ||
380 | + logger.v("No non-intersecting object found to remove"); | ||
381 | + } | ||
382 | + } | ||
383 | + | ||
384 | + // Remove everything that got intersected. | ||
385 | + for (final TrackedRecognition trackedRecognition : removeList) { | ||
386 | + logger.v( | ||
387 | + "Removing tracked object %s with detection confidence %.2f, correlation %.2f", | ||
388 | + trackedRecognition.trackedObject, | ||
389 | + trackedRecognition.detectionConfidence, | ||
390 | + trackedRecognition.trackedObject.getCurrentCorrelation()); | ||
391 | + trackedRecognition.trackedObject.stopTracking(); | ||
392 | + trackedObjects.remove(trackedRecognition); | ||
393 | + if (trackedRecognition != recogToReplace) { | ||
394 | + availableColors.add(trackedRecognition.color); | ||
395 | + } | ||
396 | + } | ||
397 | + | ||
398 | + if (recogToReplace == null && availableColors.isEmpty()) { | ||
399 | + logger.e("No room to track this object, aborting."); | ||
400 | + potentialObject.stopTracking(); | ||
401 | + return; | ||
402 | + } | ||
403 | + | ||
404 | + // Finally safe to say we can track this object. | ||
405 | + logger.v( | ||
406 | + "Tracking object %s (%s) with detection confidence %.2f at position %s", | ||
407 | + potentialObject, | ||
408 | + potential.second.getTitle(), | ||
409 | + potential.first, | ||
410 | + potential.second.getLocation()); | ||
411 | + final TrackedRecognition trackedRecognition = new TrackedRecognition(); | ||
412 | + trackedRecognition.detectionConfidence = potential.first; | ||
413 | + trackedRecognition.trackedObject = potentialObject; | ||
414 | + trackedRecognition.title = potential.second.getTitle(); | ||
415 | + | ||
416 | + // Use the color from a replaced object before taking one from the color queue. | ||
417 | + trackedRecognition.color = | ||
418 | + recogToReplace != null ? recogToReplace.color : availableColors.poll(); | ||
419 | + trackedObjects.add(trackedRecognition); | ||
420 | + } | ||
421 | +} |
1 | +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
2 | + | ||
3 | +Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +you may not use this file except in compliance with the License. | ||
5 | +You may obtain a copy of the License at | ||
6 | + | ||
7 | + http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | + | ||
9 | +Unless required by applicable law or agreed to in writing, software | ||
10 | +distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +See the License for the specific language governing permissions and | ||
13 | +limitations under the License. | ||
14 | +==============================================================================*/ | ||
15 | + | ||
16 | +package org.tensorflow.demo.tracking; | ||
17 | + | ||
18 | +import android.graphics.Canvas; | ||
19 | +import android.graphics.Color; | ||
20 | +import android.graphics.Matrix; | ||
21 | +import android.graphics.Paint; | ||
22 | +import android.graphics.PointF; | ||
23 | +import android.graphics.RectF; | ||
24 | +import android.graphics.Typeface; | ||
25 | +import java.util.ArrayList; | ||
26 | +import java.util.HashMap; | ||
27 | +import java.util.LinkedList; | ||
28 | +import java.util.List; | ||
29 | +import java.util.Map; | ||
30 | +import java.util.Vector; | ||
31 | +import javax.microedition.khronos.opengles.GL10; | ||
32 | +import org.tensorflow.demo.env.Logger; | ||
33 | +import org.tensorflow.demo.env.Size; | ||
34 | + | ||
35 | +/** | ||
36 | + * True object detector/tracker class that tracks objects across consecutive preview frames. | ||
37 | + * It provides a simplified Java interface to the analogous native object defined by | ||
38 | + * jni/client_vision/tracking/object_tracker.*. | ||
39 | + * | ||
40 | + * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must | ||
41 | + * be allocated by ObjectTracker.getInstance(). In addition, release() should be called | ||
42 | + * as soon as the ObjectTracker is no longer needed, and before a new one is created. | ||
43 | + * | ||
44 | + * nextFrame() should be called as new frames become available, preferably as often as possible. | ||
45 | + * | ||
46 | + * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects | ||
47 | + * are associated with the ObjectTracker that created them, and are only valid while that | ||
48 | + * ObjectTracker still exists. | ||
49 | + */ | ||
50 | +public class ObjectTracker { | ||
51 | + private static final Logger LOGGER = new Logger(); | ||
52 | + | ||
53 | + private static boolean libraryFound = false; | ||
54 | + | ||
55 | + static { | ||
56 | + try { | ||
57 | + System.loadLibrary("tensorflow_demo"); | ||
58 | + libraryFound = true; | ||
59 | + } catch (UnsatisfiedLinkError e) { | ||
60 | + LOGGER.e("libtensorflow_demo.so not found, tracking unavailable"); | ||
61 | + } | ||
62 | + } | ||
63 | + | ||
64 | + private static final boolean DRAW_TEXT = false; | ||
65 | + | ||
66 | + /** | ||
67 | + * How many history points to keep track of and draw in the red history line. | ||
68 | + */ | ||
69 | + private static final int MAX_DEBUG_HISTORY_SIZE = 30; | ||
70 | + | ||
71 | + /** | ||
72 | + * How many frames of optical flow deltas to record. | ||
73 | + * TODO(andrewharp): Push this down to the native level so it can be polled | ||
74 | + * efficiently into a an array for upload, instead of keeping a duplicate | ||
75 | + * copy in Java. | ||
76 | + */ | ||
77 | + private static final int MAX_FRAME_HISTORY_SIZE = 200; | ||
78 | + | ||
79 | + private static final int DOWNSAMPLE_FACTOR = 2; | ||
80 | + | ||
81 | + private final byte[] downsampledFrame; | ||
82 | + | ||
83 | + protected static ObjectTracker instance; | ||
84 | + | ||
85 | + private final Map<String, TrackedObject> trackedObjects; | ||
86 | + | ||
87 | + private long lastTimestamp; | ||
88 | + | ||
89 | + private FrameChange lastKeypoints; | ||
90 | + | ||
91 | + private final Vector<PointF> debugHistory; | ||
92 | + | ||
93 | + private final LinkedList<TimestampedDeltas> timestampedDeltas; | ||
94 | + | ||
95 | + protected final int frameWidth; | ||
96 | + protected final int frameHeight; | ||
97 | + private final int rowStride; | ||
98 | + protected final boolean alwaysTrack; | ||
99 | + | ||
100 | + private static class TimestampedDeltas { | ||
101 | + final long timestamp; | ||
102 | + final byte[] deltas; | ||
103 | + | ||
104 | + public TimestampedDeltas(final long timestamp, final byte[] deltas) { | ||
105 | + this.timestamp = timestamp; | ||
106 | + this.deltas = deltas; | ||
107 | + } | ||
108 | + } | ||
109 | + | ||
110 | + /** | ||
111 | + * A simple class that records keypoint information, which includes | ||
112 | + * local location, score and type. This will be used in calculating | ||
113 | + * FrameChange. | ||
114 | + */ | ||
115 | + public static class Keypoint { | ||
116 | + public final float x; | ||
117 | + public final float y; | ||
118 | + public final float score; | ||
119 | + public final int type; | ||
120 | + | ||
121 | + public Keypoint(final float x, final float y) { | ||
122 | + this.x = x; | ||
123 | + this.y = y; | ||
124 | + this.score = 0; | ||
125 | + this.type = -1; | ||
126 | + } | ||
127 | + | ||
128 | + public Keypoint(final float x, final float y, final float score, final int type) { | ||
129 | + this.x = x; | ||
130 | + this.y = y; | ||
131 | + this.score = score; | ||
132 | + this.type = type; | ||
133 | + } | ||
134 | + | ||
135 | + Keypoint delta(final Keypoint other) { | ||
136 | + return new Keypoint(this.x - other.x, this.y - other.y); | ||
137 | + } | ||
138 | + } | ||
139 | + | ||
140 | + /** | ||
141 | + * A simple class that could calculate Keypoint delta. | ||
142 | + * This class will be used in calculating frame translation delta | ||
143 | + * for optical flow. | ||
144 | + */ | ||
145 | + public static class PointChange { | ||
146 | + public final Keypoint keypointA; | ||
147 | + public final Keypoint keypointB; | ||
148 | + Keypoint pointDelta; | ||
149 | + private final boolean wasFound; | ||
150 | + | ||
151 | + public PointChange(final float x1, final float y1, | ||
152 | + final float x2, final float y2, | ||
153 | + final float score, final int type, | ||
154 | + final boolean wasFound) { | ||
155 | + this.wasFound = wasFound; | ||
156 | + | ||
157 | + keypointA = new Keypoint(x1, y1, score, type); | ||
158 | + keypointB = new Keypoint(x2, y2); | ||
159 | + } | ||
160 | + | ||
161 | + public Keypoint getDelta() { | ||
162 | + if (pointDelta == null) { | ||
163 | + pointDelta = keypointB.delta(keypointA); | ||
164 | + } | ||
165 | + return pointDelta; | ||
166 | + } | ||
167 | + } | ||
168 | + | ||
169 | + /** A class that records a timestamped frame translation delta for optical flow. */ | ||
170 | + public static class FrameChange { | ||
171 | + public static final int KEYPOINT_STEP = 7; | ||
172 | + | ||
173 | + public final Vector<PointChange> pointDeltas; | ||
174 | + | ||
175 | + private final float minScore; | ||
176 | + private final float maxScore; | ||
177 | + | ||
178 | + public FrameChange(final float[] framePoints) { | ||
179 | + float minScore = 100.0f; | ||
180 | + float maxScore = -100.0f; | ||
181 | + | ||
182 | + pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP); | ||
183 | + | ||
184 | + for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) { | ||
185 | + final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR; | ||
186 | + final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR; | ||
187 | + | ||
188 | + final boolean wasFound = framePoints[i + 2] > 0.0f; | ||
189 | + | ||
190 | + final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR; | ||
191 | + final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR; | ||
192 | + final float score = framePoints[i + 5]; | ||
193 | + final int type = (int) framePoints[i + 6]; | ||
194 | + | ||
195 | + minScore = Math.min(minScore, score); | ||
196 | + maxScore = Math.max(maxScore, score); | ||
197 | + | ||
198 | + pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound)); | ||
199 | + } | ||
200 | + | ||
201 | + this.minScore = minScore; | ||
202 | + this.maxScore = maxScore; | ||
203 | + } | ||
204 | + } | ||
205 | + | ||
206 | + public static synchronized ObjectTracker getInstance( | ||
207 | + final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { | ||
208 | + if (!libraryFound) { | ||
209 | + LOGGER.e( | ||
210 | + "Native object tracking support not found. " | ||
211 | + + "See tensorflow/examples/android/README.md for details."); | ||
212 | + return null; | ||
213 | + } | ||
214 | + | ||
215 | + if (instance == null) { | ||
216 | + instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack); | ||
217 | + instance.init(); | ||
218 | + } else { | ||
219 | + throw new RuntimeException( | ||
220 | + "Tried to create a new objectracker before releasing the old one!"); | ||
221 | + } | ||
222 | + return instance; | ||
223 | + } | ||
224 | + | ||
225 | + public static synchronized void clearInstance() { | ||
226 | + if (instance != null) { | ||
227 | + instance.release(); | ||
228 | + } | ||
229 | + } | ||
230 | + | ||
231 | + protected ObjectTracker( | ||
232 | + final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { | ||
233 | + this.frameWidth = frameWidth; | ||
234 | + this.frameHeight = frameHeight; | ||
235 | + this.rowStride = rowStride; | ||
236 | + this.alwaysTrack = alwaysTrack; | ||
237 | + this.timestampedDeltas = new LinkedList<TimestampedDeltas>(); | ||
238 | + | ||
239 | + trackedObjects = new HashMap<String, TrackedObject>(); | ||
240 | + | ||
241 | + debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE); | ||
242 | + | ||
243 | + downsampledFrame = | ||
244 | + new byte | ||
245 | + [(frameWidth + DOWNSAMPLE_FACTOR - 1) | ||
246 | + / DOWNSAMPLE_FACTOR | ||
247 | + * (frameHeight + DOWNSAMPLE_FACTOR - 1) | ||
248 | + / DOWNSAMPLE_FACTOR]; | ||
249 | + } | ||
250 | + | ||
251 | + protected void init() { | ||
252 | + // The native tracker never sees the full frame, so pre-scale dimensions | ||
253 | + // by the downsample factor. | ||
254 | + initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack); | ||
255 | + } | ||
256 | + | ||
257 | + private final float[] matrixValues = new float[9]; | ||
258 | + | ||
259 | + private long downsampledTimestamp; | ||
260 | + | ||
261 | + @SuppressWarnings("unused") | ||
262 | + public synchronized void drawOverlay(final GL10 gl, | ||
263 | + final Size cameraViewSize, final Matrix matrix) { | ||
264 | + final Matrix tempMatrix = new Matrix(matrix); | ||
265 | + tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR); | ||
266 | + tempMatrix.getValues(matrixValues); | ||
267 | + drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues); | ||
268 | + } | ||
269 | + | ||
270 | + public synchronized void nextFrame( | ||
271 | + final byte[] frameData, final byte[] uvData, | ||
272 | + final long timestamp, final float[] transformationMatrix, | ||
273 | + final boolean updateDebugInfo) { | ||
274 | + if (downsampledTimestamp != timestamp) { | ||
275 | + ObjectTracker.downsampleImageNative( | ||
276 | + frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame); | ||
277 | + downsampledTimestamp = timestamp; | ||
278 | + } | ||
279 | + | ||
280 | + // Do Lucas Kanade using the fullframe initializer. | ||
281 | + nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix); | ||
282 | + | ||
283 | + timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR))); | ||
284 | + while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) { | ||
285 | + timestampedDeltas.removeFirst(); | ||
286 | + } | ||
287 | + | ||
288 | + for (final TrackedObject trackedObject : trackedObjects.values()) { | ||
289 | + trackedObject.updateTrackedPosition(); | ||
290 | + } | ||
291 | + | ||
292 | + if (updateDebugInfo) { | ||
293 | + updateDebugHistory(); | ||
294 | + } | ||
295 | + | ||
296 | + lastTimestamp = timestamp; | ||
297 | + } | ||
298 | + | ||
299 | + public synchronized void release() { | ||
300 | + releaseMemoryNative(); | ||
301 | + synchronized (ObjectTracker.class) { | ||
302 | + instance = null; | ||
303 | + } | ||
304 | + } | ||
305 | + | ||
306 | + private void drawHistoryDebug(final Canvas canvas) { | ||
307 | + drawHistoryPoint( | ||
308 | + canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2); | ||
309 | + } | ||
310 | + | ||
311 | + private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) { | ||
312 | + final Paint p = new Paint(); | ||
313 | + p.setAntiAlias(false); | ||
314 | + p.setTypeface(Typeface.SERIF); | ||
315 | + | ||
316 | + p.setColor(Color.RED); | ||
317 | + p.setStrokeWidth(2.0f); | ||
318 | + | ||
319 | + // Draw the center circle. | ||
320 | + p.setColor(Color.GREEN); | ||
321 | + canvas.drawCircle(startX, startY, 3.0f, p); | ||
322 | + | ||
323 | + p.setColor(Color.RED); | ||
324 | + | ||
325 | + // Iterate through in backwards order. | ||
326 | + synchronized (debugHistory) { | ||
327 | + final int numPoints = debugHistory.size(); | ||
328 | + float lastX = startX; | ||
329 | + float lastY = startY; | ||
330 | + for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) { | ||
331 | + final PointF delta = debugHistory.get(numPoints - keypointNum - 1); | ||
332 | + final float newX = lastX + delta.x; | ||
333 | + final float newY = lastY + delta.y; | ||
334 | + canvas.drawLine(lastX, lastY, newX, newY, p); | ||
335 | + lastX = newX; | ||
336 | + lastY = newY; | ||
337 | + } | ||
338 | + } | ||
339 | + } | ||
340 | + | ||
341 | + private static int floatToChar(final float value) { | ||
342 | + return Math.max(0, Math.min((int) (value * 255.999f), 255)); | ||
343 | + } | ||
344 | + | ||
345 | + private void drawKeypointsDebug(final Canvas canvas) { | ||
346 | + final Paint p = new Paint(); | ||
347 | + if (lastKeypoints == null) { | ||
348 | + return; | ||
349 | + } | ||
350 | + final int keypointSize = 3; | ||
351 | + | ||
352 | + final float minScore = lastKeypoints.minScore; | ||
353 | + final float maxScore = lastKeypoints.maxScore; | ||
354 | + | ||
355 | + for (final PointChange keypoint : lastKeypoints.pointDeltas) { | ||
356 | + if (keypoint.wasFound) { | ||
357 | + final int r = | ||
358 | + floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore)); | ||
359 | + final int b = | ||
360 | + floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore)); | ||
361 | + | ||
362 | + final int color = 0xFF000000 | (r << 16) | b; | ||
363 | + p.setColor(color); | ||
364 | + | ||
365 | + final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y, | ||
366 | + keypoint.keypointB.x, keypoint.keypointB.y}; | ||
367 | + canvas.drawRect(screenPoints[2] - keypointSize, | ||
368 | + screenPoints[3] - keypointSize, | ||
369 | + screenPoints[2] + keypointSize, | ||
370 | + screenPoints[3] + keypointSize, p); | ||
371 | + p.setColor(Color.CYAN); | ||
372 | + canvas.drawLine(screenPoints[2], screenPoints[3], | ||
373 | + screenPoints[0], screenPoints[1], p); | ||
374 | + | ||
375 | + if (DRAW_TEXT) { | ||
376 | + p.setColor(Color.WHITE); | ||
377 | + canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score, | ||
378 | + keypoint.keypointA.x, keypoint.keypointA.y, p); | ||
379 | + } | ||
380 | + } else { | ||
381 | + p.setColor(Color.YELLOW); | ||
382 | + final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y}; | ||
383 | + canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p); | ||
384 | + } | ||
385 | + } | ||
386 | + } | ||
387 | + | ||
388 | + private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX, | ||
389 | + final float positionY, final float radius) { | ||
390 | + final RectF currPosition = getCurrentPosition(timestamp, | ||
391 | + new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius)); | ||
392 | + return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY); | ||
393 | + } | ||
394 | + | ||
395 | + private synchronized RectF getCurrentPosition(final long timestamp, final RectF | ||
396 | + oldPosition) { | ||
397 | + final RectF downscaledFrameRect = downscaleRect(oldPosition); | ||
398 | + | ||
399 | + final float[] delta = new float[4]; | ||
400 | + getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top, | ||
401 | + downscaledFrameRect.right, downscaledFrameRect.bottom, delta); | ||
402 | + | ||
403 | + final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]); | ||
404 | + | ||
405 | + return upscaleRect(newPosition); | ||
406 | + } | ||
407 | + | ||
408 | + private void updateDebugHistory() { | ||
409 | + lastKeypoints = new FrameChange(getKeypointsNative(false)); | ||
410 | + | ||
411 | + if (lastTimestamp == 0) { | ||
412 | + return; | ||
413 | + } | ||
414 | + | ||
415 | + final PointF delta = | ||
416 | + getAccumulatedDelta( | ||
417 | + lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100); | ||
418 | + | ||
419 | + synchronized (debugHistory) { | ||
420 | + debugHistory.add(delta); | ||
421 | + | ||
422 | + while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) { | ||
423 | + debugHistory.remove(0); | ||
424 | + } | ||
425 | + } | ||
426 | + } | ||
427 | + | ||
428 | + public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) { | ||
429 | + canvas.save(); | ||
430 | + canvas.setMatrix(frameToCanvas); | ||
431 | + | ||
432 | + drawHistoryDebug(canvas); | ||
433 | + drawKeypointsDebug(canvas); | ||
434 | + | ||
435 | + canvas.restore(); | ||
436 | + } | ||
437 | + | ||
438 | + public Vector<String> getDebugText() { | ||
439 | + final Vector<String> lines = new Vector<String>(); | ||
440 | + | ||
441 | + if (lastKeypoints != null) { | ||
442 | + lines.add("Num keypoints " + lastKeypoints.pointDeltas.size()); | ||
443 | + lines.add("Min score: " + lastKeypoints.minScore); | ||
444 | + lines.add("Max score: " + lastKeypoints.maxScore); | ||
445 | + } | ||
446 | + | ||
447 | + return lines; | ||
448 | + } | ||
449 | + | ||
450 | + public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) { | ||
451 | + final List<byte[]> frameDeltas = new ArrayList<byte[]>(); | ||
452 | + while (timestampedDeltas.size() > 0) { | ||
453 | + final TimestampedDeltas currentDeltas = timestampedDeltas.peek(); | ||
454 | + if (currentDeltas.timestamp <= endFrameTime) { | ||
455 | + frameDeltas.add(currentDeltas.deltas); | ||
456 | + timestampedDeltas.removeFirst(); | ||
457 | + } else { | ||
458 | + break; | ||
459 | + } | ||
460 | + } | ||
461 | + | ||
462 | + return frameDeltas; | ||
463 | + } | ||
464 | + | ||
465 | + private RectF downscaleRect(final RectF fullFrameRect) { | ||
466 | + return new RectF( | ||
467 | + fullFrameRect.left / DOWNSAMPLE_FACTOR, | ||
468 | + fullFrameRect.top / DOWNSAMPLE_FACTOR, | ||
469 | + fullFrameRect.right / DOWNSAMPLE_FACTOR, | ||
470 | + fullFrameRect.bottom / DOWNSAMPLE_FACTOR); | ||
471 | + } | ||
472 | + | ||
473 | + private RectF upscaleRect(final RectF downsampledFrameRect) { | ||
474 | + return new RectF( | ||
475 | + downsampledFrameRect.left * DOWNSAMPLE_FACTOR, | ||
476 | + downsampledFrameRect.top * DOWNSAMPLE_FACTOR, | ||
477 | + downsampledFrameRect.right * DOWNSAMPLE_FACTOR, | ||
478 | + downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR); | ||
479 | + } | ||
480 | + | ||
481 | + /** | ||
482 | + * A TrackedObject represents a native TrackedObject, and provides access to the | ||
483 | + * relevant native tracking information available after every frame update. They may | ||
484 | + * be safely passed around and accessed externally, but will become invalid after | ||
485 | + * stopTracking() is called or the related creating ObjectTracker is deactivated. | ||
486 | + * | ||
487 | + * @author andrewharp@google.com (Andrew Harp) | ||
488 | + */ | ||
489 | + public class TrackedObject { | ||
490 | + private final String id; | ||
491 | + | ||
492 | + private long lastExternalPositionTime; | ||
493 | + | ||
494 | + private RectF lastTrackedPosition; | ||
495 | + private boolean visibleInLastFrame; | ||
496 | + | ||
497 | + private boolean isDead; | ||
498 | + | ||
499 | + TrackedObject(final RectF position, final long timestamp, final byte[] data) { | ||
500 | + isDead = false; | ||
501 | + | ||
502 | + id = Integer.toString(this.hashCode()); | ||
503 | + | ||
504 | + lastExternalPositionTime = timestamp; | ||
505 | + | ||
506 | + synchronized (ObjectTracker.this) { | ||
507 | + registerInitialAppearance(position, data); | ||
508 | + setPreviousPosition(position, timestamp); | ||
509 | + trackedObjects.put(id, this); | ||
510 | + } | ||
511 | + } | ||
512 | + | ||
513 | + public void stopTracking() { | ||
514 | + checkValidObject(); | ||
515 | + | ||
516 | + synchronized (ObjectTracker.this) { | ||
517 | + isDead = true; | ||
518 | + forgetNative(id); | ||
519 | + trackedObjects.remove(id); | ||
520 | + } | ||
521 | + } | ||
522 | + | ||
523 | + public float getCurrentCorrelation() { | ||
524 | + checkValidObject(); | ||
525 | + return ObjectTracker.this.getCurrentCorrelation(id); | ||
526 | + } | ||
527 | + | ||
528 | + void registerInitialAppearance(final RectF position, final byte[] data) { | ||
529 | + final RectF externalPosition = downscaleRect(position); | ||
530 | + registerNewObjectWithAppearanceNative(id, | ||
531 | + externalPosition.left, externalPosition.top, | ||
532 | + externalPosition.right, externalPosition.bottom, | ||
533 | + data); | ||
534 | + } | ||
535 | + | ||
536 | + synchronized void setPreviousPosition(final RectF position, final long timestamp) { | ||
537 | + checkValidObject(); | ||
538 | + synchronized (ObjectTracker.this) { | ||
539 | + if (lastExternalPositionTime > timestamp) { | ||
540 | + LOGGER.w("Tried to use older position time!"); | ||
541 | + return; | ||
542 | + } | ||
543 | + final RectF externalPosition = downscaleRect(position); | ||
544 | + lastExternalPositionTime = timestamp; | ||
545 | + | ||
546 | + setPreviousPositionNative(id, | ||
547 | + externalPosition.left, externalPosition.top, | ||
548 | + externalPosition.right, externalPosition.bottom, | ||
549 | + lastExternalPositionTime); | ||
550 | + | ||
551 | + updateTrackedPosition(); | ||
552 | + } | ||
553 | + } | ||
554 | + | ||
555 | + void setCurrentPosition(final RectF position) { | ||
556 | + checkValidObject(); | ||
557 | + final RectF downsampledPosition = downscaleRect(position); | ||
558 | + synchronized (ObjectTracker.this) { | ||
559 | + setCurrentPositionNative(id, | ||
560 | + downsampledPosition.left, downsampledPosition.top, | ||
561 | + downsampledPosition.right, downsampledPosition.bottom); | ||
562 | + } | ||
563 | + } | ||
564 | + | ||
565 | + private synchronized void updateTrackedPosition() { | ||
566 | + checkValidObject(); | ||
567 | + | ||
568 | + final float[] delta = new float[4]; | ||
569 | + getTrackedPositionNative(id, delta); | ||
570 | + lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]); | ||
571 | + | ||
572 | + visibleInLastFrame = isObjectVisible(id); | ||
573 | + } | ||
574 | + | ||
575 | + public synchronized RectF getTrackedPositionInPreviewFrame() { | ||
576 | + checkValidObject(); | ||
577 | + | ||
578 | + if (lastTrackedPosition == null) { | ||
579 | + return null; | ||
580 | + } | ||
581 | + return upscaleRect(lastTrackedPosition); | ||
582 | + } | ||
583 | + | ||
584 | + synchronized long getLastExternalPositionTime() { | ||
585 | + return lastExternalPositionTime; | ||
586 | + } | ||
587 | + | ||
588 | + public synchronized boolean visibleInLastPreviewFrame() { | ||
589 | + return visibleInLastFrame; | ||
590 | + } | ||
591 | + | ||
592 | + private void checkValidObject() { | ||
593 | + if (isDead) { | ||
594 | + throw new RuntimeException("TrackedObject already removed from tracking!"); | ||
595 | + } else if (ObjectTracker.this != instance) { | ||
596 | + throw new RuntimeException("TrackedObject created with another ObjectTracker!"); | ||
597 | + } | ||
598 | + } | ||
599 | + } | ||
600 | + | ||
601 | + public synchronized TrackedObject trackObject( | ||
602 | + final RectF position, final long timestamp, final byte[] frameData) { | ||
603 | + if (downsampledTimestamp != timestamp) { | ||
604 | + ObjectTracker.downsampleImageNative( | ||
605 | + frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame); | ||
606 | + downsampledTimestamp = timestamp; | ||
607 | + } | ||
608 | + return new TrackedObject(position, timestamp, downsampledFrame); | ||
609 | + } | ||
610 | + | ||
611 | + public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) { | ||
612 | + return new TrackedObject(position, lastTimestamp, frameData); | ||
613 | + } | ||
614 | + | ||
615 | + /** ********************* NATIVE CODE ************************************ */ | ||
616 | + | ||
617 | + /** This will contain an opaque pointer to the native ObjectTracker */ | ||
618 | + private long nativeObjectTracker; | ||
619 | + | ||
620 | + private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack); | ||
621 | + | ||
622 | + protected native void registerNewObjectWithAppearanceNative( | ||
623 | + String objectId, float x1, float y1, float x2, float y2, byte[] data); | ||
624 | + | ||
625 | + protected native void setPreviousPositionNative( | ||
626 | + String objectId, float x1, float y1, float x2, float y2, long timestamp); | ||
627 | + | ||
628 | + protected native void setCurrentPositionNative( | ||
629 | + String objectId, float x1, float y1, float x2, float y2); | ||
630 | + | ||
631 | + protected native void forgetNative(String key); | ||
632 | + | ||
633 | + protected native String getModelIdNative(String key); | ||
634 | + | ||
635 | + protected native boolean haveObject(String key); | ||
636 | + protected native boolean isObjectVisible(String key); | ||
637 | + protected native float getCurrentCorrelation(String key); | ||
638 | + | ||
639 | + protected native float getMatchScore(String key); | ||
640 | + | ||
641 | + protected native void getTrackedPositionNative(String key, float[] points); | ||
642 | + | ||
643 | + protected native void nextFrameNative( | ||
644 | + byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix); | ||
645 | + | ||
646 | + protected native void releaseMemoryNative(); | ||
647 | + | ||
648 | + protected native void getCurrentPositionNative(long timestamp, | ||
649 | + final float positionX1, final float positionY1, | ||
650 | + final float positionX2, final float positionY2, | ||
651 | + final float[] delta); | ||
652 | + | ||
653 | + protected native byte[] getKeypointsPacked(float scaleFactor); | ||
654 | + | ||
655 | + protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints); | ||
656 | + | ||
657 | + protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas); | ||
658 | + | ||
659 | + protected static native void downsampleImageNative( | ||
660 | + int width, int height, int rowStride, byte[] input, int factor, byte[] output); | ||
661 | +} |
android/freeze_graph.py
0 → 100644
1 | +# Copyright 2015 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +# ============================================================================== | ||
15 | +"""Converts checkpoint variables into Const ops in a standalone GraphDef file. | ||
16 | +This script is designed to take a GraphDef proto, a SaverDef proto, and a set of | ||
17 | +variable values stored in a checkpoint file, and output a GraphDef with all of | ||
18 | +the variable ops converted into const ops containing the values of the | ||
19 | +variables. | ||
20 | +It's useful to do this when we need to load a single file in C++, especially in | ||
21 | +environments like mobile or embedded where we may not have access to the | ||
22 | +RestoreTensor ops and file loading calls that they rely on. | ||
23 | +An example of command-line usage is: | ||
24 | +bazel build tensorflow/python/tools:freeze_graph && \ | ||
25 | +bazel-bin/tensorflow/python/tools/freeze_graph \ | ||
26 | +--input_graph=some_graph_def.pb \ | ||
27 | +--input_checkpoint=model.ckpt-8361242 \ | ||
28 | +--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax | ||
29 | +You can also look at freeze_graph_test.py for an example of how to use it. | ||
30 | +""" | ||
31 | +from __future__ import absolute_import | ||
32 | +from __future__ import division | ||
33 | +from __future__ import print_function | ||
34 | + | ||
35 | +import tensorflow as tf | ||
36 | + | ||
37 | +from google.protobuf import text_format | ||
38 | +from tensorflow.python.framework import graph_util | ||
39 | + | ||
40 | + | ||
41 | +FLAGS = tf.app.flags.FLAGS | ||
42 | + | ||
43 | +tf.app.flags.DEFINE_string("input_graph", "", | ||
44 | + """TensorFlow 'GraphDef' file to load.""") | ||
45 | +tf.app.flags.DEFINE_string("input_saver", "", | ||
46 | + """TensorFlow saver file to load.""") | ||
47 | +tf.app.flags.DEFINE_string("input_checkpoint", "", | ||
48 | + """TensorFlow variables file to load.""") | ||
49 | +tf.app.flags.DEFINE_string("output_graph", "", | ||
50 | + """Output 'GraphDef' file name.""") | ||
51 | +tf.app.flags.DEFINE_boolean("input_binary", False, | ||
52 | + """Whether the input files are in binary format.""") | ||
53 | +tf.app.flags.DEFINE_string("output_node_names", "", | ||
54 | + """The name of the output nodes, comma separated.""") | ||
55 | +tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all", | ||
56 | + """The name of the master restore operator.""") | ||
57 | +tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0", | ||
58 | + """The name of the tensor holding the save path.""") | ||
59 | +tf.app.flags.DEFINE_boolean("clear_devices", True, | ||
60 | + """Whether to remove device specifications.""") | ||
61 | +tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of " | ||
62 | + "initializer nodes to run before freezing.") | ||
63 | + | ||
64 | + | ||
65 | +def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, | ||
66 | + output_node_names, restore_op_name, filename_tensor_name, | ||
67 | + output_graph, clear_devices, initializer_nodes): | ||
68 | + """Converts all variables in a graph and checkpoint into constants.""" | ||
69 | + | ||
70 | + if not tf.gfile.Exists(input_graph): | ||
71 | + print("Input graph file '" + input_graph + "' does not exist!") | ||
72 | + return -1 | ||
73 | + | ||
74 | + if input_saver and not tf.gfile.Exists(input_saver): | ||
75 | + print("Input saver file '" + input_saver + "' does not exist!") | ||
76 | + return -1 | ||
77 | + | ||
78 | + if not tf.gfile.Glob(input_checkpoint): | ||
79 | + print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") | ||
80 | + return -1 | ||
81 | + | ||
82 | + if not output_node_names: | ||
83 | + print("You need to supply the name of a node to --output_node_names.") | ||
84 | + return -1 | ||
85 | + | ||
86 | + input_graph_def = tf.GraphDef() | ||
87 | + mode = "rb" if input_binary else "r" | ||
88 | + with tf.gfile.FastGFile(input_graph, mode) as f: | ||
89 | + if input_binary: | ||
90 | + input_graph_def.ParseFromString(f.read()) | ||
91 | + else: | ||
92 | + text_format.Merge(f.read(), input_graph_def) | ||
93 | + # Remove all the explicit device specifications for this node. This helps to | ||
94 | + # make the graph more portable. | ||
95 | + if clear_devices: | ||
96 | + for node in input_graph_def.node: | ||
97 | + node.device = "" | ||
98 | + _ = tf.import_graph_def(input_graph_def, name="") | ||
99 | + | ||
100 | + with tf.Session() as sess: | ||
101 | + if input_saver: | ||
102 | + with tf.gfile.FastGFile(input_saver, mode) as f: | ||
103 | + saver_def = tf.train.SaverDef() | ||
104 | + if input_binary: | ||
105 | + saver_def.ParseFromString(f.read()) | ||
106 | + else: | ||
107 | + text_format.Merge(f.read(), saver_def) | ||
108 | + saver = tf.train.Saver(saver_def=saver_def) | ||
109 | + saver.restore(sess, input_checkpoint) | ||
110 | + else: | ||
111 | + sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) | ||
112 | + if initializer_nodes: | ||
113 | + sess.run(initializer_nodes) | ||
114 | + output_graph_def = graph_util.convert_variables_to_constants( | ||
115 | + sess, input_graph_def, output_node_names.split(",")) | ||
116 | + | ||
117 | + with tf.gfile.GFile(output_graph, "wb") as f: | ||
118 | + f.write(output_graph_def.SerializeToString()) | ||
119 | + print("%d ops in the final graph." % len(output_graph_def.node)) | ||
120 | + | ||
121 | + | ||
122 | +def main(unused_args): | ||
123 | + freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, | ||
124 | + FLAGS.input_checkpoint, FLAGS.output_node_names, | ||
125 | + FLAGS.restore_op_name, FLAGS.filename_tensor_name, | ||
126 | + FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes) | ||
127 | + | ||
128 | +if __name__ == "__main__": | ||
129 | + tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment