장수창

added freeze_graph

Showing 115 changed files with 16755 additions and 0 deletions
1 +# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore
2 +*.iml
3 +.idea/compiler.xml
4 +.idea/copyright
5 +.idea/dictionaries
6 +.idea/gradle.xml
7 +.idea/libraries
8 +.idea/inspectionProfiles
9 +.idea/misc.xml
10 +.idea/modules.xml
11 +.idea/runConfigurations.xml
12 +.idea/tasks.xml
13 +.idea/workspace.xml
14 +.gradle
15 +local.properties
16 +.DS_Store
17 +build/
18 +gradleBuild/
19 +*.apk
20 +*.ap_
21 +*.dex
22 +*.class
23 +bin/
24 +gen/
25 +out/
26 +*.log
27 +.navigation/
28 +/captures
29 +.externalNativeBuild
1 +<component name="ProjectCodeStyleConfiguration">
2 + <code_scheme name="Project" version="173">
3 + <codeStyleSettings language="XML">
4 + <indentOptions>
5 + <option name="CONTINUATION_INDENT_SIZE" value="4" />
6 + </indentOptions>
7 + <arrangement>
8 + <rules>
9 + <section>
10 + <rule>
11 + <match>
12 + <AND>
13 + <NAME>xmlns:android</NAME>
14 + <XML_ATTRIBUTE />
15 + <XML_NAMESPACE>^$</XML_NAMESPACE>
16 + </AND>
17 + </match>
18 + </rule>
19 + </section>
20 + <section>
21 + <rule>
22 + <match>
23 + <AND>
24 + <NAME>xmlns:.*</NAME>
25 + <XML_ATTRIBUTE />
26 + <XML_NAMESPACE>^$</XML_NAMESPACE>
27 + </AND>
28 + </match>
29 + <order>BY_NAME</order>
30 + </rule>
31 + </section>
32 + <section>
33 + <rule>
34 + <match>
35 + <AND>
36 + <NAME>.*:id</NAME>
37 + <XML_ATTRIBUTE />
38 + <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
39 + </AND>
40 + </match>
41 + </rule>
42 + </section>
43 + <section>
44 + <rule>
45 + <match>
46 + <AND>
47 + <NAME>.*:name</NAME>
48 + <XML_ATTRIBUTE />
49 + <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
50 + </AND>
51 + </match>
52 + </rule>
53 + </section>
54 + <section>
55 + <rule>
56 + <match>
57 + <AND>
58 + <NAME>name</NAME>
59 + <XML_ATTRIBUTE />
60 + <XML_NAMESPACE>^$</XML_NAMESPACE>
61 + </AND>
62 + </match>
63 + </rule>
64 + </section>
65 + <section>
66 + <rule>
67 + <match>
68 + <AND>
69 + <NAME>style</NAME>
70 + <XML_ATTRIBUTE />
71 + <XML_NAMESPACE>^$</XML_NAMESPACE>
72 + </AND>
73 + </match>
74 + </rule>
75 + </section>
76 + <section>
77 + <rule>
78 + <match>
79 + <AND>
80 + <NAME>.*</NAME>
81 + <XML_ATTRIBUTE />
82 + <XML_NAMESPACE>^$</XML_NAMESPACE>
83 + </AND>
84 + </match>
85 + <order>BY_NAME</order>
86 + </rule>
87 + </section>
88 + <section>
89 + <rule>
90 + <match>
91 + <AND>
92 + <NAME>.*</NAME>
93 + <XML_ATTRIBUTE />
94 + <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
95 + </AND>
96 + </match>
97 + <order>ANDROID_ATTRIBUTE_ORDER</order>
98 + </rule>
99 + </section>
100 + <section>
101 + <rule>
102 + <match>
103 + <AND>
104 + <NAME>.*</NAME>
105 + <XML_ATTRIBUTE />
106 + <XML_NAMESPACE>.*</XML_NAMESPACE>
107 + </AND>
108 + </match>
109 + <order>BY_NAME</order>
110 + </rule>
111 + </section>
112 + </rules>
113 + </arrangement>
114 + </codeStyleSettings>
115 + </code_scheme>
116 +</component>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<project version="4">
3 + <component name="VcsDirectoryMappings">
4 + <mapping directory="$PROJECT_DIR$/../../.." vcs="Git" />
5 + </component>
6 +</project>
...\ No newline at end of file ...\ No newline at end of file
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<!--
3 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 +
5 + Licensed under the Apache License, Version 2.0 (the "License");
6 + you may not use this file except in compliance with the License.
7 + You may obtain a copy of the License at
8 +
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +
11 + Unless required by applicable law or agreed to in writing, software
12 + distributed under the License is distributed on an "AS IS" BASIS,
13 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 + See the License for the specific language governing permissions and
15 + limitations under the License.
16 +-->
17 +
18 +<manifest xmlns:android="http://schemas.android.com/apk/res/android"
19 + package="org.tensorflow.demo">
20 +
21 + <uses-permission android:name="android.permission.CAMERA" />
22 + <uses-feature android:name="android.hardware.camera" />
23 + <uses-feature android:name="android.hardware.camera.autofocus" />
24 + <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
25 + <uses-permission android:name="android.permission.RECORD_AUDIO" />
26 +
27 + <application android:allowBackup="true"
28 + android:debuggable="true"
29 + android:label="@string/app_name"
30 + android:icon="@drawable/ic_launcher"
31 + android:theme="@style/MaterialTheme">
32 +
33 +<!-- <activity android:name="org.tensorflow.demo.ClassifierActivity"-->
34 +<!-- android:screenOrientation="portrait"-->
35 +<!-- android:label="@string/activity_name_classification">-->
36 +<!-- <intent-filter>-->
37 +<!-- <action android:name="android.intent.action.MAIN" />-->
38 +<!-- <category android:name="android.intent.category.LAUNCHER" />-->
39 +<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
40 +<!-- </intent-filter>-->
41 +<!-- </activity>-->
42 +
43 + <activity android:name="org.tensorflow.demo.DetectorActivity"
44 + android:screenOrientation="portrait"
45 + android:label="@string/activity_name_detection">
46 + <intent-filter>
47 + <action android:name="android.intent.action.MAIN" />
48 + <category android:name="android.intent.category.LAUNCHER" />
49 + <category android:name="android.intent.category.LEANBACK_LAUNCHER" />
50 + </intent-filter>
51 + </activity>
52 +
53 +<!-- <activity android:name="org.tensorflow.demo.StylizeActivity"-->
54 +<!-- android:screenOrientation="portrait"-->
55 +<!-- android:label="@string/activity_name_stylize">-->
56 +<!-- <intent-filter>-->
57 +<!-- <action android:name="android.intent.action.MAIN" />-->
58 +<!-- <category android:name="android.intent.category.LAUNCHER" />-->
59 +<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
60 +<!-- </intent-filter>-->
61 +<!-- </activity>-->
62 +
63 +<!-- <activity android:name="org.tensorflow.demo.SpeechActivity"-->
64 +<!-- android:screenOrientation="portrait"-->
65 +<!-- android:label="@string/activity_name_speech">-->
66 +<!-- <intent-filter>-->
67 +<!-- <action android:name="android.intent.action.MAIN" />-->
68 +<!-- <category android:name="android.intent.category.LAUNCHER" />-->
69 +<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
70 +<!-- </intent-filter>-->
71 +<!-- </activity>-->
72 + </application>
73 +
74 +</manifest>
1 +# Description:
2 +# TensorFlow camera demo app for Android.
3 +
4 +load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
5 +load(
6 + "//tensorflow:tensorflow.bzl",
7 + "tf_copts",
8 +)
9 +
10 +package(
11 + default_visibility = ["//visibility:public"],
12 + licenses = ["notice"], # Apache 2.0
13 +)
14 +
15 +exports_files(["LICENSE"])
16 +
17 +LINKER_SCRIPT = "jni/version_script.lds"
18 +
19 +# libtensorflow_demo.so contains the native code for image colorspace conversion
20 +# and object tracking used by the demo. It does not require TF as a dependency
21 +# to build if STANDALONE_DEMO_LIB is defined.
22 +# TF support for the demo is provided separately by libtensorflow_inference.so.
23 +cc_binary(
24 + name = "libtensorflow_demo.so",
25 + srcs = glob([
26 + "jni/**/*.cc",
27 + "jni/**/*.h",
28 + ]),
29 + copts = tf_copts(),
30 + defines = ["STANDALONE_DEMO_LIB"],
31 + linkopts = [
32 + "-landroid",
33 + "-ldl",
34 + "-ljnigraphics",
35 + "-llog",
36 + "-lm",
37 + "-z defs",
38 + "-s",
39 + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
40 + ],
41 + linkshared = 1,
42 + linkstatic = 1,
43 + tags = [
44 + "manual",
45 + "notap",
46 + ],
47 + deps = [
48 + LINKER_SCRIPT,
49 + ],
50 +)
51 +
52 +cc_library(
53 + name = "tensorflow_native_libs",
54 + srcs = [
55 + ":libtensorflow_demo.so",
56 + "//tensorflow/tools/android/inference_interface:libtensorflow_inference.so",
57 + ],
58 + tags = [
59 + "manual",
60 + "notap",
61 + ],
62 +)
63 +
64 +android_binary(
65 + name = "tensorflow_demo",
66 + srcs = glob([
67 + "src/**/*.java",
68 + ]),
69 + # Package assets from assets dir as well as all model targets. Remove undesired models
70 + # (and corresponding Activities in source) to reduce APK size.
71 + assets = [
72 + "//tensorflow/examples/android/assets:asset_files",
73 + ":external_assets",
74 + ],
75 + assets_dir = "",
76 + custom_package = "org.tensorflow.demo",
77 + manifest = "AndroidManifest.xml",
78 + resource_files = glob(["res/**"]),
79 + tags = [
80 + "manual",
81 + "notap",
82 + ],
83 + deps = [
84 + ":tensorflow_native_libs",
85 + "//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java",
86 + ],
87 +)
88 +
89 +# LINT.IfChange
90 +filegroup(
91 + name = "external_assets",
92 + srcs = [
93 + "@inception_v1//:model_files",
94 + "@mobile_ssd//:model_files",
95 + "@speech_commands//:model_files",
96 + "@stylize//:model_files",
97 + ],
98 +)
99 +# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle)
100 +
101 +filegroup(
102 + name = "java_files",
103 + srcs = glob(["src/**/*.java"]),
104 +)
105 +
106 +filegroup(
107 + name = "jni_files",
108 + srcs = glob([
109 + "jni/**/*.cc",
110 + "jni/**/*.h",
111 + ]),
112 +)
113 +
114 +filegroup(
115 + name = "resource_files",
116 + srcs = glob(["res/**"]),
117 +)
118 +
119 +exports_files([
120 + "AndroidManifest.xml",
121 +])
1 +# TensorFlow Android Camera Demo
2 +
3 +This folder contains an example application utilizing TensorFlow for Android
4 +devices.
5 +
6 +## Description
7 +
8 +The demos in this folder are designed to give straightforward samples of using
9 +TensorFlow in mobile applications.
10 +
11 +Inference is done using the [TensorFlow Android Inference
12 +Interface](../../tools/android/inference_interface), which may be built
13 +separately if you want a standalone library to drop into your existing
14 +application. Object tracking and efficient YUV -> RGB conversion are handled by
15 +`libtensorflow_demo.so`.
16 +
17 +A device running Android 5.0 (API 21) or higher is required to run the demo due
18 +to the use of the camera2 API, although the native libraries themselves can run
19 +on API >= 14 devices.
20 +
21 +## Current samples:
22 +
23 +1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java):
24 + Uses the [Google Inception](https://arxiv.org/abs/1409.4842)
25 + model to classify camera frames in real-time, displaying the top results
26 + in an overlay on the camera image.
27 +2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
28 + Demonstrates an SSD-Mobilenet model trained using the
29 + [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection/)
30 + introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to
31 + localize and track objects (from 80 categories) in the camera preview
32 + in real-time.
33 +3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java):
34 + Uses a model based on [A Learned Representation For Artistic
35 + Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
36 + image to that of a number of different artists.
37 +4. [TF
38 + Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java):
39 + Runs a simple speech recognition model built by the [audio training
40 + tutorial](https://www.tensorflow.org/versions/master/tutorials/audio_recognition). Listens
41 + for a small set of words, and highlights them in the UI when they are
42 + recognized.
43 +
44 +<img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%">
45 +
46 +## Prebuilt Components:
47 +
48 +The fastest path to trying the demo is to download the [prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
49 +
50 +Also available are precompiled native libraries, and a jcenter package that you
51 +may simply drop into your own applications. See
52 +[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
53 +for more details.
54 +
55 +## Running the Demo
56 +
57 +Once the app is installed it can be started via the "TF Classify", "TF Detect",
58 +"TF Stylize", and "TF Speech" icons, which have the orange TensorFlow logo as
59 +their icon.
60 +
61 +While running the activities, pressing the volume keys on your device will
62 +toggle debug visualizations on/off, rendering additional info to the screen that
63 +may be useful for development purposes.
64 +
65 +## Building in Android Studio using the TensorFlow AAR from JCenter
66 +
67 +The simplest way to compile the demo app yourself, and try out changes to the
68 +project code is to use AndroidStudio. Simply set this `android` directory as the
69 +project root.
70 +
71 +Then edit the `build.gradle` file and change the value of `nativeBuildSystem` to
72 +`'none'` so that the project is built in the simplest way possible:
73 +
74 +```None
75 +def nativeBuildSystem = 'none'
76 +```
77 +
78 +While this project includes full build integration for TensorFlow, this setting
79 +disables it, and uses the TensorFlow Inference Interface package from JCenter.
80 +
81 +Note: Currently, in this build mode, YUV -> RGB is done using a less efficient
82 +Java implementation, and object tracking is not available in the "TF Detect"
83 +activity. Setting the build system to `'cmake'` currently only builds
84 +`libtensorflow_demo.so`, which provides fast YUV -> RGB conversion and object
85 +tracking, while still acquiring TensorFlow support via the downloaded AAR, so it
86 +may be a lightweight way to enable these features.
87 +
88 +For any project that does not include custom low level TensorFlow code, this is
89 +likely sufficient.
90 +
91 +For details on how to include this JCenter package in your own project see
92 +[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
93 +
94 +## Building the Demo with TensorFlow from Source
95 +
96 +Pick your preferred approach below. At the moment, we have full support for
97 +Bazel, and partial support for gradle, cmake, make, and Android Studio.
98 +
99 +As a first step for all build types, clone the TensorFlow repo with:
100 +
101 +```
102 +git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git
103 +```
104 +
105 +Note that `--recurse-submodules` is necessary to prevent some issues with
106 +protobuf compilation.
107 +
108 +### Bazel
109 +
110 +NOTE: Bazel does not currently support building for Android on Windows. Full
111 +support for gradle/cmake builds is coming soon, but in the meantime we suggest
112 +that Windows users download the
113 +[prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk)
114 +instead.
115 +
116 +##### Install Bazel and Android Prerequisites
117 +
118 +Bazel is the primary build system for TensorFlow. To build with Bazel, it and
119 +the Android NDK and SDK must be installed on your system.
120 +
121 +1. Install the latest version of Bazel as per the instructions [on the Bazel
122 + website](https://bazel.build/versions/master/docs/install.html).
123 +2. The Android NDK is required to build the native (C/C++) TensorFlow code. The
124 + current recommended version is 14b, which may be found
125 + [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
126 +3. The Android SDK and build tools may be obtained
127 + [here](https://developer.android.com/tools/revisions/build-tools.html), or
128 + alternatively as part of [Android
129 + Studio](https://developer.android.com/studio/index.html). Build tools API >=
130 + 23 is required to build the TF Android demo (though it will run on API >= 21
131 + devices).
132 +
133 +##### Edit WORKSPACE
134 +
135 +NOTE: As long as you have the SDK and NDK installed, the `./configure` script
136 +will create these rules for you. Answer "Yes" when the script asks to
137 +automatically configure the `./WORKSPACE`.
138 +
139 +The Android entries in
140 +[`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L19-L36) must be uncommented
141 +with the paths filled in appropriately depending on where you installed the NDK
142 +and SDK. Otherwise an error such as: "The external label
143 +'//external:android/sdk' is not bound to anything" will be reported.
144 +
145 +Also edit the API levels for the SDK in WORKSPACE to the highest level you have
146 +installed in your SDK. This must be >= 23 (this is completely independent of the
147 +API level of the demo, which is defined in AndroidManifest.xml). The NDK API
148 +level may remain at 14.
149 +
150 +##### Install Model Files (optional)
151 +
152 +The TensorFlow `GraphDef`s that contain the model definitions and weights are
153 +not packaged in the repo because of their size. They are downloaded
154 +automatically and packaged with the APK by Bazel via a new_http_archive defined
155 +in `WORKSPACE` during the build process, and by Gradle via
156 +download-models.gradle.
157 +
158 +**Optional**: If you wish to place the models in your assets manually, remove
159 +all of the `model_files` entries from the `assets` list in `tensorflow_demo`
160 +found in the [`BUILD`](BUILD#L92) file. Then download and extract the archives
161 +yourself to the `assets` directory in the source tree:
162 +
163 +```bash
164 +BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models
165 +for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip
166 +do
167 + curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP}
168 + unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/
169 +done
170 +```
171 +
172 +This will extract the models and their associated metadata files to the local
173 +assets/ directory.
174 +
175 +If you are using Gradle, make sure to remove download-models.gradle reference
176 +from build.gradle after your manually download models; otherwise gradle might
177 +download models again and overwrite your models.
178 +
179 +##### Build
180 +
181 +After editing your WORKSPACE file to update the SDK/NDK configuration, you may
182 +build the APK. Run this from your workspace root:
183 +
184 +```bash
185 +bazel build --cxxopt='--std=c++11' -c opt //tensorflow/examples/android:tensorflow_demo
186 +```
187 +
188 +##### Install
189 +
190 +Make sure that adb debugging is enabled on your Android 5.0 (API 21) or later
191 +device, then after building use the following command from your workspace root
192 +to install the APK:
193 +
194 +```bash
195 +adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
196 +```
197 +
198 +### Android Studio with Bazel
199 +
200 +Android Studio may be used to build the demo in conjunction with Bazel. First,
201 +make sure that you can build with Bazel following the above directions. Then,
202 +look at [build.gradle](build.gradle) and make sure that the path to Bazel
203 +matches that of your system.
204 +
205 +At this point you can add the tensorflow/examples/android directory as a new
206 +Android Studio project. Click through installing all the Gradle extensions it
207 +requests, and you should be able to have Android Studio build the demo like any
208 +other application (it will call out to Bazel to build the native code with the
209 +NDK).
210 +
211 +### CMake
212 +
213 +Full CMake support for the demo is coming soon, but for now it is possible to
214 +build the TensorFlow Android Inference library using
215 +[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake).
1 +package(
2 + default_visibility = ["//visibility:public"],
3 + licenses = ["notice"], # Apache 2.0
4 +)
5 +
6 +# It is necessary to use this filegroup rather than globbing the files in this
7 +# folder directly the examples/android:tensorflow_demo target due to the fact
8 +# that assets_dir is necessarily set to "" there (to allow using other
9 +# arbitrary targets as assets).
10 +filegroup(
11 + name = "asset_files",
12 + srcs = glob(
13 + ["**/*"],
14 + exclude = ["BUILD"],
15 + ),
16 +)
This file is too large to display.
1 +// This file provides basic support for building the TensorFlow demo
2 +// in Android Studio with Gradle.
3 +//
4 +// Note that Bazel is still used by default to compile the native libs,
5 +// and should be installed at the location noted below. This build file
6 +// automates the process of calling out to it and copying the compiled
7 +// libraries back into the appropriate directory.
8 +//
9 +// Alternatively, experimental support for Makefile builds is provided by
10 +// setting nativeBuildSystem below to 'makefile'. This will allow building the demo
11 +// on Windows machines, but note that full equivalence with the Bazel
12 +// build is not yet guaranteed. See comments below for caveats and tips
13 +// for speeding up the build, such as enabling ccache.
14 +// NOTE: Running a make build will cause subsequent Bazel builds to *fail*
15 +// unless the contrib/makefile/downloads/ and gen/ dirs are deleted afterwards.
16 +
17 +// The cmake build only creates libtensorflow_demo.so. In this situation,
18 +// libtensorflow_inference.so will be acquired via the tensorflow.aar dependency.
19 +
20 +// It is necessary to customize Gradle's build directory, as otherwise
21 +// it will conflict with the BUILD file used by Bazel on case-insensitive OSs.
22 +project.buildDir = 'gradleBuild'
23 +getProject().setBuildDir('gradleBuild')
24 +
25 +buildscript {
26 + repositories {
27 + jcenter()
28 + google()
29 + }
30 +
31 + dependencies {
32 + classpath 'com.android.tools.build:gradle:3.3.1'
33 + classpath 'org.apache.httpcomponents:httpclient:4.5.4'
34 + }
35 +}
36 +
37 +allprojects {
38 + repositories {
39 + jcenter()
40 + google()
41 + }
42 +}
43 +
44 +// set to 'bazel', 'cmake', 'makefile', 'none'
45 +def nativeBuildSystem = 'none'
46 +
47 +// Controls output directory in APK and CPU type for Bazel builds.
48 +// NOTE: Does not affect the Makefile build target API (yet), which currently
49 +// assumes armeabi-v7a. If building with make, changing this will require
50 +// editing the Makefile as well.
51 +// The CMake build has only been tested with armeabi-v7a; others may not work.
52 +def cpuType = 'armeabi-v7a'
53 +
54 +// Output directory in the local directory for packaging into the APK.
55 +def nativeOutDir = 'libs/' + cpuType
56 +
57 +// Default to building with Bazel and override with make if requested.
58 +def nativeBuildRule = 'buildNativeBazel'
59 +def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so'
60 +def inferenceLibPath = '../../../bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so'
61 +
62 +// Override for Makefile builds.
63 +if (nativeBuildSystem == 'makefile') {
64 + nativeBuildRule = 'buildNativeMake'
65 + demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_demo.so'
66 + inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_inference.so'
67 +}
68 +
69 +// If building with Bazel, this is the location of the bazel binary.
70 +// NOTE: Bazel does not yet support building for Android on Windows,
71 +// so in this case the Makefile build must be used as described above.
72 +def bazelLocation = '/usr/local/bin/bazel'
73 +
74 +// import DownloadModels task
75 +project.ext.ASSET_DIR = projectDir.toString() + '/assets'
76 +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
77 +
78 +// Download default models; if you wish to use your own models then
79 +// place them in the "assets" directory and comment out this line.
80 +apply from: "download-models.gradle"
81 +
82 +apply plugin: 'com.android.application'
83 +
84 +android {
85 + compileSdkVersion 23
86 +
87 + if (nativeBuildSystem == 'cmake') {
88 + defaultConfig {
89 + applicationId = 'org.tensorflow.demo'
90 + minSdkVersion 21
91 + targetSdkVersion 23
92 + ndk {
93 + abiFilters "${cpuType}"
94 + }
95 + externalNativeBuild {
96 + cmake {
97 + arguments '-DANDROID_STL=c++_static'
98 + }
99 + }
100 + }
101 + externalNativeBuild {
102 + cmake {
103 + path './jni/CMakeLists.txt'
104 + }
105 + }
106 + }
107 +
108 + lintOptions {
109 + abortOnError false
110 + }
111 +
112 + sourceSets {
113 + main {
114 + if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') {
115 + // TensorFlow Java API sources.
116 + java {
117 + srcDir '../../java/src/main/java'
118 + exclude '**/examples/**'
119 + }
120 +
121 + // Android TensorFlow wrappers, etc.
122 + java {
123 + srcDir '../../tools/android/inference_interface/java'
124 + }
125 + }
126 + // Android demo app sources.
127 + java {
128 + srcDir 'src'
129 + }
130 +
131 + manifest.srcFile 'AndroidManifest.xml'
132 + resources.srcDirs = ['src']
133 + aidl.srcDirs = ['src']
134 + renderscript.srcDirs = ['src']
135 + res.srcDirs = ['res']
136 + assets.srcDirs = [project.ext.ASSET_DIR]
137 + jniLibs.srcDirs = ['libs']
138 + }
139 +
140 + debug.setRoot('build-types/debug')
141 + release.setRoot('build-types/release')
142 + }
143 + defaultConfig {
144 + targetSdkVersion 23
145 + minSdkVersion 21
146 + }
147 +}
148 +
149 +task buildNativeBazel(type: Exec) {
150 + workingDir '../../..'
151 + commandLine bazelLocation, 'build', '-c', 'opt', \
152 + 'tensorflow/examples/android:tensorflow_native_libs', \
153 + '--crosstool_top=//external:android/crosstool', \
154 + '--cpu=' + cpuType, \
155 + '--host_crosstool_top=@bazel_tools//tools/cpp:toolchain'
156 +}
157 +
158 +task buildNativeMake(type: Exec) {
159 + environment "NDK_ROOT", android.ndkDirectory
160 + // Tip: install ccache and uncomment the following to speed up
161 + // builds significantly.
162 + // environment "CC_PREFIX", 'ccache'
163 + workingDir '../../..'
164 + commandLine 'tensorflow/contrib/makefile/build_all_android.sh', \
165 + '-s', \
166 + 'tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in', \
167 + '-t', \
168 + 'libtensorflow_inference.so libtensorflow_demo.so all' \
169 + , '-a', cpuType \
170 + //, '-T' // Uncomment to skip protobuf and speed up subsequent builds.
171 +}
172 +
173 +
174 +task copyNativeLibs(type: Copy) {
175 + from demoLibPath
176 + from inferenceLibPath
177 + into nativeOutDir
178 + duplicatesStrategy = 'include'
179 + dependsOn nativeBuildRule
180 + fileMode 0644
181 +}
182 +
183 +tasks.whenTaskAdded { task ->
184 + if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') {
185 + if (task.name == 'assembleDebug') {
186 + task.dependsOn 'copyNativeLibs'
187 + }
188 + if (task.name == 'assembleRelease') {
189 + task.dependsOn 'copyNativeLibs'
190 + }
191 + }
192 +}
193 +
194 +dependencies {
195 + if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') {
196 + implementation 'org.tensorflow:tensorflow-android:+'
197 + }
198 +}
1 +/*
2 + * download-models.gradle
3 + * Downloads model files from ${MODEL_URL} into application's asset folder
4 + * Input:
5 + * project.ext.TMP_DIR: absolute path to hold downloaded zip files
6 + * project.ext.ASSET_DIR: absolute path to save unzipped model files
7 + * Output:
8 + * 3 model files will be downloaded into given folder of ext.ASSET_DIR
9 + */
10 +// hard coded model files
11 +// LINT.IfChange
12 +def models = ['inception_v1.zip',
13 + 'object_detection/ssd_mobilenet_v1_android_export.zip',
14 + 'stylize_v1.zip',
15 + 'speech_commands_conv_actions.zip']
16 +// LINT.ThenChange(//tensorflow/examples/android/BUILD)
17 +
18 +// Root URL for model archives
19 +def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models'
20 +
21 +buildscript {
22 + repositories {
23 + jcenter()
24 + }
25 + dependencies {
26 + classpath 'de.undercouch:gradle-download-task:3.2.0'
27 + }
28 +}
29 +
30 +import de.undercouch.gradle.tasks.download.Download
31 +task downloadFile(type: Download){
32 + for (f in models) {
33 + src "${MODEL_URL}/" + f
34 + }
35 + dest new File(project.ext.TMP_DIR)
36 + overwrite true
37 +}
38 +
39 +task extractModels(type: Copy) {
40 + for (f in models) {
41 + def localFile = f.split("/")[-1]
42 + from zipTree(project.ext.TMP_DIR + '/' + localFile)
43 + }
44 +
45 + into file(project.ext.ASSET_DIR)
46 + fileMode 0644
47 + exclude '**/LICENSE'
48 +
49 + def needDownload = false
50 + for (f in models) {
51 + def localFile = f.split("/")[-1]
52 + if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) {
53 + needDownload = true
54 + }
55 + }
56 +
57 + if (needDownload) {
58 + dependsOn downloadFile
59 + }
60 +}
61 +
62 +tasks.whenTaskAdded { task ->
63 + if (task.name == 'assembleDebug') {
64 + task.dependsOn 'extractModels'
65 + }
66 + if (task.name == 'assembleRelease') {
67 + task.dependsOn 'extractModels'
68 + }
69 +}
70 +
1 +#Sat Nov 18 15:06:47 CET 2017
2 +distributionBase=GRADLE_USER_HOME
3 +distributionPath=wrapper/dists
4 +zipStoreBase=GRADLE_USER_HOME
5 +zipStorePath=wrapper/dists
6 +distributionUrl=https\://services.gradle.org/distributions/gradle-4.1-all.zip
1 +#!/usr/bin/env bash
2 +
3 +##############################################################################
4 +##
5 +## Gradle start up script for UN*X
6 +##
7 +##############################################################################
8 +
9 +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
10 +DEFAULT_JVM_OPTS=""
11 +
12 +APP_NAME="Gradle"
13 +APP_BASE_NAME=`basename "$0"`
14 +
15 +# Use the maximum available, or set MAX_FD != -1 to use that value.
16 +MAX_FD="maximum"
17 +
18 +warn ( ) {
19 + echo "$*"
20 +}
21 +
22 +die ( ) {
23 + echo
24 + echo "$*"
25 + echo
26 + exit 1
27 +}
28 +
29 +# OS specific support (must be 'true' or 'false').
30 +cygwin=false
31 +msys=false
32 +darwin=false
33 +case "`uname`" in
34 + CYGWIN* )
35 + cygwin=true
36 + ;;
37 + Darwin* )
38 + darwin=true
39 + ;;
40 + MINGW* )
41 + msys=true
42 + ;;
43 +esac
44 +
45 +# Attempt to set APP_HOME
46 +# Resolve links: $0 may be a link
47 +PRG="$0"
48 +# Need this for relative symlinks.
49 +while [ -h "$PRG" ] ; do
50 + ls=`ls -ld "$PRG"`
51 + link=`expr "$ls" : '.*-> \(.*\)$'`
52 + if expr "$link" : '/.*' > /dev/null; then
53 + PRG="$link"
54 + else
55 + PRG=`dirname "$PRG"`"/$link"
56 + fi
57 +done
58 +SAVED="`pwd`"
59 +cd "`dirname \"$PRG\"`/" >/dev/null
60 +APP_HOME="`pwd -P`"
61 +cd "$SAVED" >/dev/null
62 +
63 +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
64 +
65 +# Determine the Java command to use to start the JVM.
66 +if [ -n "$JAVA_HOME" ] ; then
67 + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
68 + # IBM's JDK on AIX uses strange locations for the executables
69 + JAVACMD="$JAVA_HOME/jre/sh/java"
70 + else
71 + JAVACMD="$JAVA_HOME/bin/java"
72 + fi
73 + if [ ! -x "$JAVACMD" ] ; then
74 + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
75 +
76 +Please set the JAVA_HOME variable in your environment to match the
77 +location of your Java installation."
78 + fi
79 +else
80 + JAVACMD="java"
81 + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
82 +
83 +Please set the JAVA_HOME variable in your environment to match the
84 +location of your Java installation."
85 +fi
86 +
87 +# Increase the maximum file descriptors if we can.
88 +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
89 + MAX_FD_LIMIT=`ulimit -H -n`
90 + if [ $? -eq 0 ] ; then
91 + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
92 + MAX_FD="$MAX_FD_LIMIT"
93 + fi
94 + ulimit -n $MAX_FD
95 + if [ $? -ne 0 ] ; then
96 + warn "Could not set maximum file descriptor limit: $MAX_FD"
97 + fi
98 + else
99 + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
100 + fi
101 +fi
102 +
103 +# For Darwin, add options to specify how the application appears in the dock
104 +if $darwin; then
105 + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
106 +fi
107 +
108 +# For Cygwin, switch paths to Windows format before running java
109 +if $cygwin ; then
110 + APP_HOME=`cygpath --path --mixed "$APP_HOME"`
111 + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
112 + JAVACMD=`cygpath --unix "$JAVACMD"`
113 +
114 + # We build the pattern for arguments to be converted via cygpath
115 + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
116 + SEP=""
117 + for dir in $ROOTDIRSRAW ; do
118 + ROOTDIRS="$ROOTDIRS$SEP$dir"
119 + SEP="|"
120 + done
121 + OURCYGPATTERN="(^($ROOTDIRS))"
122 + # Add a user-defined pattern to the cygpath arguments
123 + if [ "$GRADLE_CYGPATTERN" != "" ] ; then
124 + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
125 + fi
126 + # Now convert the arguments - kludge to limit ourselves to /bin/sh
127 + i=0
128 + for arg in "$@" ; do
129 + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
130 + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
131 +
132 + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
133 + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
134 + else
135 + eval `echo args$i`="\"$arg\""
136 + fi
137 + i=$((i+1))
138 + done
139 + case $i in
140 + (0) set -- ;;
141 + (1) set -- "$args0" ;;
142 + (2) set -- "$args0" "$args1" ;;
143 + (3) set -- "$args0" "$args1" "$args2" ;;
144 + (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
145 + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
146 + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
147 + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
148 + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
149 + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
150 + esac
151 +fi
152 +
153 +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
154 +function splitJvmOpts() {
155 + JVM_OPTS=("$@")
156 +}
157 +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
158 +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
159 +
160 +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
1 +@if "%DEBUG%" == "" @echo off
2 +@rem ##########################################################################
3 +@rem
4 +@rem Gradle startup script for Windows
5 +@rem
6 +@rem ##########################################################################
7 +
8 +@rem Set local scope for the variables with windows NT shell
9 +if "%OS%"=="Windows_NT" setlocal
10 +
11 +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
12 +set DEFAULT_JVM_OPTS=
13 +
14 +set DIRNAME=%~dp0
15 +if "%DIRNAME%" == "" set DIRNAME=.
16 +set APP_BASE_NAME=%~n0
17 +set APP_HOME=%DIRNAME%
18 +
19 +@rem Find java.exe
20 +if defined JAVA_HOME goto findJavaFromJavaHome
21 +
22 +set JAVA_EXE=java.exe
23 +%JAVA_EXE% -version >NUL 2>&1
24 +if "%ERRORLEVEL%" == "0" goto init
25 +
26 +echo.
27 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
28 +echo.
29 +echo Please set the JAVA_HOME variable in your environment to match the
30 +echo location of your Java installation.
31 +
32 +goto fail
33 +
34 +:findJavaFromJavaHome
35 +set JAVA_HOME=%JAVA_HOME:"=%
36 +set JAVA_EXE=%JAVA_HOME%/bin/java.exe
37 +
38 +if exist "%JAVA_EXE%" goto init
39 +
40 +echo.
41 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
42 +echo.
43 +echo Please set the JAVA_HOME variable in your environment to match the
44 +echo location of your Java installation.
45 +
46 +goto fail
47 +
48 +:init
49 +@rem Get command-line arguments, handling Windowz variants
50 +
51 +if not "%OS%" == "Windows_NT" goto win9xME_args
52 +if "%@eval[2+2]" == "4" goto 4NT_args
53 +
54 +:win9xME_args
55 +@rem Slurp the command line arguments.
56 +set CMD_LINE_ARGS=
57 +set _SKIP=2
58 +
59 +:win9xME_args_slurp
60 +if "x%~1" == "x" goto execute
61 +
62 +set CMD_LINE_ARGS=%*
63 +goto execute
64 +
65 +:4NT_args
66 +@rem Get arguments from the 4NT Shell from JP Software
67 +set CMD_LINE_ARGS=%$
68 +
69 +:execute
70 +@rem Setup the command line
71 +
72 +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
73 +
74 +@rem Execute Gradle
75 +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
76 +
77 +:end
78 +@rem End local scope for the variables with windows NT shell
79 +if "%ERRORLEVEL%"=="0" goto mainEnd
80 +
81 +:fail
82 +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
83 +rem the _cmd.exe /c_ return code!
84 +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
85 +exit /b 1
86 +
87 +:mainEnd
88 +if "%OS%"=="Windows_NT" endlocal
89 +
90 +:omega
1 +#
2 +# Copyright (C) 2016 The Android Open Source Project
3 +#
4 +# Licensed under the Apache License, Version 2.0 (the "License");
5 +# you may not use this file except in compliance with the License.
6 +# You may obtain a copy of the License at
7 +#
8 +# http://www.apache.org/licenses/LICENSE-2.0
9 +#
10 +# Unless required by applicable law or agreed to in writing, software
11 +# distributed under the License is distributed on an "AS IS" BASIS,
12 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +# See the License for the specific language governing permissions and
14 +# limitations under the License.
15 +#
16 +
17 +project(TENSORFLOW_DEMO)
18 +cmake_minimum_required(VERSION 3.4.1)
19 +
20 +set(CMAKE_VERBOSE_MAKEFILE on)
21 +
22 +get_filename_component(TF_SRC_ROOT ${CMAKE_SOURCE_DIR}/../../../.. ABSOLUTE)
23 +get_filename_component(SAMPLE_SRC_DIR ${CMAKE_SOURCE_DIR}/.. ABSOLUTE)
24 +
25 +if (ANDROID_ABI MATCHES "^armeabi-v7a$")
26 + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon")
27 +elseif(ANDROID_ABI MATCHES "^arm64-v8a")
28 + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -ftree-vectorize")
29 +endif()
30 +
31 +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSTANDALONE_DEMO_LIB \
32 + -std=c++11 -fno-exceptions -fno-rtti -O2 -Wno-narrowing \
33 + -fPIE")
34 +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
35 + -Wl,--allow-multiple-definition \
36 + -Wl,--whole-archive -fPIE -v")
37 +
38 +file(GLOB_RECURSE tensorflow_demo_sources ${SAMPLE_SRC_DIR}/jni/*.*)
39 +add_library(tensorflow_demo SHARED
40 + ${tensorflow_demo_sources})
41 +target_include_directories(tensorflow_demo PRIVATE
42 + ${TF_SRC_ROOT}
43 + ${CMAKE_SOURCE_DIR})
44 +
45 +target_link_libraries(tensorflow_demo
46 + android
47 + log
48 + jnigraphics
49 + m
50 + atomic
51 + z)
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// This file binds the native image utility code to the Java class
17 +// which exposes them.
18 +
19 +#include <jni.h>
20 +#include <stdio.h>
21 +#include <stdlib.h>
22 +
23 +#include "tensorflow/examples/android/jni/rgb2yuv.h"
24 +#include "tensorflow/examples/android/jni/yuv2rgb.h"
25 +
26 +#define IMAGEUTILS_METHOD(METHOD_NAME) \
27 + Java_org_tensorflow_demo_env_ImageUtils_##METHOD_NAME // NOLINT
28 +
29 +#ifdef __cplusplus
30 +extern "C" {
31 +#endif
32 +
33 +JNIEXPORT void JNICALL
34 +IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
35 + JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
36 + jint width, jint height, jboolean halfSize);
37 +
38 +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
39 + JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
40 + jintArray output, jint width, jint height, jint y_row_stride,
41 + jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize);
42 +
43 +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
44 + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
45 + jint height);
46 +
47 +JNIEXPORT void JNICALL
48 +IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
49 + JNIEnv* env, jclass clazz, jintArray input, jbyteArray output,
50 + jint width, jint height);
51 +
52 +JNIEXPORT void JNICALL
53 +IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)(
54 + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
55 + jint width, jint height);
56 +
57 +#ifdef __cplusplus
58 +}
59 +#endif
60 +
61 +JNIEXPORT void JNICALL
62 +IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
63 + JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
64 + jint width, jint height, jboolean halfSize) {
65 + jboolean inputCopy = JNI_FALSE;
66 + jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
67 +
68 + jboolean outputCopy = JNI_FALSE;
69 + jint* const o = env->GetIntArrayElements(output, &outputCopy);
70 +
71 + if (halfSize) {
72 + ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(i),
73 + reinterpret_cast<uint32_t*>(o), width,
74 + height);
75 + } else {
76 + ConvertYUV420SPToARGB8888(reinterpret_cast<uint8_t*>(i),
77 + reinterpret_cast<uint8_t*>(i) + width * height,
78 + reinterpret_cast<uint32_t*>(o), width, height);
79 + }
80 +
81 + env->ReleaseByteArrayElements(input, i, JNI_ABORT);
82 + env->ReleaseIntArrayElements(output, o, 0);
83 +}
84 +
85 +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
86 + JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
87 + jintArray output, jint width, jint height, jint y_row_stride,
88 + jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize) {
89 + jboolean inputCopy = JNI_FALSE;
90 + jbyte* const y_buff = env->GetByteArrayElements(y, &inputCopy);
91 + jboolean outputCopy = JNI_FALSE;
92 + jint* const o = env->GetIntArrayElements(output, &outputCopy);
93 +
94 + if (halfSize) {
95 + ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(y_buff),
96 + reinterpret_cast<uint32_t*>(o), width,
97 + height);
98 + } else {
99 + jbyte* const u_buff = env->GetByteArrayElements(u, &inputCopy);
100 + jbyte* const v_buff = env->GetByteArrayElements(v, &inputCopy);
101 +
102 + ConvertYUV420ToARGB8888(
103 + reinterpret_cast<uint8_t*>(y_buff), reinterpret_cast<uint8_t*>(u_buff),
104 + reinterpret_cast<uint8_t*>(v_buff), reinterpret_cast<uint32_t*>(o),
105 + width, height, y_row_stride, uv_row_stride, uv_pixel_stride);
106 +
107 + env->ReleaseByteArrayElements(u, u_buff, JNI_ABORT);
108 + env->ReleaseByteArrayElements(v, v_buff, JNI_ABORT);
109 + }
110 +
111 + env->ReleaseByteArrayElements(y, y_buff, JNI_ABORT);
112 + env->ReleaseIntArrayElements(output, o, 0);
113 +}
114 +
115 +JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
116 + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
117 + jint height) {
118 + jboolean inputCopy = JNI_FALSE;
119 + jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
120 +
121 + jboolean outputCopy = JNI_FALSE;
122 + jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
123 +
124 + ConvertYUV420SPToRGB565(reinterpret_cast<uint8_t*>(i),
125 + reinterpret_cast<uint16_t*>(o), width, height);
126 +
127 + env->ReleaseByteArrayElements(input, i, JNI_ABORT);
128 + env->ReleaseByteArrayElements(output, o, 0);
129 +}
130 +
131 +JNIEXPORT void JNICALL
132 +IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
133 + JNIEnv* env, jclass clazz, jintArray input, jbyteArray output,
134 + jint width, jint height) {
135 + jboolean inputCopy = JNI_FALSE;
136 + jint* const i = env->GetIntArrayElements(input, &inputCopy);
137 +
138 + jboolean outputCopy = JNI_FALSE;
139 + jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
140 +
141 + ConvertARGB8888ToYUV420SP(reinterpret_cast<uint32_t*>(i),
142 + reinterpret_cast<uint8_t*>(o), width, height);
143 +
144 + env->ReleaseIntArrayElements(input, i, JNI_ABORT);
145 + env->ReleaseByteArrayElements(output, o, 0);
146 +}
147 +
148 +JNIEXPORT void JNICALL
149 +IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)(
150 + JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
151 + jint width, jint height) {
152 + jboolean inputCopy = JNI_FALSE;
153 + jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
154 +
155 + jboolean outputCopy = JNI_FALSE;
156 + jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
157 +
158 + ConvertRGB565ToYUV420SP(reinterpret_cast<uint16_t*>(i),
159 + reinterpret_cast<uint8_t*>(o), width, height);
160 +
161 + env->ReleaseByteArrayElements(input, i, JNI_ABORT);
162 + env->ReleaseByteArrayElements(output, o, 0);
163 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
18 +
19 +#include <math.h>
20 +
21 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 +
23 +namespace tf_tracking {
24 +
25 +// Arbitrary keypoint type ids for labeling the origin of tracked keypoints.
26 +enum KeypointType {
27 + KEYPOINT_TYPE_DEFAULT = 0,
28 + KEYPOINT_TYPE_FAST = 1,
29 + KEYPOINT_TYPE_INTEREST = 2
30 +};
31 +
32 +// Struct that can be used to more richly store the results of a detection
33 +// than a single number, while still maintaining comparability.
34 +struct MatchScore {
35 + explicit MatchScore(double val) : value(val) {}
36 + MatchScore() { value = 0.0; }
37 +
38 + double value;
39 +
40 + MatchScore& operator+(const MatchScore& rhs) {
41 + value += rhs.value;
42 + return *this;
43 + }
44 +
45 + friend std::ostream& operator<<(std::ostream& stream,
46 + const MatchScore& detection) {
47 + stream << detection.value;
48 + return stream;
49 + }
50 +};
51 +inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) {
52 + return cC1.value < cC2.value;
53 +}
54 +inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) {
55 + return cC1.value > cC2.value;
56 +}
57 +inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) {
58 + return cC1.value >= cC2.value;
59 +}
60 +inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) {
61 + return cC1.value <= cC2.value;
62 +}
63 +
64 +// Fixed seed used for all random number generators.
65 +static const int kRandomNumberSeed = 11111;
66 +
67 +// TODO(andrewharp): Move as many of these settings as possible into a settings
68 +// object which can be passed in from Java at runtime.
69 +
70 +// Whether or not to use ESM instead of LK flow.
71 +static const bool kUseEsm = false;
72 +
73 +// This constant gets added to the diagonal of the Hessian
74 +// before solving for translation in 2dof ESM.
75 +// It ensures better behavior especially in the absence of
76 +// strong texture.
77 +static const int kEsmRegularizer = 20;
78 +
79 +// Do we want to brightness-normalize each keypoint patch when we compute
80 +// its flow using ESM?
81 +static const bool kDoBrightnessNormalize = true;
82 +
83 +// Whether or not to use fixed-point interpolated pixel lookups in optical flow.
84 +#define USE_FIXED_POINT_FLOW 1
85 +
86 +// Whether to normalize keypoint windows for intensity in LK optical flow.
87 +// This is a define for now because it helps keep the code streamlined.
88 +#define NORMALIZE 1
89 +
90 +// Number of keypoints to store per frame.
91 +static const int kMaxKeypoints = 76;
92 +
93 +// Keypoint detection.
94 +static const int kMaxTempKeypoints = 1024;
95 +
96 +// Number of floats each keypoint takes up when exporting to an array.
97 +static const int kKeypointStep = 7;
98 +
99 +// Number of frame deltas to keep around in the circular queue.
100 +static const int kNumFrames = 512;
101 +
102 +// Number of iterations to do tracking on each keypoint at each pyramid level.
103 +static const int kNumIterations = 3;
104 +
105 +// The number of bins (on a side) to divide each bin from the previous
106 +// cache level into. Higher numbers will decrease performance by increasing
107 +// cache misses, but mean that cache hits are more locally relevant.
108 +static const int kCacheBranchFactor = 2;
109 +
110 +// Number of levels to put in the cache.
111 +// Each level of the cache is a square grid of bins, length:
112 +// branch_factor^(level - 1) on each side.
113 +//
114 +// This may be greater than kNumPyramidLevels. Setting it to 0 means no
115 +// caching is enabled.
116 +static const int kNumCacheLevels = 3;
117 +
118 +// The level at which the cache pyramid gets cut off and replaced by a matrix
119 +// transform if such a matrix has been provided to the cache.
120 +static const int kCacheCutoff = 1;
121 +
122 +static const int kNumPyramidLevels = 4;
123 +
124 +// The minimum number of keypoints needed in an object's area.
125 +static const int kMaxKeypointsForObject = 16;
126 +
127 +// Minimum number of pyramid levels to use after getting cached value.
128 +// This allows fine-scale adjustment from the cached value, which is taken
129 +// from the center of the corresponding top cache level box.
130 +// Can be [0, kNumPyramidLevels).
131 +static const int kMinNumPyramidLevelsToUseForAdjustment = 1;
132 +
133 +// Window size to integrate over to find local image derivative.
134 +static const int kFlowIntegrationWindowSize = 3;
135 +
136 +// Total area of integration windows.
137 +static const int kFlowArraySize =
138 + (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1);
139 +
140 +// Error that's considered good enough to early abort tracking.
141 +static const float kTrackingAbortThreshold = 0.03f;
142 +
143 +// Maximum number of deviations a keypoint-correspondence delta can be from the
144 +// weighted average before being thrown out for region-based queries.
145 +static const float kNumDeviations = 2.0f;
146 +
147 +// The length of the allowed delta between the forward and the backward
148 +// flow deltas in terms of the length of the forward flow vector.
149 +static const float kMaxForwardBackwardErrorAllowed = 0.5f;
150 +
151 +// Threshold for pixels to be considered different.
152 +static const int kFastDiffAmount = 10;
153 +
154 +// How far from edge of frame to stop looking for FAST keypoints.
155 +static const int kFastBorderBuffer = 10;
156 +
157 +// Determines if non-detected arbitrary keypoints should be added to regions.
158 +// This will help if no keypoints have been detected in the region yet.
159 +static const bool kAddArbitraryKeypoints = true;
160 +
161 +// How many arbitrary keypoints to add along each axis as candidates for each
162 +// region?
163 +static const int kNumToAddAsCandidates = 1;
164 +
165 +// In terms of region dimensions, how closely can we place keypoints
166 +// next to each other?
167 +static const float kClosestPercent = 0.6f;
168 +
169 +// How many FAST qualifying pixels must be connected to a pixel for it to be
170 +// considered a candidate keypoint for Harris filtering.
171 +static const int kMinNumConnectedForFastKeypoint = 8;
172 +
173 +// Size of the window to integrate over for Harris filtering.
174 +// Compare to kFlowIntegrationWindowSize.
175 +static const int kHarrisWindowSize = 2;
176 +
177 +
178 +// DETECTOR PARAMETERS
179 +
180 +// Before relocalizing, make sure the new proposed position is better than
181 +// the existing position by a small amount to prevent thrashing.
182 +static const MatchScore kMatchScoreBuffer(0.01f);
183 +
184 +// Minimum score a tracked object can have and still be considered a match.
185 +// TODO(andrewharp): Make this a per detector thing.
186 +static const MatchScore kMinimumMatchScore(0.5f);
187 +
188 +static const float kMinimumCorrelationForTracking = 0.4f;
189 +
190 +static const MatchScore kMatchScoreForImmediateTermination(0.0f);
191 +
192 +// Run the detector every N frames.
193 +static const int kDetectEveryNFrames = 4;
194 +
195 +// How many features does each feature_set contain?
196 +static const int kFeaturesPerFeatureSet = 10;
197 +
198 +// The number of FeatureSets managed by the object detector.
199 +// More FeatureSets can increase recall at the cost of performance.
200 +static const int kNumFeatureSets = 7;
201 +
202 +// How many FeatureSets must respond affirmatively for a candidate descriptor
203 +// and position to be given more thorough attention?
204 +static const int kNumFeatureSetsForCandidate = 2;
205 +
206 +// How large the thumbnails used for correlation validation are. Used for both
207 +// width and height.
208 +static const int kNormalizedThumbnailSize = 11;
209 +
210 +// The area of intersection divided by union for the bounding boxes that tells
211 +// if this tracking has slipped enough to invalidate all unlocked examples.
212 +static const float kPositionOverlapThreshold = 0.6f;
213 +
214 +// The number of detection failures allowed before an object goes invisible.
215 +// Tracking will still occur, so if it is actually still being tracked and
216 +// comes back into a detectable position, it's likely to be found.
217 +static const int kMaxNumDetectionFailures = 4;
218 +
219 +
220 +// Minimum square size to scan with sliding window.
221 +static const float kScanMinSquareSize = 16.0f;
222 +
223 +// Minimum square size to scan with sliding window.
224 +static const float kScanMaxSquareSize = 64.0f;
225 +
226 +// Scale difference for consecutive scans of the sliding window.
227 +static const float kScanScaleFactor = sqrtf(2.0f);
228 +
229 +// Step size for sliding window.
230 +static const int kScanStepSize = 10;
231 +
232 +
233 +// How tightly to pack the descriptor boxes for confirmed exemplars.
234 +static const float kLockedScaleFactor = 1 / sqrtf(2.0f);
235 +
236 +// How tightly to pack the descriptor boxes for unconfirmed exemplars.
237 +static const float kUnlockedScaleFactor = 1 / 2.0f;
238 +
239 +// How tightly the boxes to scan centered at the last known position will be
240 +// packed.
241 +static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f);
242 +
243 +// The bounds on how close a new object example must be to existing object
244 +// examples for detection to be valid.
245 +static const float kMinCorrelationForNewExample = 0.75f;
246 +static const float kMaxCorrelationForNewExample = 0.99f;
247 +
248 +
249 +// The number of safe tries an exemplar has after being created before
250 +// missed detections count against it.
251 +static const int kFreeTries = 5;
252 +
253 +// A false positive is worth this many missed detections.
254 +static const int kFalsePositivePenalty = 5;
255 +
256 +struct ObjectDetectorConfig {
257 + const Size image_size;
258 +
259 + explicit ObjectDetectorConfig(const Size& image_size)
260 + : image_size(image_size) {}
261 + virtual ~ObjectDetectorConfig() = default;
262 +};
263 +
264 +struct KeypointDetectorConfig {
265 + const Size image_size;
266 +
267 + bool detect_skin;
268 +
269 + explicit KeypointDetectorConfig(const Size& image_size)
270 + : image_size(image_size),
271 + detect_skin(false) {}
272 +};
273 +
274 +
275 +struct OpticalFlowConfig {
276 + const Size image_size;
277 +
278 + explicit OpticalFlowConfig(const Size& image_size)
279 + : image_size(image_size) {}
280 +};
281 +
282 +struct TrackerConfig {
283 + const Size image_size;
284 + KeypointDetectorConfig keypoint_detector_config;
285 + OpticalFlowConfig flow_config;
286 + bool always_track;
287 +
288 + float object_box_scale_factor_for_features;
289 +
290 + explicit TrackerConfig(const Size& image_size)
291 + : image_size(image_size),
292 + keypoint_detector_config(image_size),
293 + flow_config(image_size),
294 + always_track(false),
295 + object_box_scale_factor_for_features(1.0f) {}
296 +};
297 +
298 +} // namespace tf_tracking
299 +
300 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
18 +
19 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
21 +
22 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
24 +
25 +namespace tf_tracking {
26 +
27 +// Class that helps OpticalFlow to speed up flow computation
28 +// by caching coarse-grained flow.
29 +class FlowCache {
30 + public:
31 + explicit FlowCache(const OpticalFlowConfig* const config)
32 + : config_(config),
33 + image_size_(config->image_size),
34 + optical_flow_(config),
35 + fullframe_matrix_(NULL) {
36 + for (int i = 0; i < kNumCacheLevels; ++i) {
37 + const int curr_dims = BlockDimForCacheLevel(i);
38 + has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
39 + displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
40 + }
41 + }
42 +
43 + ~FlowCache() {
44 + for (int i = 0; i < kNumCacheLevels; ++i) {
45 + SAFE_DELETE(has_cache_[i]);
46 + SAFE_DELETE(displacements_[i]);
47 + }
48 + delete[](fullframe_matrix_);
49 + fullframe_matrix_ = NULL;
50 + }
51 +
52 + void NextFrame(ImageData* const new_frame,
53 + const float* const align_matrix23) {
54 + ClearCache();
55 + SetFullframeAlignmentMatrix(align_matrix23);
56 + optical_flow_.NextFrame(new_frame);
57 + }
58 +
59 + void ClearCache() {
60 + for (int i = 0; i < kNumCacheLevels; ++i) {
61 + has_cache_[i]->Clear(false);
62 + }
63 + delete[](fullframe_matrix_);
64 + fullframe_matrix_ = NULL;
65 + }
66 +
67 + // Finds the flow at a point, using the cache for performance.
68 + bool FindFlowAtPoint(const float u_x, const float u_y,
69 + float* const flow_x, float* const flow_y) const {
70 + // Get the best guess from the cache.
71 + const Point2f guess_from_cache = LookupGuess(u_x, u_y);
72 +
73 + *flow_x = guess_from_cache.x;
74 + *flow_y = guess_from_cache.y;
75 +
76 + // Now refine the guess using the image pyramid.
77 + for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
78 + pyramid_level >= 0; --pyramid_level) {
79 + if (!optical_flow_.FindFlowAtPointSingleLevel(
80 + pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
81 + return false;
82 + }
83 + }
84 +
85 + return true;
86 + }
87 +
88 + // Determines the displacement of a point, and uses that to calculate a new
89 + // position.
90 + // Returns true iff the displacement determination worked and the new position
91 + // is in the image.
92 + bool FindNewPositionOfPoint(const float u_x, const float u_y,
93 + float* final_x, float* final_y) const {
94 + float flow_x;
95 + float flow_y;
96 + if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
97 + return false;
98 + }
99 +
100 + // Add in the displacement to get the final position.
101 + *final_x = u_x + flow_x;
102 + *final_y = u_y + flow_y;
103 +
104 + // Assign the best guess, if we're still in the image.
105 + if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
106 + InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
107 + return true;
108 + } else {
109 + return false;
110 + }
111 + }
112 +
113 + // Comparison function for qsort.
114 + static int Compare(const void* a, const void* b) {
115 + return *reinterpret_cast<const float*>(a) -
116 + *reinterpret_cast<const float*>(b);
117 + }
118 +
119 + // Returns the median flow within the given bounding box as determined
120 + // by a grid_width x grid_height grid.
121 + Point2f GetMedianFlow(const BoundingBox& bounding_box,
122 + const bool filter_by_fb_error,
123 + const int grid_width,
124 + const int grid_height) const {
125 + const int kMaxPoints = 100;
126 + SCHECK(grid_width * grid_height <= kMaxPoints,
127 + "Too many points for Median flow!");
128 +
129 + const BoundingBox valid_box = bounding_box.Intersect(
130 + BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
131 +
132 + if (valid_box.GetArea() <= 0.0f) {
133 + return Point2f(0, 0);
134 + }
135 +
136 + float x_deltas[kMaxPoints];
137 + float y_deltas[kMaxPoints];
138 +
139 + int curr_offset = 0;
140 + for (int i = 0; i < grid_width; ++i) {
141 + for (int j = 0; j < grid_height; ++j) {
142 + const float x_in = valid_box.left_ +
143 + (valid_box.GetWidth() * i) / (grid_width - 1);
144 +
145 + const float y_in = valid_box.top_ +
146 + (valid_box.GetHeight() * j) / (grid_height - 1);
147 +
148 + float curr_flow_x;
149 + float curr_flow_y;
150 + const bool success = FindNewPositionOfPoint(x_in, y_in,
151 + &curr_flow_x, &curr_flow_y);
152 +
153 + if (success) {
154 + x_deltas[curr_offset] = curr_flow_x;
155 + y_deltas[curr_offset] = curr_flow_y;
156 + ++curr_offset;
157 + } else {
158 + LOGW("Tracking failure!");
159 + }
160 + }
161 + }
162 +
163 + if (curr_offset > 0) {
164 + qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
165 + qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
166 +
167 + return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
168 + }
169 +
170 + LOGW("No points were valid!");
171 + return Point2f(0, 0);
172 + }
173 +
174 + void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
175 + if (align_matrix23 != NULL) {
176 + if (fullframe_matrix_ == NULL) {
177 + fullframe_matrix_ = new float[6];
178 + }
179 +
180 + memcpy(fullframe_matrix_, align_matrix23,
181 + 6 * sizeof(fullframe_matrix_[0]));
182 + }
183 + }
184 +
185 + private:
186 + Point2f LookupGuessFromLevel(
187 + const int cache_level, const float x, const float y) const {
188 + // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
189 +
190 + // Cutoff at the target level and use the matrix transform instead.
191 + if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
192 + const float xnew = x * fullframe_matrix_[0] +
193 + y * fullframe_matrix_[1] +
194 + fullframe_matrix_[2];
195 + const float ynew = x * fullframe_matrix_[3] +
196 + y * fullframe_matrix_[4] +
197 + fullframe_matrix_[5];
198 +
199 + return Point2f(xnew - x, ynew - y);
200 + }
201 +
202 + const int level_dim = BlockDimForCacheLevel(cache_level);
203 + const int pixels_per_cache_block_x =
204 + (image_size_.width + level_dim - 1) / level_dim;
205 + const int pixels_per_cache_block_y =
206 + (image_size_.height + level_dim - 1) / level_dim;
207 + const int index_x = x / pixels_per_cache_block_x;
208 + const int index_y = y / pixels_per_cache_block_y;
209 +
210 + Point2f displacement;
211 + if (!(*has_cache_[cache_level])[index_y][index_x]) {
212 + (*has_cache_[cache_level])[index_y][index_x] = true;
213 +
214 + // Get the lower cache level's best guess, if it exists.
215 + displacement = cache_level >= kNumCacheLevels - 1 ?
216 + Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
217 + // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
218 + // best_guess.x, best_guess.y);
219 +
220 + // Find the center of the block.
221 + const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
222 + const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
223 + const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
224 +
225 + // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
226 + // "Querying %5.2f, %5.2f at pyramid level %d, ",
227 + // cache_level, index_x, index_y,
228 + // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
229 + // center_x, center_y, pyramid_level);
230 +
231 + // TODO(andrewharp): Turn on FB error filtering.
232 + const bool success = optical_flow_.FindFlowAtPointSingleLevel(
233 + pyramid_level, center_x, center_y, false,
234 + &displacement.x, &displacement.y);
235 +
236 + if (!success) {
237 + LOGV("Computation of cached value failed for level %d!", cache_level);
238 + }
239 +
240 + // Store the value for later use.
241 + (*displacements_[cache_level])[index_y][index_x] = displacement;
242 + } else {
243 + displacement = (*displacements_[cache_level])[index_y][index_x];
244 + }
245 +
246 + // LOGI("Returning %5.2f, %5.2f for level %d",
247 + // displacement.x, displacement.y, cache_level);
248 + return displacement;
249 + }
250 +
251 + Point2f LookupGuess(const float x, const float y) const {
252 + if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
253 + return Point2f(0, 0);
254 + }
255 +
256 + // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
257 + if (kNumCacheLevels > 0) {
258 + return LookupGuessFromLevel(0, x, y);
259 + } else {
260 + return Point2f(0, 0);
261 + }
262 + }
263 +
264 + // Returns the number of cache bins in each dimension for a given level
265 + // of the cache.
266 + int BlockDimForCacheLevel(const int cache_level) const {
267 + // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
268 + // thus if there are 4 cache levels, requesting level 3 (0-based) should
269 + // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
270 + // and so on.
271 + int block_dim = kNumCacheLevels;
272 + for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
273 + --curr_level) {
274 + block_dim *= kCacheBranchFactor;
275 + }
276 + return block_dim;
277 + }
278 +
279 + // Returns the level of the image pyramid that a given cache level maps to.
280 + int PyramidLevelForCacheLevel(const int cache_level) const {
281 + // Higher cache and pyramid levels have smaller dimensions. The highest
282 + // cache level should refer to the highest image pyramid level. The
283 + // lower, finer image pyramid levels are uncached (assuming
284 + // kNumCacheLevels < kNumPyramidLevels).
285 + return cache_level + (kNumPyramidLevels - kNumCacheLevels);
286 + }
287 +
288 + const OpticalFlowConfig* const config_;
289 +
290 + const Size image_size_;
291 + OpticalFlow optical_flow_;
292 +
293 + float* fullframe_matrix_;
294 +
295 + // Whether this value is currently present in the cache.
296 + Image<bool>* has_cache_[kNumCacheLevels];
297 +
298 + // The cached displacement values.
299 + Image<Point2f>* displacements_[kNumCacheLevels];
300 +
301 + TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
302 +};
303 +
304 +} // namespace tf_tracking
305 +
306 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#include <float.h>
17 +
18 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
19 +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
20 +
21 +namespace tf_tracking {
22 +
23 +void FramePair::Init(const int64_t start_time, const int64_t end_time) {
24 + start_time_ = start_time;
25 + end_time_ = end_time;
26 + memset(optical_flow_found_keypoint_, false,
27 + sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
28 + number_of_keypoints_ = 0;
29 +}
30 +
31 +void FramePair::AdjustBox(const BoundingBox box,
32 + float* const translation_x,
33 + float* const translation_y,
34 + float* const scale_x,
35 + float* const scale_y) const {
36 + static float weights[kMaxKeypoints];
37 + static Point2f deltas[kMaxKeypoints];
38 + memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
39 +
40 + BoundingBox resized_box(box);
41 + resized_box.Scale(0.4f, 0.4f);
42 + FillWeights(resized_box, weights);
43 + FillTranslations(deltas);
44 +
45 + const Point2f translation = GetWeightedMedian(weights, deltas);
46 +
47 + *translation_x = translation.x;
48 + *translation_y = translation.y;
49 +
50 + const Point2f old_center = box.GetCenter();
51 + const int good_scale_points =
52 + FillScales(old_center, translation, weights, deltas);
53 +
54 + // Default scale factor is 1 for x and y.
55 + *scale_x = 1.0f;
56 + *scale_y = 1.0f;
57 +
58 + // The assumption is that all deltas that make it to this stage with a
59 + // corresponding optical_flow_found_keypoint_[i] == true are not in
60 + // themselves degenerate.
61 + //
62 + // The degeneracy with scale arose because if the points are too close to the
63 + // center of the objects, the scale ratio determination might be incalculable.
64 + //
65 + // The check for kMinNumInRange is not a degeneracy check, but merely an
66 + // attempt to ensure some sort of stability. The actual degeneracy check is in
67 + // the comparison to EPSILON in FillScales (which I've updated to return the
68 + // number good remaining as well).
69 + static const int kMinNumInRange = 5;
70 + if (good_scale_points >= kMinNumInRange) {
71 + const float scale_factor = GetWeightedMedianScale(weights, deltas);
72 +
73 + if (scale_factor > 0.0f) {
74 + *scale_x = scale_factor;
75 + *scale_y = scale_factor;
76 + }
77 + }
78 +}
79 +
80 +int FramePair::FillWeights(const BoundingBox& box,
81 + float* const weights) const {
82 + // Compute the max score.
83 + float max_score = -FLT_MAX;
84 + float min_score = FLT_MAX;
85 + for (int i = 0; i < kMaxKeypoints; ++i) {
86 + if (optical_flow_found_keypoint_[i]) {
87 + max_score = MAX(max_score, frame1_keypoints_[i].score_);
88 + min_score = MIN(min_score, frame1_keypoints_[i].score_);
89 + }
90 + }
91 +
92 + int num_in_range = 0;
93 + for (int i = 0; i < kMaxKeypoints; ++i) {
94 + if (!optical_flow_found_keypoint_[i]) {
95 + weights[i] = 0.0f;
96 + continue;
97 + }
98 +
99 + const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
100 + if (in_box) {
101 + ++num_in_range;
102 + }
103 +
104 + // The weighting based off distance. Anything within the bounding box
105 + // has a weight of 1, and everything outside of that is within the range
106 + // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
107 + float distance_score = 1.0f;
108 + if (!in_box) {
109 + const Point2f initial = box.GetCenter();
110 + const float sq_x_dist =
111 + Square(initial.x - frame1_keypoints_[i].pos_.x);
112 + const float sq_y_dist =
113 + Square(initial.y - frame1_keypoints_[i].pos_.y);
114 + const float squared_half_width = Square(box.GetWidth() / 2.0f);
115 + const float squared_half_height = Square(box.GetHeight() / 2.0f);
116 +
117 + static const float kOutOfBoxMultiplier = 0.5f;
118 + distance_score = kOutOfBoxMultiplier *
119 + MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
120 + }
121 +
122 + // The weighting based on relative score strength. kBaseScore - 1.0f.
123 + float intrinsic_score = 1.0f;
124 + if (max_score > min_score) {
125 + static const float kBaseScore = 0.5f;
126 + intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
127 + (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
128 + }
129 +
130 + // The final score will be in the range [0, 1].
131 + weights[i] = distance_score * intrinsic_score;
132 + }
133 +
134 + return num_in_range;
135 +}
136 +
137 +void FramePair::FillTranslations(Point2f* const translations) const {
138 + for (int i = 0; i < kMaxKeypoints; ++i) {
139 + if (!optical_flow_found_keypoint_[i]) {
140 + continue;
141 + }
142 + translations[i].x =
143 + frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
144 + translations[i].y =
145 + frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
146 + }
147 +}
148 +
149 +int FramePair::FillScales(const Point2f& old_center,
150 + const Point2f& translation,
151 + float* const weights,
152 + Point2f* const scales) const {
153 + int num_good = 0;
154 + for (int i = 0; i < kMaxKeypoints; ++i) {
155 + if (!optical_flow_found_keypoint_[i]) {
156 + continue;
157 + }
158 +
159 + const Keypoint keypoint1 = frame1_keypoints_[i];
160 + const Keypoint keypoint2 = frame2_keypoints_[i];
161 +
162 + const float dist1_x = keypoint1.pos_.x - old_center.x;
163 + const float dist1_y = keypoint1.pos_.y - old_center.y;
164 +
165 + const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
166 + const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
167 +
168 + // Make sure that the scale makes sense; points too close to the center
169 + // will result in either NaNs or infinite results for scale due to
170 + // limited tracking and floating point resolution.
171 + // Also check that the parity of the points is the same with respect to
172 + // x and y, as we can't really make sense of data that has flipped.
173 + if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
174 + (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
175 + ((dist2_y > EPSILON && dist1_y > EPSILON) ||
176 + (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
177 + scales[i].x = dist2_x / dist1_x;
178 + scales[i].y = dist2_y / dist1_y;
179 + ++num_good;
180 + } else {
181 + weights[i] = 0.0f;
182 + scales[i].x = 1.0f;
183 + scales[i].y = 1.0f;
184 + }
185 + }
186 + return num_good;
187 +}
188 +
189 +struct WeightedDelta {
190 + float weight;
191 + float delta;
192 +};
193 +
194 +// Sort by delta, not by weight.
195 +inline int WeightedDeltaCompare(const void* const a, const void* const b) {
196 + return (reinterpret_cast<const WeightedDelta*>(a)->delta -
197 + reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
198 +}
199 +
200 +// Returns the median delta from a sorted set of weighted deltas.
201 +static float GetMedian(const int num_items,
202 + const WeightedDelta* const weighted_deltas,
203 + const float sum) {
204 + if (num_items == 0 || sum < EPSILON) {
205 + return 0.0f;
206 + }
207 +
208 + float current_weight = 0.0f;
209 + const float target_weight = sum / 2.0f;
210 + for (int i = 0; i < num_items; ++i) {
211 + if (weighted_deltas[i].weight > 0.0f) {
212 + current_weight += weighted_deltas[i].weight;
213 + if (current_weight >= target_weight) {
214 + return weighted_deltas[i].delta;
215 + }
216 + }
217 + }
218 + LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
219 + return 0.0f;
220 +}
221 +
222 +Point2f FramePair::GetWeightedMedian(
223 + const float* const weights, const Point2f* const deltas) const {
224 + Point2f median_delta;
225 +
226 + // TODO(andrewharp): only sort deltas that could possibly have an effect.
227 + static WeightedDelta weighted_deltas[kMaxKeypoints];
228 +
229 + // Compute median X value.
230 + {
231 + float total_weight = 0.0f;
232 +
233 + // Compute weighted mean and deltas.
234 + for (int i = 0; i < kMaxKeypoints; ++i) {
235 + weighted_deltas[i].delta = deltas[i].x;
236 + const float weight = weights[i];
237 + weighted_deltas[i].weight = weight;
238 + if (weight > 0.0f) {
239 + total_weight += weight;
240 + }
241 + }
242 + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
243 + WeightedDeltaCompare);
244 + median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
245 + }
246 +
247 + // Compute median Y value.
248 + {
249 + float total_weight = 0.0f;
250 +
251 + // Compute weighted mean and deltas.
252 + for (int i = 0; i < kMaxKeypoints; ++i) {
253 + const float weight = weights[i];
254 + weighted_deltas[i].weight = weight;
255 + weighted_deltas[i].delta = deltas[i].y;
256 + if (weight > 0.0f) {
257 + total_weight += weight;
258 + }
259 + }
260 + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
261 + WeightedDeltaCompare);
262 + median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
263 + }
264 +
265 + return median_delta;
266 +}
267 +
268 +float FramePair::GetWeightedMedianScale(
269 + const float* const weights, const Point2f* const deltas) const {
270 + float median_delta;
271 +
272 + // TODO(andrewharp): only sort deltas that could possibly have an effect.
273 + static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
274 +
275 + // Compute median scale value across x and y.
276 + {
277 + float total_weight = 0.0f;
278 +
279 + // Add X values.
280 + for (int i = 0; i < kMaxKeypoints; ++i) {
281 + weighted_deltas[i].delta = deltas[i].x;
282 + const float weight = weights[i];
283 + weighted_deltas[i].weight = weight;
284 + if (weight > 0.0f) {
285 + total_weight += weight;
286 + }
287 + }
288 +
289 + // Add Y values.
290 + for (int i = 0; i < kMaxKeypoints; ++i) {
291 + weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
292 + const float weight = weights[i];
293 + weighted_deltas[i + kMaxKeypoints].weight = weight;
294 + if (weight > 0.0f) {
295 + total_weight += weight;
296 + }
297 + }
298 +
299 + qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
300 + WeightedDeltaCompare);
301 +
302 + median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
303 + }
304 +
305 + return median_delta;
306 +}
307 +
308 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
18 +
19 +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
20 +
21 +namespace tf_tracking {
22 +
23 +// A class that records keypoint correspondences from pairs of
24 +// consecutive frames.
25 +class FramePair {
26 + public:
27 + FramePair()
28 + : start_time_(0),
29 + end_time_(0),
30 + number_of_keypoints_(0) {}
31 +
32 + // Cleans up the FramePair so that they can be reused.
33 + void Init(const int64_t start_time, const int64_t end_time);
34 +
35 + void AdjustBox(const BoundingBox box,
36 + float* const translation_x,
37 + float* const translation_y,
38 + float* const scale_x,
39 + float* const scale_y) const;
40 +
41 + private:
42 + // Returns the weighted median of the given deltas, computed independently on
43 + // x and y. Returns 0,0 in case of failure. The assumption is that a
44 + // translation of 0.0 in the degenerate case is the best that can be done, and
45 + // should not be considered an error.
46 + //
47 + // In the case of scale, a slight exception is made just to be safe and
48 + // there is a check for 0.0 explicitly, but that shouldn't ever be possible to
49 + // happen naturally because of the non-zero + parity checks in FillScales.
50 + Point2f GetWeightedMedian(const float* const weights,
51 + const Point2f* const deltas) const;
52 +
53 + float GetWeightedMedianScale(const float* const weights,
54 + const Point2f* const deltas) const;
55 +
56 + // Weights points based on the query_point and cutoff_dist.
57 + int FillWeights(const BoundingBox& box,
58 + float* const weights) const;
59 +
60 + // Fills in the array of deltas with the translations of the points
61 + // between frames.
62 + void FillTranslations(Point2f* const translations) const;
63 +
64 + // Fills in the array of deltas with the relative scale factor of points
65 + // relative to a given center. Has the ability to override the weight to 0 if
66 + // a degenerate scale is detected.
67 + // Translation is the amount the center of the box has moved from one frame to
68 + // the next.
69 + int FillScales(const Point2f& old_center,
70 + const Point2f& translation,
71 + float* const weights,
72 + Point2f* const scales) const;
73 +
74 + // TODO(andrewharp): Make these private.
75 + public:
76 + // The time at frame1.
77 + int64_t start_time_;
78 +
79 + // The time at frame2.
80 + int64_t end_time_;
81 +
82 + // This array will contain the keypoints found in frame 1.
83 + Keypoint frame1_keypoints_[kMaxKeypoints];
84 +
85 + // Contain the locations of the keypoints from frame 1 in frame 2.
86 + Keypoint frame2_keypoints_[kMaxKeypoints];
87 +
88 + // The number of keypoints in frame 1.
89 + int number_of_keypoints_;
90 +
91 + // Keeps track of which keypoint correspondences were actually found from one
92 + // frame to another.
93 + // The i-th element of this array will be non-zero if and only if the i-th
94 + // keypoint of frame 1 was found in frame 2.
95 + bool optical_flow_found_keypoint_[kMaxKeypoints];
96 +
97 + private:
98 + TF_DISALLOW_COPY_AND_ASSIGN(FramePair);
99 +};
100 +
101 +} // namespace tf_tracking
102 +
103 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
18 +
19 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
20 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
21 +
22 +namespace tf_tracking {
23 +
24 +struct Size {
25 + Size(const int width, const int height) : width(width), height(height) {}
26 +
27 + int width;
28 + int height;
29 +};
30 +
31 +
32 +class Point2f {
33 + public:
34 + Point2f() : x(0.0f), y(0.0f) {}
35 + Point2f(const float x, const float y) : x(x), y(y) {}
36 +
37 + inline Point2f operator- (const Point2f& that) const {
38 + return Point2f(this->x - that.x, this->y - that.y);
39 + }
40 +
41 + inline Point2f operator+ (const Point2f& that) const {
42 + return Point2f(this->x + that.x, this->y + that.y);
43 + }
44 +
45 + inline Point2f& operator+= (const Point2f& that) {
46 + this->x += that.x;
47 + this->y += that.y;
48 + return *this;
49 + }
50 +
51 + inline Point2f& operator-= (const Point2f& that) {
52 + this->x -= that.x;
53 + this->y -= that.y;
54 + return *this;
55 + }
56 +
57 + inline Point2f operator- (const Point2f& that) {
58 + return Point2f(this->x - that.x, this->y - that.y);
59 + }
60 +
61 + inline float LengthSquared() {
62 + return Square(this->x) + Square(this->y);
63 + }
64 +
65 + inline float Length() {
66 + return sqrtf(LengthSquared());
67 + }
68 +
69 + inline float DistanceSquared(const Point2f& that) {
70 + return Square(this->x - that.x) + Square(this->y - that.y);
71 + }
72 +
73 + inline float Distance(const Point2f& that) {
74 + return sqrtf(DistanceSquared(that));
75 + }
76 +
77 + float x;
78 + float y;
79 +};
80 +
81 +inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
82 + stream << point.x << "," << point.y;
83 + return stream;
84 +}
85 +
86 +class BoundingBox {
87 + public:
88 + BoundingBox()
89 + : left_(0),
90 + top_(0),
91 + right_(0),
92 + bottom_(0) {}
93 +
94 + BoundingBox(const BoundingBox& bounding_box)
95 + : left_(bounding_box.left_),
96 + top_(bounding_box.top_),
97 + right_(bounding_box.right_),
98 + bottom_(bounding_box.bottom_) {
99 + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
100 + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
101 + }
102 +
103 + BoundingBox(const float left,
104 + const float top,
105 + const float right,
106 + const float bottom)
107 + : left_(left),
108 + top_(top),
109 + right_(right),
110 + bottom_(bottom) {
111 + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
112 + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
113 + }
114 +
115 + BoundingBox(const Point2f& point1, const Point2f& point2)
116 + : left_(MIN(point1.x, point2.x)),
117 + top_(MIN(point1.y, point2.y)),
118 + right_(MAX(point1.x, point2.x)),
119 + bottom_(MAX(point1.y, point2.y)) {}
120 +
121 + inline void CopyToArray(float* const bounds_array) const {
122 + bounds_array[0] = left_;
123 + bounds_array[1] = top_;
124 + bounds_array[2] = right_;
125 + bounds_array[3] = bottom_;
126 + }
127 +
128 + inline float GetWidth() const {
129 + return right_ - left_;
130 + }
131 +
132 + inline float GetHeight() const {
133 + return bottom_ - top_;
134 + }
135 +
136 + inline float GetArea() const {
137 + const float width = GetWidth();
138 + const float height = GetHeight();
139 + if (width <= 0 || height <= 0) {
140 + return 0.0f;
141 + }
142 +
143 + return width * height;
144 + }
145 +
146 + inline Point2f GetCenter() const {
147 + return Point2f((left_ + right_) / 2.0f,
148 + (top_ + bottom_) / 2.0f);
149 + }
150 +
151 + inline bool ValidBox() const {
152 + return GetArea() > 0.0f;
153 + }
154 +
155 + // Returns a bounding box created from the overlapping area of these two.
156 + inline BoundingBox Intersect(const BoundingBox& that) const {
157 + const float new_left = MAX(this->left_, that.left_);
158 + const float new_right = MIN(this->right_, that.right_);
159 +
160 + if (new_left >= new_right) {
161 + return BoundingBox();
162 + }
163 +
164 + const float new_top = MAX(this->top_, that.top_);
165 + const float new_bottom = MIN(this->bottom_, that.bottom_);
166 +
167 + if (new_top >= new_bottom) {
168 + return BoundingBox();
169 + }
170 +
171 + return BoundingBox(new_left, new_top, new_right, new_bottom);
172 + }
173 +
174 + // Returns a bounding box that can contain both boxes.
175 + inline BoundingBox Union(const BoundingBox& that) const {
176 + return BoundingBox(MIN(this->left_, that.left_),
177 + MIN(this->top_, that.top_),
178 + MAX(this->right_, that.right_),
179 + MAX(this->bottom_, that.bottom_));
180 + }
181 +
182 + inline float PascalScore(const BoundingBox& that) const {
183 + SCHECK(GetArea() > 0.0f, "Empty bounding box!");
184 + SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
185 +
186 + const float intersect_area = this->Intersect(that).GetArea();
187 +
188 + if (intersect_area <= 0) {
189 + return 0;
190 + }
191 +
192 + const float score =
193 + intersect_area / (GetArea() + that.GetArea() - intersect_area);
194 + SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
195 + return score;
196 + }
197 +
198 + inline bool Intersects(const BoundingBox& that) const {
199 + return InRange(that.left_, left_, right_)
200 + || InRange(that.right_, left_, right_)
201 + || InRange(that.top_, top_, bottom_)
202 + || InRange(that.bottom_, top_, bottom_);
203 + }
204 +
205 + // Returns whether another bounding box is completely inside of this bounding
206 + // box. Sharing edges is ok.
207 + inline bool Contains(const BoundingBox& that) const {
208 + return that.left_ >= left_ &&
209 + that.right_ <= right_ &&
210 + that.top_ >= top_ &&
211 + that.bottom_ <= bottom_;
212 + }
213 +
214 + inline bool Contains(const Point2f& point) const {
215 + return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
216 + }
217 +
218 + inline void Shift(const Point2f shift_amount) {
219 + left_ += shift_amount.x;
220 + top_ += shift_amount.y;
221 + right_ += shift_amount.x;
222 + bottom_ += shift_amount.y;
223 + }
224 +
225 + inline void ScaleOrigin(const float scale_x, const float scale_y) {
226 + left_ *= scale_x;
227 + right_ *= scale_x;
228 + top_ *= scale_y;
229 + bottom_ *= scale_y;
230 + }
231 +
232 + inline void Scale(const float scale_x, const float scale_y) {
233 + const Point2f center = GetCenter();
234 + const float half_width = GetWidth() / 2.0f;
235 + const float half_height = GetHeight() / 2.0f;
236 +
237 + left_ = center.x - half_width * scale_x;
238 + right_ = center.x + half_width * scale_x;
239 +
240 + top_ = center.y - half_height * scale_y;
241 + bottom_ = center.y + half_height * scale_y;
242 + }
243 +
244 + float left_;
245 + float top_;
246 + float right_;
247 + float bottom_;
248 +};
249 +inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
250 + stream << "[" << box.left_ << " - " << box.right_
251 + << ", " << box.top_ << " - " << box.bottom_
252 + << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
253 + return stream;
254 +}
255 +
256 +
257 +class BoundingSquare {
258 + public:
259 + BoundingSquare(const float x, const float y, const float size)
260 + : x_(x), y_(y), size_(size) {}
261 +
262 + explicit BoundingSquare(const BoundingBox& box)
263 + : x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
264 +#ifdef SANITY_CHECKS
265 + if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
266 + LOG(WARNING) << "This is not a square: " << box << std::endl;
267 + }
268 +#endif
269 + }
270 +
271 + inline BoundingBox ToBoundingBox() const {
272 + return BoundingBox(x_, y_, x_ + size_, y_ + size_);
273 + }
274 +
275 + inline bool ValidBox() {
276 + return size_ > 0.0f;
277 + }
278 +
279 + inline void Shift(const Point2f shift_amount) {
280 + x_ += shift_amount.x;
281 + y_ += shift_amount.y;
282 + }
283 +
284 + inline void Scale(const float scale) {
285 + const float new_size = size_ * scale;
286 + const float position_diff = (new_size - size_) / 2.0f;
287 + x_ -= position_diff;
288 + y_ -= position_diff;
289 + size_ = new_size;
290 + }
291 +
292 + float x_;
293 + float y_;
294 + float size_;
295 +};
296 +inline std::ostream& operator<<(std::ostream& stream,
297 + const BoundingSquare& square) {
298 + stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
299 + return stream;
300 +}
301 +
302 +
303 +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
304 + const float size) {
305 + const float width_diff = (original_box.GetWidth() - size) / 2.0f;
306 + const float height_diff = (original_box.GetHeight() - size) / 2.0f;
307 + return BoundingSquare(original_box.left_ + width_diff,
308 + original_box.top_ + height_diff,
309 + size);
310 +}
311 +
312 +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
313 + return GetCenteredSquare(
314 + original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
315 +}
316 +
317 +} // namespace tf_tracking
318 +
319 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
18 +
19 +#include <GLES/gl.h>
20 +#include <GLES/glext.h>
21 +
22 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
23 +
24 +namespace tf_tracking {
25 +
26 +// Draws a box at the given position.
27 +inline static void DrawBox(const BoundingBox& bounding_box) {
28 + const GLfloat line[] = {
29 + bounding_box.left_, bounding_box.bottom_,
30 + bounding_box.left_, bounding_box.top_,
31 + bounding_box.left_, bounding_box.top_,
32 + bounding_box.right_, bounding_box.top_,
33 + bounding_box.right_, bounding_box.top_,
34 + bounding_box.right_, bounding_box.bottom_,
35 + bounding_box.right_, bounding_box.bottom_,
36 + bounding_box.left_, bounding_box.bottom_
37 + };
38 +
39 + glVertexPointer(2, GL_FLOAT, 0, line);
40 + glEnableClientState(GL_VERTEX_ARRAY);
41 +
42 + glDrawArrays(GL_LINES, 0, 8);
43 +}
44 +
45 +
46 +// Changes the coordinate system such that drawing to an arbitrary square in
47 +// the world can thereafter be drawn to using coordinates 0 - 1.
48 +inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) {
49 + glScalef(square.size_, square.size_, 1.0f);
50 + glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f);
51 +}
52 +
53 +} // namespace tf_tracking
54 +
55 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
18 +
19 +#include <stdint.h>
20 +
21 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
24 +
25 +namespace tf_tracking {
26 +
27 +template <typename T>
28 +Image<T>::Image(const int width, const int height)
29 + : width_less_one_(width - 1),
30 + height_less_one_(height - 1),
31 + data_size_(width * height),
32 + own_data_(true),
33 + width_(width),
34 + height_(height),
35 + stride_(width) {
36 + Allocate();
37 +}
38 +
39 +template <typename T>
40 +Image<T>::Image(const Size& size)
41 + : width_less_one_(size.width - 1),
42 + height_less_one_(size.height - 1),
43 + data_size_(size.width * size.height),
44 + own_data_(true),
45 + width_(size.width),
46 + height_(size.height),
47 + stride_(size.width) {
48 + Allocate();
49 +}
50 +
51 +// Constructor that creates an image from preallocated data.
52 +// Note: The image takes ownership of the data lifecycle, unless own_data is
53 +// set to false.
54 +template <typename T>
55 +Image<T>::Image(const int width, const int height, T* const image_data,
56 + const bool own_data) :
57 + width_less_one_(width - 1),
58 + height_less_one_(height - 1),
59 + data_size_(width * height),
60 + own_data_(own_data),
61 + width_(width),
62 + height_(height),
63 + stride_(width) {
64 + image_data_ = image_data;
65 + SCHECK(image_data_ != NULL, "Can't create image with NULL data!");
66 +}
67 +
68 +template <typename T>
69 +Image<T>::~Image() {
70 + if (own_data_) {
71 + delete[] image_data_;
72 + }
73 + image_data_ = NULL;
74 +}
75 +
76 +template<typename T>
77 +template<class DstType>
78 +bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x,
79 + const int fp_y,
80 + const int patchwidth,
81 + const int patchheight,
82 + DstType* to_data) const {
83 + // Calculate weights.
84 + const int trunc_x = fp_x >> 16;
85 + const int trunc_y = fp_y >> 16;
86 +
87 + if (trunc_x < 0 || trunc_y < 0 ||
88 + (trunc_x + patchwidth) >= width_less_one_ ||
89 + (trunc_y + patchheight) >= height_less_one_) {
90 + return false;
91 + }
92 +
93 + // Now walk over destination patch and fill from interpolated source image.
94 + for (int y = 0; y < patchheight; ++y, to_data += patchwidth) {
95 + for (int x = 0; x < patchwidth; ++x) {
96 + to_data[x] =
97 + static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16),
98 + fp_y + (y << 16)));
99 + }
100 + }
101 +
102 + return true;
103 +}
104 +
105 +template <typename T>
106 +Image<T>* Image<T>::Crop(
107 + const int left, const int top, const int right, const int bottom) const {
108 + SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left);
109 + SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right);
110 + SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top);
111 + SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom);
112 +
113 + SCHECK(left <= right, "mismatch!");
114 + SCHECK(top <= bottom, "mismatch!");
115 +
116 + const int new_width = right - left + 1;
117 + const int new_height = bottom - top + 1;
118 +
119 + Image<T>* const cropped_image = new Image(new_width, new_height);
120 +
121 + for (int y = 0; y < new_height; ++y) {
122 + memcpy((*cropped_image)[y], ((*this)[y + top] + left),
123 + new_width * sizeof(T));
124 + }
125 +
126 + return cropped_image;
127 +}
128 +
129 +template <typename T>
130 +inline float Image<T>::GetPixelInterp(const float x, const float y) const {
131 + // Do int conversion one time.
132 + const int floored_x = static_cast<int>(x);
133 + const int floored_y = static_cast<int>(y);
134 +
135 + // Note: it might be the case that the *_[min|max] values are clipped, and
136 + // these (the a b c d vals) aren't (for speed purposes), but that doesn't
137 + // matter. We'll just be blending the pixel with itself in that case anyway.
138 + const float b = x - floored_x;
139 + const float a = 1.0f - b;
140 +
141 + const float d = y - floored_y;
142 + const float c = 1.0f - d;
143 +
144 + SCHECK(ValidInterpPixel(x, y),
145 + "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)",
146 + x, width_less_one_, y, height_less_one_);
147 +
148 + const T* const pix_ptr = (*this)[floored_y] + floored_x;
149 +
150 + // Get the pixel values surrounding this point.
151 + const T& p1 = pix_ptr[0];
152 + const T& p2 = pix_ptr[1];
153 + const T& p3 = pix_ptr[width_];
154 + const T& p4 = pix_ptr[width_ + 1];
155 +
156 + // Simple bilinear interpolation between four reference pixels.
157 + // If x is the value requested:
158 + // a b
159 + // -------
160 + // c |p1 p2|
161 + // | x |
162 + // d |p3 p4|
163 + // -------
164 + return c * ((a * p1) + (b * p2)) +
165 + d * ((a * p3) + (b * p4));
166 +}
167 +
168 +
169 +template <typename T>
170 +inline T Image<T>::GetPixelInterpFixed1616(
171 + const int fp_x_whole, const int fp_y_whole) const {
172 + static const int kFixedPointOne = 0x00010000;
173 + static const int kFixedPointHalf = 0x00008000;
174 + static const int kFixedPointTruncateMask = 0xFFFF0000;
175 +
176 + int trunc_x = fp_x_whole & kFixedPointTruncateMask;
177 + int trunc_y = fp_y_whole & kFixedPointTruncateMask;
178 + const int fp_x = fp_x_whole - trunc_x;
179 + const int fp_y = fp_y_whole - trunc_y;
180 +
181 + // Scale the truncated values back to regular ints.
182 + trunc_x >>= 16;
183 + trunc_y >>= 16;
184 +
185 + const int one_minus_fp_x = kFixedPointOne - fp_x;
186 + const int one_minus_fp_y = kFixedPointOne - fp_y;
187 +
188 + const T* trunc_start = (*this)[trunc_y] + trunc_x;
189 +
190 + const T a = trunc_start[0];
191 + const T b = trunc_start[1];
192 + const T c = trunc_start[stride_];
193 + const T d = trunc_start[stride_ + 1];
194 +
195 + return (
196 + (one_minus_fp_y * static_cast<int64_t>(one_minus_fp_x * a + fp_x * b) +
197 + fp_y * static_cast<int64_t>(one_minus_fp_x * c + fp_x * d) +
198 + kFixedPointHalf) >>
199 + 32);
200 +}
201 +
202 +template <typename T>
203 +inline bool Image<T>::ValidPixel(const int x, const int y) const {
204 + return InRange(x, ZERO, width_less_one_) &&
205 + InRange(y, ZERO, height_less_one_);
206 +}
207 +
208 +template <typename T>
209 +inline BoundingBox Image<T>::GetContainingBox() const {
210 + return BoundingBox(
211 + 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON);
212 +}
213 +
214 +template <typename T>
215 +inline bool Image<T>::Contains(const BoundingBox& bounding_box) const {
216 + // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds
217 + // are ok.
218 + return GetContainingBox().Contains(bounding_box);
219 +}
220 +
221 +template <typename T>
222 +inline bool Image<T>::ValidInterpPixel(const float x, const float y) const {
223 + // Exclusive of max because we can be more efficient if we don't handle
224 + // interpolating on or past the last pixel.
225 + return (x >= ZERO) && (x < width_less_one_) &&
226 + (y >= ZERO) && (y < height_less_one_);
227 +}
228 +
229 +template <typename T>
230 +void Image<T>::DownsampleAveraged(const T* const original, const int stride,
231 + const int factor) {
232 +#ifdef __ARM_NEON
233 + if (factor == 4 || factor == 2) {
234 + DownsampleAveragedNeon(original, stride, factor);
235 + return;
236 + }
237 +#endif
238 +
239 + // TODO(andrewharp): delete or enable this for non-uint8_t downsamples.
240 + const int pixels_per_block = factor * factor;
241 +
242 + // For every pixel in resulting image.
243 + for (int y = 0; y < height_; ++y) {
244 + const int orig_y = y * factor;
245 + const int y_bound = orig_y + factor;
246 +
247 + // Sum up the original pixels.
248 + for (int x = 0; x < width_; ++x) {
249 + const int orig_x = x * factor;
250 + const int x_bound = orig_x + factor;
251 +
252 + // Making this int32_t because type U or T might overflow.
253 + int32_t pixel_sum = 0;
254 +
255 + // Grab all the pixels that make up this pixel.
256 + for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) {
257 + const T* p = original + curr_y * stride + orig_x;
258 +
259 + for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) {
260 + pixel_sum += *p++;
261 + }
262 + }
263 +
264 + (*this)[y][x] = pixel_sum / pixels_per_block;
265 + }
266 + }
267 +}
268 +
269 +template <typename T>
270 +void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) {
271 + // Calculating the scaling factors based on target image size.
272 + const float factor_x = static_cast<float>(original.GetWidth()) /
273 + static_cast<float>(width_);
274 + const float factor_y = static_cast<float>(original.GetHeight()) /
275 + static_cast<float>(height_);
276 +
277 + // Calculating initial offset in x-axis.
278 + const float offset_x = 0.5f * (original.GetWidth() - width_) / width_;
279 +
280 + // Calculating initial offset in y-axis.
281 + const float offset_y = 0.5f * (original.GetHeight() - height_) / height_;
282 +
283 + float orig_y = offset_y;
284 +
285 + // For every pixel in resulting image.
286 + for (int y = 0; y < height_; ++y) {
287 + float orig_x = offset_x;
288 +
289 + // Finding nearest pixel on y-axis.
290 + const int nearest_y = static_cast<int>(orig_y + 0.5f);
291 + const T* row_data = original[nearest_y];
292 +
293 + T* pixel_ptr = (*this)[y];
294 +
295 + for (int x = 0; x < width_; ++x) {
296 + // Finding nearest pixel on x-axis.
297 + const int nearest_x = static_cast<int>(orig_x + 0.5f);
298 +
299 + *pixel_ptr++ = row_data[nearest_x];
300 +
301 + orig_x += factor_x;
302 + }
303 +
304 + orig_y += factor_y;
305 + }
306 +}
307 +
308 +template <typename T>
309 +void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) {
310 + // TODO(andrewharp): Turn this into a general compare sizes/bulk
311 + // copy method.
312 + if (original.GetWidth() == GetWidth() &&
313 + original.GetHeight() == GetHeight() &&
314 + original.stride() == stride()) {
315 + memcpy(image_data_, original.data(), data_size_ * sizeof(T));
316 + return;
317 + }
318 +
319 + // Calculating the scaling factors based on target image size.
320 + const float factor_x = static_cast<float>(original.GetWidth()) /
321 + static_cast<float>(width_);
322 + const float factor_y = static_cast<float>(original.GetHeight()) /
323 + static_cast<float>(height_);
324 +
325 + // Calculating initial offset in x-axis.
326 + const float offset_x = 0;
327 + const int offset_x_fp = RealToFixed1616(offset_x);
328 +
329 + // Calculating initial offset in y-axis.
330 + const float offset_y = 0;
331 + const int offset_y_fp = RealToFixed1616(offset_y);
332 +
333 + // Get the fixed point scaling factor value.
334 + // Shift by 8 so we can fit everything into a 4 byte int later for speed
335 + // reasons. This means the precision is limited to 1 / 256th of a pixel,
336 + // but this should be good enough.
337 + const int factor_x_fp = RealToFixed1616(factor_x) >> 8;
338 + const int factor_y_fp = RealToFixed1616(factor_y) >> 8;
339 +
340 + int src_y_fp = offset_y_fp >> 8;
341 +
342 + static const int kFixedPointOne8 = 0x00000100;
343 + static const int kFixedPointHalf8 = 0x00000080;
344 + static const int kFixedPointTruncateMask8 = 0xFFFFFF00;
345 +
346 + // For every pixel in resulting image.
347 + for (int y = 0; y < height_; ++y) {
348 + int src_x_fp = offset_x_fp >> 8;
349 +
350 + int trunc_y = src_y_fp & kFixedPointTruncateMask8;
351 + const int fp_y = src_y_fp - trunc_y;
352 +
353 + // Scale the truncated values back to regular ints.
354 + trunc_y >>= 8;
355 +
356 + const int one_minus_fp_y = kFixedPointOne8 - fp_y;
357 +
358 + T* pixel_ptr = (*this)[y];
359 +
360 + // Make sure not to read from an invalid row.
361 + const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1);
362 + const T* other_top_ptr = original[trunc_y];
363 + const T* other_bot_ptr = original[trunc_y_b];
364 +
365 + int last_trunc_x = -1;
366 + int trunc_x = -1;
367 +
368 + T a = 0;
369 + T b = 0;
370 + T c = 0;
371 + T d = 0;
372 +
373 + for (int x = 0; x < width_; ++x) {
374 + trunc_x = src_x_fp & kFixedPointTruncateMask8;
375 +
376 + const int fp_x = (src_x_fp - trunc_x) >> 8;
377 +
378 + // Scale the truncated values back to regular ints.
379 + trunc_x >>= 8;
380 +
381 + // It's possible we're reading from the same pixels
382 + if (trunc_x != last_trunc_x) {
383 + // Make sure not to read from an invalid column.
384 + const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1);
385 + a = other_top_ptr[trunc_x];
386 + b = other_top_ptr[trunc_x_b];
387 + c = other_bot_ptr[trunc_x];
388 + d = other_bot_ptr[trunc_x_b];
389 + last_trunc_x = trunc_x;
390 + }
391 +
392 + const int one_minus_fp_x = kFixedPointOne8 - fp_x;
393 +
394 + const int32_t value =
395 + ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) +
396 + (fp_y * one_minus_fp_x * c + fp_x * d) + kFixedPointHalf8) >>
397 + 16;
398 +
399 + *pixel_ptr++ = value;
400 +
401 + src_x_fp += factor_x_fp;
402 + }
403 + src_y_fp += factor_y_fp;
404 + }
405 +}
406 +
407 +template <typename T>
408 +void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) {
409 + for (int y = 0; y < height_; ++y) {
410 + const int orig_y = Clip(2 * y, ZERO, original.height_less_one_);
411 + const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_);
412 + const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_);
413 +
414 + for (int x = 0; x < width_; ++x) {
415 + const int orig_x = Clip(2 * x, ZERO, original.width_less_one_);
416 + const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_);
417 + const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_);
418 +
419 + // Center.
420 + int32_t pixel_sum = original[orig_y][orig_x] * 4;
421 +
422 + // Sides.
423 + pixel_sum += (original[orig_y][max_x] +
424 + original[orig_y][min_x] +
425 + original[max_y][orig_x] +
426 + original[min_y][orig_x]) * 2;
427 +
428 + // Diagonals.
429 + pixel_sum += (original[min_y][max_x] +
430 + original[min_y][min_x] +
431 + original[max_y][max_x] +
432 + original[max_y][min_x]);
433 +
434 + (*this)[y][x] = pixel_sum >> 4; // 16
435 + }
436 + }
437 +}
438 +
439 +template <typename T>
440 +void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) {
441 + const int max_x = original.width_less_one_;
442 + const int max_y = original.height_less_one_;
443 +
444 + // The JY Bouget paper on Lucas-Kanade recommends a
445 + // [1/16 1/4 3/8 1/4 1/16]^2 filter.
446 + // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below.
447 + static const int window_radius = 2;
448 + static const int window_size = window_radius*2 + 1;
449 + static const int window_weights[] = {1, 4, 6, 4, 1, // 16 +
450 + 4, 16, 24, 16, 4, // 64 +
451 + 6, 24, 36, 24, 6, // 96 +
452 + 4, 16, 24, 16, 4, // 64 +
453 + 1, 4, 6, 4, 1}; // 16 = 256
454 +
455 + // We'll multiply and sum with the whole numbers first, then divide by
456 + // the total weight to normalize at the last moment.
457 + for (int y = 0; y < height_; ++y) {
458 + for (int x = 0; x < width_; ++x) {
459 + int32_t pixel_sum = 0;
460 +
461 + const int* w = window_weights;
462 + const int start_x = Clip((x << 1) - window_radius, ZERO, max_x);
463 +
464 + // Clip the boundaries to the size of the image.
465 + for (int window_y = 0; window_y < window_size; ++window_y) {
466 + const int start_y =
467 + Clip((y << 1) - window_radius + window_y, ZERO, max_y);
468 +
469 + const T* p = original[start_y] + start_x;
470 +
471 + for (int window_x = 0; window_x < window_size; ++window_x) {
472 + pixel_sum += *p++ * *w++;
473 + }
474 + }
475 +
476 + // Conversion to type T will happen here after shifting right 8 bits to
477 + // divide by 256.
478 + (*this)[y][x] = pixel_sum >> 8;
479 + }
480 + }
481 +}
482 +
483 +template <typename T>
484 +template <typename U>
485 +inline T Image<T>::ScharrPixelX(const Image<U>& original,
486 + const int center_x, const int center_y) const {
487 + const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_);
488 + const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_);
489 + const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_);
490 + const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_);
491 +
492 + // Convolution loop unrolled for performance...
493 + return (3 * (original[min_y][max_x]
494 + + original[max_y][max_x]
495 + - original[min_y][min_x]
496 + - original[max_y][min_x])
497 + + 10 * (original[center_y][max_x]
498 + - original[center_y][min_x])) / 32;
499 +}
500 +
501 +template <typename T>
502 +template <typename U>
503 +inline T Image<T>::ScharrPixelY(const Image<U>& original,
504 + const int center_x, const int center_y) const {
505 + const int min_x = Clip(center_x - 1, 0, original.width_less_one_);
506 + const int max_x = Clip(center_x + 1, 0, original.width_less_one_);
507 + const int min_y = Clip(center_y - 1, 0, original.height_less_one_);
508 + const int max_y = Clip(center_y + 1, 0, original.height_less_one_);
509 +
510 + // Convolution loop unrolled for performance...
511 + return (3 * (original[max_y][min_x]
512 + + original[max_y][max_x]
513 + - original[min_y][min_x]
514 + - original[min_y][max_x])
515 + + 10 * (original[max_y][center_x]
516 + - original[min_y][center_x])) / 32;
517 +}
518 +
519 +template <typename T>
520 +template <typename U>
521 +inline void Image<T>::ScharrX(const Image<U>& original) {
522 + for (int y = 0; y < height_; ++y) {
523 + for (int x = 0; x < width_; ++x) {
524 + SetPixel(x, y, ScharrPixelX(original, x, y));
525 + }
526 + }
527 +}
528 +
529 +template <typename T>
530 +template <typename U>
531 +inline void Image<T>::ScharrY(const Image<U>& original) {
532 + for (int y = 0; y < height_; ++y) {
533 + for (int x = 0; x < width_; ++x) {
534 + SetPixel(x, y, ScharrPixelY(original, x, y));
535 + }
536 + }
537 +}
538 +
539 +template <typename T>
540 +template <typename U>
541 +void Image<T>::DerivativeX(const Image<U>& original) {
542 + for (int y = 0; y < height_; ++y) {
543 + const U* const source_row = original[y];
544 + T* const dest_row = (*this)[y];
545 +
546 + // Compute first pixel. Approximated with forward difference.
547 + dest_row[0] = source_row[1] - source_row[0];
548 +
549 + // All the pixels in between. Central difference method.
550 + const U* source_prev_pixel = source_row;
551 + T* dest_pixel = dest_row + 1;
552 + const U* source_next_pixel = source_row + 2;
553 + for (int x = 1; x < width_less_one_; ++x) {
554 + *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
555 + }
556 +
557 + // Last pixel. Approximated with backward difference.
558 + dest_row[width_less_one_] =
559 + source_row[width_less_one_] - source_row[width_less_one_ - 1];
560 + }
561 +}
562 +
563 +template <typename T>
564 +template <typename U>
565 +void Image<T>::DerivativeY(const Image<U>& original) {
566 + const int src_stride = original.stride();
567 +
568 + // Compute 1st row. Approximated with forward difference.
569 + {
570 + const U* const src_row = original[0];
571 + T* dest_row = (*this)[0];
572 + for (int x = 0; x < width_; ++x) {
573 + dest_row[x] = src_row[x + src_stride] - src_row[x];
574 + }
575 + }
576 +
577 + // Compute all rows in between using central difference.
578 + for (int y = 1; y < height_less_one_; ++y) {
579 + T* dest_row = (*this)[y];
580 +
581 + const U* source_prev_pixel = original[y - 1];
582 + const U* source_next_pixel = original[y + 1];
583 + for (int x = 0; x < width_; ++x) {
584 + *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
585 + }
586 + }
587 +
588 + // Compute last row. Approximated with backward difference.
589 + {
590 + const U* const src_row = original[height_less_one_];
591 + T* dest_row = (*this)[height_less_one_];
592 + for (int x = 0; x < width_; ++x) {
593 + dest_row[x] = src_row[x] - src_row[x - src_stride];
594 + }
595 + }
596 +}
597 +
598 +template <typename T>
599 +template <typename U>
600 +inline T Image<T>::ConvolvePixel3x3(const Image<U>& original,
601 + const int* const filter,
602 + const int center_x, const int center_y,
603 + const int total) const {
604 + int32_t sum = 0;
605 + for (int filter_y = 0; filter_y < 3; ++filter_y) {
606 + const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight());
607 + for (int filter_x = 0; filter_x < 3; ++filter_x) {
608 + const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth());
609 + sum += original[y][x] * filter[filter_y * 3 + filter_x];
610 + }
611 + }
612 + return sum / total;
613 +}
614 +
615 +template <typename T>
616 +template <typename U>
617 +inline void Image<T>::Convolve3x3(const Image<U>& original,
618 + const int32_t* const filter) {
619 + int32_t sum = 0;
620 + for (int i = 0; i < 9; ++i) {
621 + sum += abs(filter[i]);
622 + }
623 + for (int y = 0; y < height_; ++y) {
624 + for (int x = 0; x < width_; ++x) {
625 + SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum));
626 + }
627 + }
628 +}
629 +
630 +template <typename T>
631 +inline void Image<T>::FromArray(const T* const pixels, const int stride,
632 + const int factor) {
633 + if (factor == 1 && stride == width_) {
634 + // If not subsampling, memcpy per line should be faster.
635 + memcpy(this->image_data_, pixels, data_size_ * sizeof(T));
636 + return;
637 + }
638 +
639 + DownsampleAveraged(pixels, stride, factor);
640 +}
641 +
642 +} // namespace tf_tracking
643 +
644 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
18 +
19 +#include <stdint.h>
20 +
21 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 +
24 +// TODO(andrewharp): Make this a cast to uint32_t if/when we go unsigned for
25 +// operations.
26 +#define ZERO 0
27 +
28 +#ifdef SANITY_CHECKS
29 + #define CHECK_PIXEL(IMAGE, X, Y) {\
30 + SCHECK((IMAGE)->ValidPixel((X), (Y)), \
31 + "CHECK_PIXEL(%d,%d) in %dx%d image.", \
32 + static_cast<int>(X), static_cast<int>(Y), \
33 + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
34 + }
35 +
36 + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\
37 + SCHECK((IMAGE)->validInterpPixel((X), (Y)), \
38 + "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \
39 + static_cast<float>(X), static_cast<float>(Y), \
40 + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
41 + }
42 +#else
43 + #define CHECK_PIXEL(image, x, y) {}
44 + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {}
45 +#endif
46 +
47 +namespace tf_tracking {
48 +
49 +#ifdef SANITY_CHECKS
50 +// Class which exists solely to provide bounds checking for array-style image
51 +// data access.
52 +template <typename T>
53 +class RowData {
54 + public:
55 + RowData(T* const row_data, const int max_col)
56 + : row_data_(row_data), max_col_(max_col) {}
57 +
58 + inline T& operator[](const int col) const {
59 + SCHECK(InRange(col, 0, max_col_),
60 + "Column out of range: %d (%d max)", col, max_col_);
61 + return row_data_[col];
62 + }
63 +
64 + inline operator T*() const {
65 + return row_data_;
66 + }
67 +
68 + private:
69 + T* const row_data_;
70 + const int max_col_;
71 +};
72 +#endif
73 +
74 +// Naive templated sorting function.
75 +template <typename T>
76 +int Comp(const void* a, const void* b) {
77 + const T val1 = *reinterpret_cast<const T*>(a);
78 + const T val2 = *reinterpret_cast<const T*>(b);
79 +
80 + if (val1 == val2) {
81 + return 0;
82 + } else if (val1 < val2) {
83 + return -1;
84 + } else {
85 + return 1;
86 + }
87 +}
88 +
89 +// TODO(andrewharp): Make explicit which operations support negative numbers or
90 +// struct/class types in image data (possibly create fast multi-dim array class
91 +// for data where pixel arithmetic does not make sense).
92 +
93 +// Image class optimized for working on numeric arrays as grayscale image data.
94 +// Supports other data types as a 2D array class, so long as no pixel math
95 +// operations are called (convolution, downsampling, etc).
96 +template <typename T>
97 +class Image {
98 + public:
99 + Image(const int width, const int height);
100 + explicit Image(const Size& size);
101 +
102 + // Constructor that creates an image from preallocated data.
103 + // Note: The image takes ownership of the data lifecycle, unless own_data is
104 + // set to false.
105 + Image(const int width, const int height, T* const image_data,
106 + const bool own_data = true);
107 +
108 + ~Image();
109 +
110 + // Extract a pixel patch from this image, starting at a subpixel location.
111 + // Uses 16:16 fixed point format for representing real values and doing the
112 + // bilinear interpolation.
113 + //
114 + // Arguments fp_x and fp_y tell the subpixel position in fixed point format,
115 + // patchwidth/patchheight give the size of the patch in pixels and
116 + // to_data must be a valid pointer to a *contiguous* destination data array.
117 + template<class DstType>
118 + bool ExtractPatchAtSubpixelFixed1616(const int fp_x,
119 + const int fp_y,
120 + const int patchwidth,
121 + const int patchheight,
122 + DstType* to_data) const;
123 +
124 + Image<T>* Crop(
125 + const int left, const int top, const int right, const int bottom) const;
126 +
127 + inline int GetWidth() const { return width_; }
128 + inline int GetHeight() const { return height_; }
129 +
130 + // Bilinearly sample a value between pixels. Values must be within the image.
131 + inline float GetPixelInterp(const float x, const float y) const;
132 +
133 + // Bilinearly sample a pixels at a subpixel position using fixed point
134 + // arithmetic.
135 + // Avoids float<->int conversions.
136 + // Values must be within the image.
137 + // Arguments fp_x and fp_y tell the subpixel position in
138 + // 16:16 fixed point format.
139 + //
140 + // Important: This function only makes sense for integer-valued images, such
141 + // as Image<uint8_t> or Image<int> etc.
142 + inline T GetPixelInterpFixed1616(const int fp_x_whole,
143 + const int fp_y_whole) const;
144 +
145 + // Returns true iff the pixel is in the image's boundaries.
146 + inline bool ValidPixel(const int x, const int y) const;
147 +
148 + inline BoundingBox GetContainingBox() const;
149 +
150 + inline bool Contains(const BoundingBox& bounding_box) const;
151 +
152 + inline T GetMedianValue() {
153 + qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>);
154 + return image_data_[data_size_ >> 1];
155 + }
156 +
157 + // Returns true iff the pixel is in the image's boundaries for interpolation
158 + // purposes.
159 + // TODO(andrewharp): check in interpolation follow-up change.
160 + inline bool ValidInterpPixel(const float x, const float y) const;
161 +
162 + // Safe lookup with boundary enforcement.
163 + inline T GetPixelClipped(const int x, const int y) const {
164 + return (*this)[Clip(y, ZERO, height_less_one_)]
165 + [Clip(x, ZERO, width_less_one_)];
166 + }
167 +
168 +#ifdef SANITY_CHECKS
169 + inline RowData<T> operator[](const int row) {
170 + SCHECK(InRange(row, 0, height_less_one_),
171 + "Row out of range: %d (%d max)", row, height_less_one_);
172 + return RowData<T>(image_data_ + row * stride_, width_less_one_);
173 + }
174 +
175 + inline const RowData<T> operator[](const int row) const {
176 + SCHECK(InRange(row, 0, height_less_one_),
177 + "Row out of range: %d (%d max)", row, height_less_one_);
178 + return RowData<T>(image_data_ + row * stride_, width_less_one_);
179 + }
180 +#else
181 + inline T* operator[](const int row) {
182 + return image_data_ + row * stride_;
183 + }
184 +
185 + inline const T* operator[](const int row) const {
186 + return image_data_ + row * stride_;
187 + }
188 +#endif
189 +
190 + const T* data() const { return image_data_; }
191 +
192 + inline int stride() const { return stride_; }
193 +
194 + // Clears image to a single value.
195 + inline void Clear(const T& val) {
196 + memset(image_data_, val, sizeof(*image_data_) * data_size_);
197 + }
198 +
199 +#ifdef __ARM_NEON
200 + void Downsample2x32ColumnsNeon(const uint8_t* const original,
201 + const int stride, const int orig_x);
202 +
203 + void Downsample4x32ColumnsNeon(const uint8_t* const original,
204 + const int stride, const int orig_x);
205 +
206 + void DownsampleAveragedNeon(const uint8_t* const original, const int stride,
207 + const int factor);
208 +#endif
209 +
210 + // Naive downsampler that reduces image size by factor by averaging pixels in
211 + // blocks of size factor x factor.
212 + void DownsampleAveraged(const T* const original, const int stride,
213 + const int factor);
214 +
215 + // Naive downsampler that reduces image size by factor by averaging pixels in
216 + // blocks of size factor x factor.
217 + inline void DownsampleAveraged(const Image<T>& original, const int factor) {
218 + DownsampleAveraged(original.data(), original.GetWidth(), factor);
219 + }
220 +
221 + // Native downsampler that reduces image size using nearest interpolation
222 + void DownsampleInterpolateNearest(const Image<T>& original);
223 +
224 + // Native downsampler that reduces image size using fixed-point bilinear
225 + // interpolation
226 + void DownsampleInterpolateLinear(const Image<T>& original);
227 +
228 + // Relatively efficient downsampling of an image by a factor of two with a
229 + // low-pass 3x3 smoothing operation thrown in.
230 + void DownsampleSmoothed3x3(const Image<T>& original);
231 +
232 + // Relatively efficient downsampling of an image by a factor of two with a
233 + // low-pass 5x5 smoothing operation thrown in.
234 + void DownsampleSmoothed5x5(const Image<T>& original);
235 +
236 + // Optimized Scharr filter on a single pixel in the X direction.
237 + // Scharr filters are like central-difference operators, but have more
238 + // rotational symmetry in their response because they also consider the
239 + // diagonal neighbors.
240 + template <typename U>
241 + inline T ScharrPixelX(const Image<U>& original,
242 + const int center_x, const int center_y) const;
243 +
244 + // Optimized Scharr filter on a single pixel in the X direction.
245 + // Scharr filters are like central-difference operators, but have more
246 + // rotational symmetry in their response because they also consider the
247 + // diagonal neighbors.
248 + template <typename U>
249 + inline T ScharrPixelY(const Image<U>& original,
250 + const int center_x, const int center_y) const;
251 +
252 + // Convolve the image with a Scharr filter in the X direction.
253 + // Much faster than an equivalent generic convolution.
254 + template <typename U>
255 + inline void ScharrX(const Image<U>& original);
256 +
257 + // Convolve the image with a Scharr filter in the Y direction.
258 + // Much faster than an equivalent generic convolution.
259 + template <typename U>
260 + inline void ScharrY(const Image<U>& original);
261 +
262 + static inline T HalfDiff(int32_t first, int32_t second) {
263 + return (second - first) / 2;
264 + }
265 +
266 + template <typename U>
267 + void DerivativeX(const Image<U>& original);
268 +
269 + template <typename U>
270 + void DerivativeY(const Image<U>& original);
271 +
272 + // Generic function for convolving pixel with 3x3 filter.
273 + // Filter pixels should be in row major order.
274 + template <typename U>
275 + inline T ConvolvePixel3x3(const Image<U>& original,
276 + const int* const filter,
277 + const int center_x, const int center_y,
278 + const int total) const;
279 +
280 + // Generic function for convolving an image with a 3x3 filter.
281 + // TODO(andrewharp): Generalize this for any size filter.
282 + template <typename U>
283 + inline void Convolve3x3(const Image<U>& original,
284 + const int32_t* const filter);
285 +
286 + // Load this image's data from a data array. The data at pixels is assumed to
287 + // have dimensions equivalent to this image's dimensions * factor.
288 + inline void FromArray(const T* const pixels, const int stride,
289 + const int factor = 1);
290 +
291 + // Copy the image back out to an appropriately sized data array.
292 + inline void ToArray(T* const pixels) const {
293 + // If not subsampling, memcpy should be faster.
294 + memcpy(pixels, this->image_data_, data_size_ * sizeof(T));
295 + }
296 +
297 + // Precompute these for efficiency's sake as they're used by a lot of
298 + // clipping code and loop code.
299 + // TODO(andrewharp): make these only accessible by other Images.
300 + const int width_less_one_;
301 + const int height_less_one_;
302 +
303 + // The raw size of the allocated data.
304 + const int data_size_;
305 +
306 + private:
307 + inline void Allocate() {
308 + image_data_ = new T[data_size_];
309 + if (image_data_ == NULL) {
310 + LOGE("Couldn't allocate image data!");
311 + }
312 + }
313 +
314 + T* image_data_;
315 +
316 + bool own_data_;
317 +
318 + const int width_;
319 + const int height_;
320 +
321 + // The image stride (offset to next row).
322 + // TODO(andrewharp): Make sure that stride is honored in all code.
323 + const int stride_;
324 +
325 + TF_DISALLOW_COPY_AND_ASSIGN(Image);
326 +};
327 +
328 +template <typename t>
329 +inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) {
330 + for (int y = 0; y < image.GetHeight(); ++y) {
331 + for (int x = 0; x < image.GetWidth(); ++x) {
332 + stream << image[y][x] << " ";
333 + }
334 + stream << std::endl;
335 + }
336 + return stream;
337 +}
338 +
339 +} // namespace tf_tracking
340 +
341 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
18 +
19 +#include <stdint.h>
20 +#include <memory>
21 +
22 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
28 +
29 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
30 +
31 +namespace tf_tracking {
32 +
33 +// Class that encapsulates all bulky processed data for a frame.
34 +class ImageData {
35 + public:
36 + explicit ImageData(const int width, const int height)
37 + : uv_frame_width_(width << 1),
38 + uv_frame_height_(height << 1),
39 + timestamp_(0),
40 + image_(width, height) {
41 + InitPyramid(width, height);
42 + ResetComputationCache();
43 + }
44 +
45 + private:
46 + void ResetComputationCache() {
47 + uv_data_computed_ = false;
48 + integral_image_computed_ = false;
49 + for (int i = 0; i < kNumPyramidLevels; ++i) {
50 + spatial_x_computed_[i] = false;
51 + spatial_y_computed_[i] = false;
52 + pyramid_sqrt2_computed_[i * 2] = false;
53 + pyramid_sqrt2_computed_[i * 2 + 1] = false;
54 + }
55 + }
56 +
57 + void InitPyramid(const int width, const int height) {
58 + int level_width = width;
59 + int level_height = height;
60 +
61 + for (int i = 0; i < kNumPyramidLevels; ++i) {
62 + pyramid_sqrt2_[i * 2] = NULL;
63 + pyramid_sqrt2_[i * 2 + 1] = NULL;
64 + spatial_x_[i] = NULL;
65 + spatial_y_[i] = NULL;
66 +
67 + level_width /= 2;
68 + level_height /= 2;
69 + }
70 +
71 + // Alias the first pyramid level to image_.
72 + pyramid_sqrt2_[0] = &image_;
73 + }
74 +
75 + public:
76 + ~ImageData() {
77 + // The first pyramid level is actually an alias to image_,
78 + // so make sure it doesn't get deleted here.
79 + pyramid_sqrt2_[0] = NULL;
80 +
81 + for (int i = 0; i < kNumPyramidLevels; ++i) {
82 + SAFE_DELETE(pyramid_sqrt2_[i * 2]);
83 + SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
84 + SAFE_DELETE(spatial_x_[i]);
85 + SAFE_DELETE(spatial_y_[i]);
86 + }
87 + }
88 +
89 + void SetData(const uint8_t* const new_frame, const int stride,
90 + const int64_t timestamp, const int downsample_factor) {
91 + SetData(new_frame, NULL, stride, timestamp, downsample_factor);
92 + }
93 +
94 + void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame,
95 + const int stride, const int64_t timestamp,
96 + const int downsample_factor) {
97 + ResetComputationCache();
98 +
99 + timestamp_ = timestamp;
100 +
101 + TimeLog("SetData!");
102 +
103 + pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
104 + pyramid_sqrt2_computed_[0] = true;
105 + TimeLog("Downsampled image");
106 +
107 + if (uv_frame != NULL) {
108 + if (u_data_.get() == NULL) {
109 + u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
110 + v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
111 + }
112 +
113 + GetUV(uv_frame, u_data_.get(), v_data_.get());
114 + uv_data_computed_ = true;
115 + TimeLog("Copied UV data");
116 + } else {
117 + LOGV("No uv data!");
118 + }
119 +
120 +#ifdef LOG_TIME
121 + // If profiling is enabled, precompute here to make it easier to distinguish
122 + // total costs.
123 + Precompute();
124 +#endif
125 + }
126 +
127 + inline const uint64_t GetTimestamp() const { return timestamp_; }
128 +
129 + inline const Image<uint8_t>* GetImage() const {
130 + SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
131 + return pyramid_sqrt2_[0];
132 + }
133 +
134 + const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const {
135 + if (!pyramid_sqrt2_computed_[level]) {
136 + SCHECK(level != 0, "Level equals 0!");
137 + if (level == 1) {
138 + const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0);
139 + if (pyramid_sqrt2_[level] == NULL) {
140 + const int new_width =
141 + (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
142 + const int new_height =
143 + (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
144 + 2;
145 +
146 + pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height);
147 + }
148 + pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
149 + } else {
150 + const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2);
151 + if (pyramid_sqrt2_[level] == NULL) {
152 + pyramid_sqrt2_[level] = new Image<uint8_t>(
153 + upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
154 + }
155 + pyramid_sqrt2_[level]->DownsampleAveraged(
156 + upper_level.data(), upper_level.stride(), 2);
157 + }
158 + pyramid_sqrt2_computed_[level] = true;
159 + }
160 + return pyramid_sqrt2_[level];
161 + }
162 +
163 + inline const Image<int32_t>* GetSpatialX(const int level) const {
164 + if (!spatial_x_computed_[level]) {
165 + const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
166 + if (spatial_x_[level] == NULL) {
167 + spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
168 + }
169 + spatial_x_[level]->DerivativeX(src);
170 + spatial_x_computed_[level] = true;
171 + }
172 + return spatial_x_[level];
173 + }
174 +
175 + inline const Image<int32_t>* GetSpatialY(const int level) const {
176 + if (!spatial_y_computed_[level]) {
177 + const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
178 + if (spatial_y_[level] == NULL) {
179 + spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
180 + }
181 + spatial_y_[level]->DerivativeY(src);
182 + spatial_y_computed_[level] = true;
183 + }
184 + return spatial_y_[level];
185 + }
186 +
187 + // The integral image is currently only used for object detection, so lazily
188 + // initialize it on request.
189 + inline const IntegralImage* GetIntegralImage() const {
190 + if (integral_image_.get() == NULL) {
191 + integral_image_.reset(new IntegralImage(image_));
192 + } else if (!integral_image_computed_) {
193 + integral_image_->Recompute(image_);
194 + }
195 + integral_image_computed_ = true;
196 + return integral_image_.get();
197 + }
198 +
199 + inline const Image<uint8_t>* GetU() const {
200 + SCHECK(uv_data_computed_, "UV data not provided!");
201 + return u_data_.get();
202 + }
203 +
204 + inline const Image<uint8_t>* GetV() const {
205 + SCHECK(uv_data_computed_, "UV data not provided!");
206 + return v_data_.get();
207 + }
208 +
209 + private:
210 + void Precompute() {
211 + // Create the smoothed pyramids.
212 + for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
213 + (void) GetPyramidSqrt2Level(i);
214 + }
215 + TimeLog("Created smoothed pyramids");
216 +
217 + // Create the smoothed pyramids.
218 + for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
219 + (void) GetPyramidSqrt2Level(i);
220 + }
221 + TimeLog("Created smoothed sqrt pyramids");
222 +
223 + // Create the spatial derivatives for frame 1.
224 + for (int i = 0; i < kNumPyramidLevels; ++i) {
225 + (void) GetSpatialX(i);
226 + (void) GetSpatialY(i);
227 + }
228 + TimeLog("Created spatial derivatives");
229 +
230 + (void) GetIntegralImage();
231 + TimeLog("Got integral image!");
232 + }
233 +
234 + const int uv_frame_width_;
235 + const int uv_frame_height_;
236 +
237 + int64_t timestamp_;
238 +
239 + Image<uint8_t> image_;
240 +
241 + bool uv_data_computed_;
242 + std::unique_ptr<Image<uint8_t> > u_data_;
243 + std::unique_ptr<Image<uint8_t> > v_data_;
244 +
245 + mutable bool spatial_x_computed_[kNumPyramidLevels];
246 + mutable Image<int32_t>* spatial_x_[kNumPyramidLevels];
247 +
248 + mutable bool spatial_y_computed_[kNumPyramidLevels];
249 + mutable Image<int32_t>* spatial_y_[kNumPyramidLevels];
250 +
251 + // Mutable so the lazy initialization can work when this class is const.
252 + // Whether or not the integral image has been computed for the current image.
253 + mutable bool integral_image_computed_;
254 + mutable std::unique_ptr<IntegralImage> integral_image_;
255 +
256 + mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
257 + mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2];
258 +
259 + TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
260 +};
261 +
262 +} // namespace tf_tracking
263 +
264 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// NEON implementations of Image methods for compatible devices. Control
17 +// should never enter this compilation unit on incompatible devices.
18 +
19 +#ifdef __ARM_NEON
20 +
21 +#include <arm_neon.h>
22 +
23 +#include <stdint.h>
24 +
25 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
27 +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
28 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
29 +
30 +namespace tf_tracking {
31 +
32 +// This function does the bulk of the work.
33 +template <>
34 +void Image<uint8_t>::Downsample2x32ColumnsNeon(const uint8_t* const original,
35 + const int stride,
36 + const int orig_x) {
37 + // Divide input x offset by 2 to find output offset.
38 + const int new_x = orig_x >> 1;
39 +
40 + // Initial offset into top row.
41 + const uint8_t* offset = original + orig_x;
42 +
43 + // This points to the leftmost pixel of our 8 horizontally arranged
44 + // pixels in the destination data.
45 + uint8_t* ptr_dst = (*this)[0] + new_x;
46 +
47 + // Sum along vertical columns.
48 + // Process 32x2 input pixels and 16x1 output pixels per iteration.
49 + for (int new_y = 0; new_y < height_; ++new_y) {
50 + uint16x8_t accum1 = vdupq_n_u16(0);
51 + uint16x8_t accum2 = vdupq_n_u16(0);
52 +
53 + // Go top to bottom across the four rows of input pixels that make up
54 + // this output row.
55 + for (int row_num = 0; row_num < 2; ++row_num) {
56 + // First 16 bytes.
57 + {
58 + // Load 16 bytes of data from current offset.
59 + const uint8x16_t curr_data1 = vld1q_u8(offset);
60 +
61 + // Pairwise add and accumulate into accum vectors (16 bit to account
62 + // for values above 255).
63 + accum1 = vpadalq_u8(accum1, curr_data1);
64 + }
65 +
66 + // Second 16 bytes.
67 + {
68 + // Load 16 bytes of data from current offset.
69 + const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
70 +
71 + // Pairwise add and accumulate into accum vectors (16 bit to account
72 + // for values above 255).
73 + accum2 = vpadalq_u8(accum2, curr_data2);
74 + }
75 +
76 + // Move offset down one row.
77 + offset += stride;
78 + }
79 +
80 + // Divide by 4 (number of input pixels per output
81 + // pixel) and narrow data from 16 bits per pixel to 8 bpp.
82 + const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2);
83 + const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2);
84 +
85 + // Concatenate 8x1 pixel strips into 16x1 pixel strip.
86 + const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2);
87 +
88 + // Copy all pixels from composite 16x1 vector into output strip.
89 + vst1q_u8(ptr_dst, allpixels);
90 +
91 + ptr_dst += stride_;
92 + }
93 +}
94 +
95 +// This function does the bulk of the work.
96 +template <>
97 +void Image<uint8_t>::Downsample4x32ColumnsNeon(const uint8_t* const original,
98 + const int stride,
99 + const int orig_x) {
100 + // Divide input x offset by 4 to find output offset.
101 + const int new_x = orig_x >> 2;
102 +
103 + // Initial offset into top row.
104 + const uint8_t* offset = original + orig_x;
105 +
106 + // This points to the leftmost pixel of our 8 horizontally arranged
107 + // pixels in the destination data.
108 + uint8_t* ptr_dst = (*this)[0] + new_x;
109 +
110 + // Sum along vertical columns.
111 + // Process 32x4 input pixels and 8x1 output pixels per iteration.
112 + for (int new_y = 0; new_y < height_; ++new_y) {
113 + uint16x8_t accum1 = vdupq_n_u16(0);
114 + uint16x8_t accum2 = vdupq_n_u16(0);
115 +
116 + // Go top to bottom across the four rows of input pixels that make up
117 + // this output row.
118 + for (int row_num = 0; row_num < 4; ++row_num) {
119 + // First 16 bytes.
120 + {
121 + // Load 16 bytes of data from current offset.
122 + const uint8x16_t curr_data1 = vld1q_u8(offset);
123 +
124 + // Pairwise add and accumulate into accum vectors (16 bit to account
125 + // for values above 255).
126 + accum1 = vpadalq_u8(accum1, curr_data1);
127 + }
128 +
129 + // Second 16 bytes.
130 + {
131 + // Load 16 bytes of data from current offset.
132 + const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
133 +
134 + // Pairwise add and accumulate into accum vectors (16 bit to account
135 + // for values above 255).
136 + accum2 = vpadalq_u8(accum2, curr_data2);
137 + }
138 +
139 + // Move offset down one row.
140 + offset += stride;
141 + }
142 +
143 + // Add and widen, then divide by 16 (number of input pixels per output
144 + // pixel) and narrow data from 32 bits per pixel to 16 bpp.
145 + const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4);
146 + const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4);
147 +
148 + // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from
149 + // 16 bits to 8 bits per pixel.
150 + const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2));
151 +
152 + // Copy all pixels from composite 8x1 vector into output strip.
153 + vst1_u8(ptr_dst, allpixels);
154 +
155 + ptr_dst += stride_;
156 + }
157 +}
158 +
159 +
160 +// Hardware accelerated downsampling method for supported devices.
161 +// Requires that image size be a multiple of 16 pixels in each dimension,
162 +// and that downsampling be by a factor of 2 or 4.
163 +template <>
164 +void Image<uint8_t>::DownsampleAveragedNeon(const uint8_t* const original,
165 + const int stride,
166 + const int factor) {
167 + // TODO(andrewharp): stride is a bad approximation for the src image's width.
168 + // Better to pass that in directly.
169 + SCHECK(width_ * factor <= stride, "Uh oh!");
170 + const int last_starting_index = width_ * factor - 32;
171 +
172 + // We process 32 input pixels lengthwise at a time.
173 + // The output per pass of this loop is an 8 wide by downsampled height tall
174 + // pixel strip.
175 + int orig_x = 0;
176 + for (; orig_x <= last_starting_index; orig_x += 32) {
177 + if (factor == 2) {
178 + Downsample2x32ColumnsNeon(original, stride, orig_x);
179 + } else {
180 + Downsample4x32ColumnsNeon(original, stride, orig_x);
181 + }
182 + }
183 +
184 + // If a last pass is required, push it to the left enough so that it never
185 + // goes out of bounds. This will result in some extra computation on devices
186 + // whose frame widths are multiples of 16 and not 32.
187 + if (orig_x < last_starting_index + 32) {
188 + if (factor == 2) {
189 + Downsample2x32ColumnsNeon(original, stride, last_starting_index);
190 + } else {
191 + Downsample4x32ColumnsNeon(original, stride, last_starting_index);
192 + }
193 + }
194 +}
195 +
196 +
197 +// Puts the image gradient matrix about a pixel into the 2x2 float array G.
198 +// vals_x should be an array of the window x gradient values, whose indices
199 +// can be in any order but are parallel to the vals_y entries.
200 +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
201 +void CalculateGNeon(const float* const vals_x, const float* const vals_y,
202 + const int num_vals, float* const G) {
203 + const float32_t* const arm_vals_x = (const float32_t*) vals_x;
204 + const float32_t* const arm_vals_y = (const float32_t*) vals_y;
205 +
206 + // Running sums.
207 + float32x4_t xx = vdupq_n_f32(0.0f);
208 + float32x4_t xy = vdupq_n_f32(0.0f);
209 + float32x4_t yy = vdupq_n_f32(0.0f);
210 +
211 + // Maximum index we can load 4 consecutive values from.
212 + // e.g. if there are 81 values, our last full pass can be from index 77:
213 + // 81-4=>77 (77, 78, 79, 80)
214 + const int max_i = num_vals - 4;
215 +
216 + // Defined here because we want to keep track of how many values were
217 + // processed by NEON, so that we can finish off the remainder the normal
218 + // way.
219 + int i = 0;
220 +
221 + // Process values 4 at a time, accumulating the sums of
222 + // the pixel-wise x*x, x*y, and y*y values.
223 + for (; i <= max_i; i += 4) {
224 + // Load xs
225 + float32x4_t x = vld1q_f32(arm_vals_x + i);
226 +
227 + // Multiply x*x and accumulate.
228 + xx = vmlaq_f32(xx, x, x);
229 +
230 + // Load ys
231 + float32x4_t y = vld1q_f32(arm_vals_y + i);
232 +
233 + // Multiply x*y and accumulate.
234 + xy = vmlaq_f32(xy, x, y);
235 +
236 + // Multiply y*y and accumulate.
237 + yy = vmlaq_f32(yy, y, y);
238 + }
239 +
240 + static float32_t xx_vals[4];
241 + static float32_t xy_vals[4];
242 + static float32_t yy_vals[4];
243 +
244 + vst1q_f32(xx_vals, xx);
245 + vst1q_f32(xy_vals, xy);
246 + vst1q_f32(yy_vals, yy);
247 +
248 + // Accumulated values are store in sets of 4, we have to manually add
249 + // the last bits together.
250 + for (int j = 0; j < 4; ++j) {
251 + G[0] += xx_vals[j];
252 + G[1] += xy_vals[j];
253 + G[3] += yy_vals[j];
254 + }
255 +
256 + // Finishes off last few values (< 4) from above.
257 + for (; i < num_vals; ++i) {
258 + G[0] += Square(vals_x[i]);
259 + G[1] += vals_x[i] * vals_y[i];
260 + G[3] += Square(vals_y[i]);
261 + }
262 +
263 + // The matrix is symmetric, so this is a given.
264 + G[2] = G[1];
265 +}
266 +
267 +} // namespace tf_tracking
268 +
269 +#endif
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
18 +
19 +#include <stdint.h>
20 +
21 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
25 +
26 +
27 +namespace tf_tracking {
28 +
29 +inline void GetUV(const uint8_t* const input, Image<uint8_t>* const u,
30 + Image<uint8_t>* const v) {
31 + const uint8_t* pUV = input;
32 +
33 + for (int row = 0; row < u->GetHeight(); ++row) {
34 + uint8_t* u_curr = (*u)[row];
35 + uint8_t* v_curr = (*v)[row];
36 + for (int col = 0; col < u->GetWidth(); ++col) {
37 +#ifdef __APPLE__
38 + *u_curr++ = *pUV++;
39 + *v_curr++ = *pUV++;
40 +#else
41 + *v_curr++ = *pUV++;
42 + *u_curr++ = *pUV++;
43 +#endif
44 + }
45 + }
46 +}
47 +
48 +// Marks every point within a circle of a given radius on the given boolean
49 +// image true.
50 +template <typename U>
51 +inline static void MarkImage(const int x, const int y, const int radius,
52 + Image<U>* const img) {
53 + SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y);
54 +
55 + // Precomputed for efficiency.
56 + const int squared_radius = Square(radius);
57 +
58 + // Mark every row in the circle.
59 + for (int d_y = 0; d_y <= radius; ++d_y) {
60 + const int squared_y_dist = Square(d_y);
61 +
62 + const int min_y = MAX(y - d_y, 0);
63 + const int max_y = MIN(y + d_y, img->height_less_one_);
64 +
65 + // The max d_x of the circle must be strictly greater or equal to
66 + // radius - d_y for any positive d_y. Thus, starting from radius - d_y will
67 + // reduce the number of iterations required as compared to starting from
68 + // either 0 and counting up or radius and counting down.
69 + for (int d_x = radius - d_y; d_x <= radius; ++d_x) {
70 + // The first time this criteria is met, we know the width of the circle at
71 + // this row (without using sqrt).
72 + if (squared_y_dist + Square(d_x) >= squared_radius) {
73 + const int min_x = MAX(x - d_x, 0);
74 + const int max_x = MIN(x + d_x, img->width_less_one_);
75 +
76 + // Mark both above and below the center row.
77 + bool* const top_row_start = (*img)[min_y] + min_x;
78 + bool* const bottom_row_start = (*img)[max_y] + min_x;
79 +
80 + const int x_width = max_x - min_x + 1;
81 + memset(top_row_start, true, sizeof(*top_row_start) * x_width);
82 + memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width);
83 +
84 + // This row is marked, time to move on to the next row.
85 + break;
86 + }
87 + }
88 + }
89 +}
90 +
91 +#ifdef __ARM_NEON
92 +void CalculateGNeon(
93 + const float* const vals_x, const float* const vals_y,
94 + const int num_vals, float* const G);
95 +#endif
96 +
97 +// Puts the image gradient matrix about a pixel into the 2x2 float array G.
98 +// vals_x should be an array of the window x gradient values, whose indices
99 +// can be in any order but are parallel to the vals_y entries.
100 +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
101 +inline void CalculateG(const float* const vals_x, const float* const vals_y,
102 + const int num_vals, float* const G) {
103 +#ifdef __ARM_NEON
104 + CalculateGNeon(vals_x, vals_y, num_vals, G);
105 + return;
106 +#endif
107 +
108 + // Non-accelerated version.
109 + for (int i = 0; i < num_vals; ++i) {
110 + G[0] += Square(vals_x[i]);
111 + G[1] += vals_x[i] * vals_y[i];
112 + G[3] += Square(vals_y[i]);
113 + }
114 +
115 + // The matrix is symmetric, so this is a given.
116 + G[2] = G[1];
117 +}
118 +
119 +inline void CalculateGInt16(const int16_t* const vals_x,
120 + const int16_t* const vals_y, const int num_vals,
121 + int* const G) {
122 + // Non-accelerated version.
123 + for (int i = 0; i < num_vals; ++i) {
124 + G[0] += Square(vals_x[i]);
125 + G[1] += vals_x[i] * vals_y[i];
126 + G[3] += Square(vals_y[i]);
127 + }
128 +
129 + // The matrix is symmetric, so this is a given.
130 + G[2] = G[1];
131 +}
132 +
133 +
134 +// Puts the image gradient matrix about a pixel into the 2x2 float array G.
135 +// Looks up interpolated pixels, then calls above method for implementation.
136 +inline void CalculateG(const int window_radius, const float center_x,
137 + const float center_y, const Image<int32_t>& I_x,
138 + const Image<int32_t>& I_y, float* const G) {
139 + SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!");
140 +
141 + // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels).
142 + static const int kMaxWindowRadius = 5;
143 + SCHECK(window_radius <= kMaxWindowRadius,
144 + "Window %d > %d!", window_radius, kMaxWindowRadius);
145 +
146 + // Diameter of window is 2 * radius + 1 for center pixel.
147 + static const int kWindowBufferSize =
148 + (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1);
149 +
150 + // Preallocate buffers statically for efficiency.
151 + static int16_t vals_x[kWindowBufferSize];
152 + static int16_t vals_y[kWindowBufferSize];
153 +
154 + const int src_left_fixed = RealToFixed1616(center_x - window_radius);
155 + const int src_top_fixed = RealToFixed1616(center_y - window_radius);
156 +
157 + int16_t* vals_x_ptr = vals_x;
158 + int16_t* vals_y_ptr = vals_y;
159 +
160 + const int window_size = 2 * window_radius + 1;
161 + for (int y = 0; y < window_size; ++y) {
162 + const int fp_y = src_top_fixed + (y << 16);
163 +
164 + for (int x = 0; x < window_size; ++x) {
165 + const int fp_x = src_left_fixed + (x << 16);
166 +
167 + *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
168 + *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
169 + }
170 + }
171 +
172 + int32_t g_temp[] = {0, 0, 0, 0};
173 + CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp);
174 +
175 + for (int i = 0; i < 4; ++i) {
176 + G[i] = g_temp[i];
177 + }
178 +}
179 +
180 +inline float ImageCrossCorrelation(const Image<float>& image1,
181 + const Image<float>& image2,
182 + const int x_offset, const int y_offset) {
183 + SCHECK(image1.GetWidth() == image2.GetWidth() &&
184 + image1.GetHeight() == image2.GetHeight(),
185 + "Dimension mismatch! %dx%d vs %dx%d",
186 + image1.GetWidth(), image1.GetHeight(),
187 + image2.GetWidth(), image2.GetHeight());
188 +
189 + const int num_pixels = image1.GetWidth() * image1.GetHeight();
190 + const float* data1 = image1.data();
191 + const float* data2 = image2.data();
192 + return ComputeCrossCorrelation(data1, data2, num_pixels);
193 +}
194 +
195 +// Copies an arbitrary region of an image to another (floating point)
196 +// image, scaling as it goes using bilinear interpolation.
197 +inline void CopyArea(const Image<uint8_t>& image,
198 + const BoundingBox& area_to_copy,
199 + Image<float>* const patch_image) {
200 + VLOG(2) << "Copying from: " << area_to_copy << std::endl;
201 +
202 + const int patch_width = patch_image->GetWidth();
203 + const int patch_height = patch_image->GetHeight();
204 +
205 + const float x_dist_between_samples = patch_width > 0 ?
206 + area_to_copy.GetWidth() / (patch_width - 1) : 0;
207 +
208 + const float y_dist_between_samples = patch_height > 0 ?
209 + area_to_copy.GetHeight() / (patch_height - 1) : 0;
210 +
211 + for (int y_index = 0; y_index < patch_height; ++y_index) {
212 + const float sample_y =
213 + y_index * y_dist_between_samples + area_to_copy.top_;
214 +
215 + for (int x_index = 0; x_index < patch_width; ++x_index) {
216 + const float sample_x =
217 + x_index * x_dist_between_samples + area_to_copy.left_;
218 +
219 + if (image.ValidInterpPixel(sample_x, sample_y)) {
220 + // TODO(andrewharp): Do area averaging when downsampling.
221 + (*patch_image)[y_index][x_index] =
222 + image.GetPixelInterp(sample_x, sample_y);
223 + } else {
224 + (*patch_image)[y_index][x_index] = -1.0f;
225 + }
226 + }
227 + }
228 +}
229 +
230 +
231 +// Takes a floating point image and normalizes it in-place.
232 +//
233 +// First, negative values will be set to the mean of the non-negative pixels
234 +// in the image.
235 +//
236 +// Then, the resulting will be normalized such that it has mean value of 0.0 and
237 +// a standard deviation of 1.0.
238 +inline void NormalizeImage(Image<float>* const image) {
239 + const float* const data_ptr = image->data();
240 +
241 + // Copy only the non-negative values to some temp memory.
242 + float running_sum = 0.0f;
243 + int num_data_gte_zero = 0;
244 + {
245 + float* const curr_data = (*image)[0];
246 + for (int i = 0; i < image->data_size_; ++i) {
247 + if (curr_data[i] >= 0.0f) {
248 + running_sum += curr_data[i];
249 + ++num_data_gte_zero;
250 + } else {
251 + curr_data[i] = -1.0f;
252 + }
253 + }
254 + }
255 +
256 + // If none of the pixels are valid, just set the entire thing to 0.0f.
257 + if (num_data_gte_zero == 0) {
258 + image->Clear(0.0f);
259 + return;
260 + }
261 +
262 + const float corrected_mean = running_sum / num_data_gte_zero;
263 +
264 + float* curr_data = (*image)[0];
265 + for (int i = 0; i < image->data_size_; ++i) {
266 + const float curr_val = *curr_data;
267 + *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean;
268 + }
269 +
270 + const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f);
271 +
272 + if (std_dev > 0.0f) {
273 + curr_data = (*image)[0];
274 + for (int i = 0; i < image->data_size_; ++i) {
275 + *curr_data++ /= std_dev;
276 + }
277 +
278 +#ifdef SANITY_CHECKS
279 + LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev);
280 + const float correlation =
281 + ComputeCrossCorrelation(image->data(),
282 + image->data(),
283 + image->data_size_);
284 +
285 + if (std::abs(correlation - 1.0f) > EPSILON) {
286 + LOG(ERROR) << "Bad image!" << std::endl;
287 + LOG(ERROR) << *image << std::endl;
288 + }
289 +
290 + SCHECK(std::abs(correlation - 1.0f) < EPSILON,
291 + "Correlation wasn't 1.0f: %.10f", correlation);
292 +#endif
293 + }
294 +}
295 +
296 +} // namespace tf_tracking
297 +
298 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
18 +
19 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 +
24 +namespace tf_tracking {
25 +
26 +typedef uint8_t Code;
27 +
28 +class IntegralImage : public Image<uint32_t> {
29 + public:
30 + explicit IntegralImage(const Image<uint8_t>& image_base)
31 + : Image<uint32_t>(image_base.GetWidth(), image_base.GetHeight()) {
32 + Recompute(image_base);
33 + }
34 +
35 + IntegralImage(const int width, const int height)
36 + : Image<uint32_t>(width, height) {}
37 +
38 + void Recompute(const Image<uint8_t>& image_base) {
39 + SCHECK(image_base.GetWidth() == GetWidth() &&
40 + image_base.GetHeight() == GetHeight(), "Dimensions don't match!");
41 +
42 + // Sum along first row.
43 + {
44 + int x_sum = 0;
45 + for (int x = 0; x < image_base.GetWidth(); ++x) {
46 + x_sum += image_base[0][x];
47 + (*this)[0][x] = x_sum;
48 + }
49 + }
50 +
51 + // Sum everything else.
52 + for (int y = 1; y < image_base.GetHeight(); ++y) {
53 + uint32_t* curr_sum = (*this)[y];
54 +
55 + // Previously summed pointers.
56 + const uint32_t* up_one = (*this)[y - 1];
57 +
58 + // Current value pointer.
59 + const uint8_t* curr_delta = image_base[y];
60 +
61 + uint32_t row_till_now = 0;
62 +
63 + for (int x = 0; x < GetWidth(); ++x) {
64 + // Add the one above and the one to the left.
65 + row_till_now += *curr_delta;
66 + *curr_sum = *up_one + row_till_now;
67 +
68 + // Scoot everything along.
69 + ++curr_sum;
70 + ++up_one;
71 + ++curr_delta;
72 + }
73 + }
74 +
75 + SCHECK(VerifyData(image_base), "Images did not match!");
76 + }
77 +
78 + bool VerifyData(const Image<uint8_t>& image_base) {
79 + for (int y = 0; y < GetHeight(); ++y) {
80 + for (int x = 0; x < GetWidth(); ++x) {
81 + uint32_t curr_val = (*this)[y][x];
82 +
83 + if (x > 0) {
84 + curr_val -= (*this)[y][x - 1];
85 + }
86 +
87 + if (y > 0) {
88 + curr_val -= (*this)[y - 1][x];
89 + }
90 +
91 + if (x > 0 && y > 0) {
92 + curr_val += (*this)[y - 1][x - 1];
93 + }
94 +
95 + if (curr_val != image_base[y][x]) {
96 + LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]);
97 + return false;
98 + }
99 +
100 + if (GetRegionSum(x, y, x, y) != curr_val) {
101 + LOGE("Mismatch!");
102 + }
103 + }
104 + }
105 +
106 + return true;
107 + }
108 +
109 + // Returns the sum of all pixels in the specified region.
110 + inline uint32_t GetRegionSum(const int x1, const int y1, const int x2,
111 + const int y2) const {
112 + SCHECK(x1 >= 0 && y1 >= 0 &&
113 + x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(),
114 + "indices out of bounds! %d-%d / %d, %d-%d / %d, ",
115 + x1, x2, GetWidth(), y1, y2, GetHeight());
116 +
117 + const uint32_t everything = (*this)[y2][x2];
118 +
119 + uint32_t sum = everything;
120 + if (x1 > 0 && y1 > 0) {
121 + // Most common case.
122 + const uint32_t left = (*this)[y2][x1 - 1];
123 + const uint32_t top = (*this)[y1 - 1][x2];
124 + const uint32_t top_left = (*this)[y1 - 1][x1 - 1];
125 +
126 + sum = everything - left - top + top_left;
127 + SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d",
128 + everything, left, top, top_left, sum, x1, y1, x2, y2);
129 + } else if (x1 > 0) {
130 + // Flush against top of image.
131 + // Subtract out the region to the left only.
132 + const uint32_t top = (*this)[y2][x1 - 1];
133 + sum = everything - top;
134 + SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum);
135 + } else if (y1 > 0) {
136 + // Flush against left side of image.
137 + // Subtract out the region above only.
138 + const uint32_t left = (*this)[y1 - 1][x2];
139 + sum = everything - left;
140 + SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum);
141 + }
142 +
143 + SCHECK(sum >= 0, "Negative sum!");
144 +
145 + return sum;
146 + }
147 +
148 + // Returns the 2bit code associated with this region, which represents
149 + // the overall gradient.
150 + inline Code GetCode(const BoundingBox& bounding_box) const {
151 + return GetCode(bounding_box.left_, bounding_box.top_,
152 + bounding_box.right_, bounding_box.bottom_);
153 + }
154 +
155 + inline Code GetCode(const int x1, const int y1,
156 + const int x2, const int y2) const {
157 + SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d",
158 + x1, y1, x2, y2);
159 +
160 + // Gradient computed vertically.
161 + const int box_height = (y2 - y1) / 2;
162 + const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height);
163 + const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2);
164 + const bool vertical_code = top_sum > bottom_sum;
165 +
166 + // Gradient computed horizontally.
167 + const int box_width = (x2 - x1) / 2;
168 + const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2);
169 + const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2);
170 + const bool horizontal_code = left_sum > right_sum;
171 +
172 + const Code final_code = (vertical_code << 1) | horizontal_code;
173 +
174 + SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)),
175 + "Invalid code! %d", final_code);
176 +
177 + // Returns a value 0-3.
178 + return final_code;
179 + }
180 +
181 + private:
182 + TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage);
183 +};
184 +
185 +} // namespace tf_tracking
186 +
187 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
18 +
19 +#include <jni.h>
20 +#include <stdint.h>
21 +
22 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 +
24 +// The JniLongField class is used to access Java fields from native code. This
25 +// technique of hiding pointers to native objects in opaque Java fields is how
26 +// the Android hardware libraries work. This reduces the amount of static
27 +// native methods and makes it easier to manage the lifetime of native objects.
28 +class JniLongField {
29 + public:
30 + JniLongField(const char* field_name)
31 + : field_name_(field_name), field_ID_(0) {}
32 +
33 + int64_t get(JNIEnv* env, jobject thiz) {
34 + if (field_ID_ == 0) {
35 + jclass cls = env->GetObjectClass(thiz);
36 + CHECK_ALWAYS(cls != 0, "Unable to find class");
37 + field_ID_ = env->GetFieldID(cls, field_name_, "J");
38 + CHECK_ALWAYS(field_ID_ != 0,
39 + "Unable to find field %s. (Check proguard cfg)", field_name_);
40 + }
41 +
42 + return env->GetLongField(thiz, field_ID_);
43 + }
44 +
45 + void set(JNIEnv* env, jobject thiz, int64_t value) {
46 + if (field_ID_ == 0) {
47 + jclass cls = env->GetObjectClass(thiz);
48 + CHECK_ALWAYS(cls != 0, "Unable to find class");
49 + field_ID_ = env->GetFieldID(cls, field_name_, "J");
50 + CHECK_ALWAYS(field_ID_ != 0,
51 + "Unable to find field %s (Check proguard cfg)", field_name_);
52 + }
53 +
54 + env->SetLongField(thiz, field_ID_, value);
55 + }
56 +
57 + private:
58 + const char* const field_name_;
59 +
60 + // This is just a cache
61 + jfieldID field_ID_;
62 +};
63 +
64 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
18 +
19 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
25 +
26 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
27 +
28 +namespace tf_tracking {
29 +
30 +// For keeping track of keypoints.
31 +struct Keypoint {
32 + Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {}
33 + Keypoint(const float x, const float y)
34 + : pos_(x, y), score_(0.0f), type_(0) {}
35 +
36 + Point2f pos_;
37 + float score_;
38 + uint8_t type_;
39 +};
40 +
41 +inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) {
42 + return stream << "[" << keypoint.pos_ << ", "
43 + << keypoint.score_ << ", " << keypoint.type_ << "]";
44 +}
45 +
46 +} // namespace tf_tracking
47 +
48 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// Various keypoint detecting functions.
17 +
18 +#include <float.h>
19 +
20 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
24 +
25 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
27 +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
28 +
29 +namespace tf_tracking {
30 +
31 +static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) {
32 + return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]);
33 +}
34 +
35 +void KeypointDetector::ScoreKeypoints(const ImageData& image_data,
36 + const int num_candidates,
37 + Keypoint* const candidate_keypoints) {
38 + const Image<int>& I_x = *image_data.GetSpatialX(0);
39 + const Image<int>& I_y = *image_data.GetSpatialY(0);
40 +
41 + if (config_->detect_skin) {
42 + const Image<uint8_t>& u_data = *image_data.GetU();
43 + const Image<uint8_t>& v_data = *image_data.GetV();
44 +
45 + static const int reference[] = {111, 155};
46 +
47 + // Score all the keypoints.
48 + for (int i = 0; i < num_candidates; ++i) {
49 + Keypoint* const keypoint = candidate_keypoints + i;
50 +
51 + const int x_pos = keypoint->pos_.x * 2;
52 + const int y_pos = keypoint->pos_.y * 2;
53 +
54 + const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]};
55 + keypoint->score_ =
56 + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) /
57 + GetDistSquaredBetween(reference, curr_color);
58 + }
59 + } else {
60 + // Score all the keypoints.
61 + for (int i = 0; i < num_candidates; ++i) {
62 + Keypoint* const keypoint = candidate_keypoints + i;
63 + keypoint->score_ =
64 + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y);
65 + }
66 + }
67 +}
68 +
69 +
70 +inline int KeypointCompare(const void* const a, const void* const b) {
71 + return (reinterpret_cast<const Keypoint*>(a)->score_ -
72 + reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1;
73 +}
74 +
75 +
76 +// Quicksorts detected keypoints by score.
77 +void KeypointDetector::SortKeypoints(const int num_candidates,
78 + Keypoint* const candidate_keypoints) const {
79 + qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare);
80 +
81 +#ifdef SANITY_CHECKS
82 + // Verify that the array got sorted.
83 + float last_score = FLT_MAX;
84 + for (int i = 0; i < num_candidates; ++i) {
85 + const float curr_score = candidate_keypoints[i].score_;
86 +
87 + // Scores should be monotonically increasing.
88 + SCHECK(last_score >= curr_score,
89 + "Quicksort failure! %d: %.5f > %d: %.5f (%d total)",
90 + i - 1, last_score, i, curr_score, num_candidates);
91 +
92 + last_score = curr_score;
93 + }
94 +#endif
95 +}
96 +
97 +
98 +int KeypointDetector::SelectKeypointsInBox(
99 + const BoundingBox& box,
100 + const Keypoint* const candidate_keypoints,
101 + const int num_candidates,
102 + const int max_keypoints,
103 + const int num_existing_keypoints,
104 + const Keypoint* const existing_keypoints,
105 + Keypoint* const final_keypoints) const {
106 + if (max_keypoints <= 0) {
107 + return 0;
108 + }
109 +
110 + // This is the distance within which keypoints may be placed to each other
111 + // within this box, roughly based on the box dimensions.
112 + const int distance =
113 + MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f);
114 +
115 + // First, mark keypoints that already happen to be inside this region. Ignore
116 + // keypoints that are outside it, however close they might be.
117 + interest_map_->Clear(false);
118 + for (int i = 0; i < num_existing_keypoints; ++i) {
119 + const Keypoint& candidate = existing_keypoints[i];
120 +
121 + const int x_pos = candidate.pos_.x;
122 + const int y_pos = candidate.pos_.y;
123 + if (box.Contains(candidate.pos_)) {
124 + MarkImage(x_pos, y_pos, distance, interest_map_.get());
125 + }
126 + }
127 +
128 + // Now, go through and check which keypoints will still fit in the box.
129 + int num_keypoints_selected = 0;
130 + for (int i = 0; i < num_candidates; ++i) {
131 + const Keypoint& candidate = candidate_keypoints[i];
132 +
133 + const int x_pos = candidate.pos_.x;
134 + const int y_pos = candidate.pos_.y;
135 +
136 + if (!box.Contains(candidate.pos_) ||
137 + !interest_map_->ValidPixel(x_pos, y_pos)) {
138 + continue;
139 + }
140 +
141 + if (!(*interest_map_)[y_pos][x_pos]) {
142 + final_keypoints[num_keypoints_selected++] = candidate;
143 + if (num_keypoints_selected >= max_keypoints) {
144 + break;
145 + }
146 + MarkImage(x_pos, y_pos, distance, interest_map_.get());
147 + }
148 + }
149 + return num_keypoints_selected;
150 +}
151 +
152 +
153 +void KeypointDetector::SelectKeypoints(
154 + const std::vector<BoundingBox>& boxes,
155 + const Keypoint* const candidate_keypoints,
156 + const int num_candidates,
157 + FramePair* const curr_change) const {
158 + // Now select all the interesting keypoints that fall insider our boxes.
159 + curr_change->number_of_keypoints_ = 0;
160 + for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
161 + iter != boxes.end(); ++iter) {
162 + const BoundingBox bounding_box = *iter;
163 +
164 + // Count up keypoints that have already been selected, and fall within our
165 + // box.
166 + int num_keypoints_already_in_box = 0;
167 + for (int i = 0; i < curr_change->number_of_keypoints_; ++i) {
168 + if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) {
169 + ++num_keypoints_already_in_box;
170 + }
171 + }
172 +
173 + const int max_keypoints_to_find_in_box =
174 + MIN(kMaxKeypointsForObject - num_keypoints_already_in_box,
175 + kMaxKeypoints - curr_change->number_of_keypoints_);
176 +
177 + const int num_new_keypoints_in_box = SelectKeypointsInBox(
178 + bounding_box,
179 + candidate_keypoints,
180 + num_candidates,
181 + max_keypoints_to_find_in_box,
182 + curr_change->number_of_keypoints_,
183 + curr_change->frame1_keypoints_,
184 + curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_);
185 +
186 + curr_change->number_of_keypoints_ += num_new_keypoints_in_box;
187 +
188 + LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_);
189 + }
190 +}
191 +
192 +
193 +// Walks along the given circle checking for pixels above or below the center.
194 +// Returns a score, or 0 if the keypoint did not pass the criteria.
195 +//
196 +// Parameters:
197 +// circle_perimeter: the circumference in pixels of the circle.
198 +// threshold: the minimum number of contiguous pixels that must be above or
199 +// below the center value.
200 +// center_ptr: the location of the center pixel in memory
201 +// offsets: the relative offsets from the center pixel of the edge pixels.
202 +inline int TestCircle(const int circle_perimeter, const int threshold,
203 + const uint8_t* const center_ptr, const int* offsets) {
204 + // Get the actual value of the center pixel for easier reference later on.
205 + const int center_value = static_cast<int>(*center_ptr);
206 +
207 + // Number of total pixels to check. Have to wrap around some in case
208 + // the contiguous section is split by the array edges.
209 + const int num_total = circle_perimeter + threshold - 1;
210 +
211 + int num_above = 0;
212 + int above_diff = 0;
213 +
214 + int num_below = 0;
215 + int below_diff = 0;
216 +
217 + // Used to tell when this is definitely not going to meet the threshold so we
218 + // can early abort.
219 + int minimum_by_now = threshold - num_total + 1;
220 +
221 + // Go through every pixel along the perimeter of the circle, and then around
222 + // again a little bit.
223 + for (int i = 0; i < num_total; ++i) {
224 + // This should be faster than mod.
225 + const int perim_index = i < circle_perimeter ? i : i - circle_perimeter;
226 +
227 + // This gets the value of the current pixel along the perimeter by using
228 + // a precomputed offset.
229 + const int curr_value =
230 + static_cast<int>(center_ptr[offsets[perim_index]]);
231 +
232 + const int difference = curr_value - center_value;
233 +
234 + if (difference > kFastDiffAmount) {
235 + above_diff += difference;
236 + ++num_above;
237 +
238 + num_below = 0;
239 + below_diff = 0;
240 +
241 + if (num_above >= threshold) {
242 + return above_diff;
243 + }
244 + } else if (difference < -kFastDiffAmount) {
245 + below_diff += difference;
246 + ++num_below;
247 +
248 + num_above = 0;
249 + above_diff = 0;
250 +
251 + if (num_below >= threshold) {
252 + return below_diff;
253 + }
254 + } else {
255 + num_above = 0;
256 + num_below = 0;
257 + above_diff = 0;
258 + below_diff = 0;
259 + }
260 +
261 + // See if there's any chance of making the threshold.
262 + if (MAX(num_above, num_below) < minimum_by_now) {
263 + // Didn't pass.
264 + return 0;
265 + }
266 + ++minimum_by_now;
267 + }
268 +
269 + // Didn't pass.
270 + return 0;
271 +}
272 +
273 +
274 +// Returns a score in the range [0.0, positive infinity) which represents the
275 +// relative likelihood of a point being a corner.
276 +float KeypointDetector::HarrisFilter(const Image<int32_t>& I_x,
277 + const Image<int32_t>& I_y, const float x,
278 + const float y) const {
279 + if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) &&
280 + I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) {
281 + // Image gradient matrix.
282 + float G[] = { 0, 0, 0, 0 };
283 + CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G);
284 +
285 + const float dx = G[0];
286 + const float dy = G[3];
287 + const float dxy = G[1];
288 +
289 + // Harris-Nobel corner score.
290 + return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN);
291 + }
292 +
293 + return 0.0f;
294 +}
295 +
296 +
297 +int KeypointDetector::AddExtraCandidatesForBoxes(
298 + const std::vector<BoundingBox>& boxes,
299 + const int max_num_keypoints,
300 + Keypoint* const keypoints) const {
301 + int num_keypoints_added = 0;
302 +
303 + for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
304 + iter != boxes.end(); ++iter) {
305 + const BoundingBox box = *iter;
306 +
307 + for (int i = 0; i < kNumToAddAsCandidates; ++i) {
308 + for (int j = 0; j < kNumToAddAsCandidates; ++j) {
309 + if (num_keypoints_added >= max_num_keypoints) {
310 + LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints);
311 + return num_keypoints_added;
312 + }
313 +
314 + Keypoint& curr_keypoint = keypoints[num_keypoints_added++];
315 + curr_keypoint.pos_ = Point2f(
316 + box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates,
317 + box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates);
318 + curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST;
319 + }
320 + }
321 + }
322 +
323 + return num_keypoints_added;
324 +}
325 +
326 +
327 +void KeypointDetector::FindKeypoints(const ImageData& image_data,
328 + const std::vector<BoundingBox>& rois,
329 + const FramePair& prev_change,
330 + FramePair* const curr_change) {
331 + // Copy keypoints from second frame of last pass to temp keypoints of this
332 + // pass.
333 + int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_);
334 +
335 + const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints;
336 + number_of_tmp_keypoints +=
337 + FindFastKeypoints(image_data, max_num_fast,
338 + tmp_keypoints_ + number_of_tmp_keypoints);
339 +
340 + TimeLog("Found FAST keypoints");
341 +
342 + if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
343 + LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints",
344 + kMaxTempKeypoints, number_of_tmp_keypoints);
345 + }
346 +
347 + if (kAddArbitraryKeypoints) {
348 + // Add some for each object prior to scoring.
349 + const int max_num_box_keypoints =
350 + kMaxTempKeypoints - number_of_tmp_keypoints;
351 + number_of_tmp_keypoints +=
352 + AddExtraCandidatesForBoxes(rois, max_num_box_keypoints,
353 + tmp_keypoints_ + number_of_tmp_keypoints);
354 + TimeLog("Added box keypoints");
355 +
356 + if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
357 + LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints",
358 + kMaxTempKeypoints, number_of_tmp_keypoints);
359 + }
360 + }
361 +
362 + // Score them...
363 + LOGV("Scoring %d keypoints!", number_of_tmp_keypoints);
364 + ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_);
365 + TimeLog("Scored keypoints");
366 +
367 + // Now pare it down a bit.
368 + SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_);
369 + TimeLog("Sorted keypoints");
370 +
371 + LOGV("%d keypoints to select from!", number_of_tmp_keypoints);
372 +
373 + SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change);
374 + TimeLog("Selected keypoints");
375 +
376 + LOGV("Picked %d (%d max) final keypoints out of %d potential.",
377 + curr_change->number_of_keypoints_,
378 + kMaxKeypoints, number_of_tmp_keypoints);
379 +}
380 +
381 +
382 +int KeypointDetector::CopyKeypoints(const FramePair& prev_change,
383 + Keypoint* const new_keypoints) {
384 + int number_of_keypoints = 0;
385 +
386 + // Caching values from last pass, just copy and compact.
387 + for (int i = 0; i < prev_change.number_of_keypoints_; ++i) {
388 + if (prev_change.optical_flow_found_keypoint_[i]) {
389 + new_keypoints[number_of_keypoints] =
390 + prev_change.frame2_keypoints_[i];
391 +
392 + new_keypoints[number_of_keypoints].score_ =
393 + prev_change.frame1_keypoints_[i].score_;
394 +
395 + ++number_of_keypoints;
396 + }
397 + }
398 +
399 + TimeLog("Copied keypoints");
400 + return number_of_keypoints;
401 +}
402 +
403 +
404 +// FAST keypoint detector.
405 +int KeypointDetector::FindFastKeypoints(const Image<uint8_t>& frame,
406 + const int quadrant,
407 + const int downsample_factor,
408 + const int max_num_keypoints,
409 + Keypoint* const keypoints) {
410 + /*
411 + // Reference for a circle of diameter 7.
412 + const int circle[] = {0, 0, 1, 1, 1, 0, 0,
413 + 0, 1, 0, 0, 0, 1, 0,
414 + 1, 0, 0, 0, 0, 0, 1,
415 + 1, 0, 0, 0, 0, 0, 1,
416 + 1, 0, 0, 0, 0, 0, 1,
417 + 0, 1, 0, 0, 0, 1, 0,
418 + 0, 0, 1, 1, 1, 0, 0};
419 + const int circle_offset[] =
420 + {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46};
421 + */
422 +
423 + // Quick test of compass directions. Any length 16 circle with a break of up
424 + // to 4 pixels will have at least 3 of these 4 pixels active.
425 + static const int short_circle_perimeter = 4;
426 + static const int short_threshold = 3;
427 + static const int short_circle_x[] = { -3, 0, +3, 0 };
428 + static const int short_circle_y[] = { 0, -3, 0, +3 };
429 +
430 + // Precompute image offsets.
431 + int short_offsets[short_circle_perimeter];
432 + for (int i = 0; i < short_circle_perimeter; ++i) {
433 + short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth();
434 + }
435 +
436 + // Large circle values.
437 + static const int full_circle_perimeter = 16;
438 + static const int full_threshold = 12;
439 + static const int full_circle_x[] =
440 + { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 };
441 + static const int full_circle_y[] =
442 + { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 };
443 +
444 + // Precompute image offsets.
445 + int full_offsets[full_circle_perimeter];
446 + for (int i = 0; i < full_circle_perimeter; ++i) {
447 + full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth();
448 + }
449 +
450 + const int scratch_stride = frame.stride();
451 +
452 + keypoint_scratch_->Clear(0);
453 +
454 + // Set up the bounds on the region to test based on the passed-in quadrant.
455 + const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer;
456 + const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer;
457 + const int start_x =
458 + kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width);
459 + const int start_y =
460 + kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height);
461 + const int end_x = start_x + quadrant_width;
462 + const int end_y = start_y + quadrant_height;
463 +
464 + // Loop through once to find FAST keypoint clumps.
465 + for (int img_y = start_y; img_y < end_y; ++img_y) {
466 + const uint8_t* curr_pixel_ptr = frame[img_y] + start_x;
467 +
468 + for (int img_x = start_x; img_x < end_x; ++img_x) {
469 + // Only insert it if it meets the quick minimum requirements test.
470 + if (TestCircle(short_circle_perimeter, short_threshold,
471 + curr_pixel_ptr, short_offsets) != 0) {
472 + // Longer test for actual keypoint score..
473 + const int fast_score = TestCircle(full_circle_perimeter,
474 + full_threshold,
475 + curr_pixel_ptr,
476 + full_offsets);
477 +
478 + // Non-zero score means the keypoint was found.
479 + if (fast_score != 0) {
480 + uint8_t* const center_ptr = (*keypoint_scratch_)[img_y] + img_x;
481 +
482 + // Increase the keypoint count on this pixel and the pixels in all
483 + // 4 cardinal directions.
484 + *center_ptr += 5;
485 + *(center_ptr - 1) += 1;
486 + *(center_ptr + 1) += 1;
487 + *(center_ptr - scratch_stride) += 1;
488 + *(center_ptr + scratch_stride) += 1;
489 + }
490 + }
491 +
492 + ++curr_pixel_ptr;
493 + } // x
494 + } // y
495 +
496 + TimeLog("Found FAST keypoints.");
497 +
498 + int num_keypoints = 0;
499 + // Loop through again and Harris filter pixels in the center of clumps.
500 + // We can shrink the window by 1 pixel on every side.
501 + for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) {
502 + const uint8_t* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x;
503 +
504 + for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) {
505 + if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) {
506 + Keypoint* const keypoint = keypoints + num_keypoints;
507 + keypoint->pos_ = Point2f(
508 + img_x * downsample_factor, img_y * downsample_factor);
509 + keypoint->score_ = 0;
510 + keypoint->type_ = KEYPOINT_TYPE_FAST;
511 +
512 + ++num_keypoints;
513 + if (num_keypoints >= max_num_keypoints) {
514 + return num_keypoints;
515 + }
516 + }
517 +
518 + ++curr_pixel_ptr;
519 + } // x
520 + } // y
521 +
522 + TimeLog("Picked FAST keypoints.");
523 +
524 + return num_keypoints;
525 +}
526 +
527 +int KeypointDetector::FindFastKeypoints(const ImageData& image_data,
528 + const int max_num_keypoints,
529 + Keypoint* const keypoints) {
530 + int downsample_factor = 1;
531 + int num_found = 0;
532 +
533 + // TODO(andrewharp): Get this working for multiple image scales.
534 + for (int i = 0; i < 1; ++i) {
535 + const Image<uint8_t>& frame = *image_data.GetPyramidSqrt2Level(i);
536 + num_found += FindFastKeypoints(
537 + frame, fast_quadrant_,
538 + downsample_factor, max_num_keypoints, keypoints + num_found);
539 + downsample_factor *= 2;
540 + }
541 +
542 + // Increment the current quadrant.
543 + fast_quadrant_ = (fast_quadrant_ + 1) % 4;
544 +
545 + return num_found;
546 +}
547 +
548 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
18 +
19 +#include <stdint.h>
20 +#include <vector>
21 +
22 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
26 +
27 +namespace tf_tracking {
28 +
29 +struct Keypoint;
30 +
31 +class KeypointDetector {
32 + public:
33 + explicit KeypointDetector(const KeypointDetectorConfig* const config)
34 + : config_(config),
35 + keypoint_scratch_(new Image<uint8_t>(config_->image_size)),
36 + interest_map_(new Image<bool>(config_->image_size)),
37 + fast_quadrant_(0) {
38 + interest_map_->Clear(false);
39 + }
40 +
41 + ~KeypointDetector() {}
42 +
43 + // Finds a new set of keypoints for the current frame, picked from the current
44 + // set of keypoints and also from a set discovered via a keypoint detector.
45 + // Special attention is applied to make sure that keypoints are distributed
46 + // within the supplied ROIs.
47 + void FindKeypoints(const ImageData& image_data,
48 + const std::vector<BoundingBox>& rois,
49 + const FramePair& prev_change,
50 + FramePair* const curr_change);
51 +
52 + private:
53 + // Compute the corneriness of a point in the image.
54 + float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y,
55 + const float x, const float y) const;
56 +
57 + // Adds a grid of candidate keypoints to the given box, up to
58 + // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
59 + int AddExtraCandidatesForBoxes(
60 + const std::vector<BoundingBox>& boxes,
61 + const int max_num_keypoints,
62 + Keypoint* const keypoints) const;
63 +
64 + // Scan the frame for potential keypoints using the FAST keypoint detector.
65 + // Quadrant is an argument 0-3 which refers to the quadrant of the image in
66 + // which to detect keypoints.
67 + int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant,
68 + const int downsample_factor,
69 + const int max_num_keypoints, Keypoint* const keypoints);
70 +
71 + int FindFastKeypoints(const ImageData& image_data,
72 + const int max_num_keypoints,
73 + Keypoint* const keypoints);
74 +
75 + // Score a bunch of candidate keypoints. Assigns the scores to the input
76 + // candidate_keypoints array entries.
77 + void ScoreKeypoints(const ImageData& image_data,
78 + const int num_candidates,
79 + Keypoint* const candidate_keypoints);
80 +
81 + void SortKeypoints(const int num_candidates,
82 + Keypoint* const candidate_keypoints) const;
83 +
84 + // Selects a set of keypoints falling within the supplied box such that the
85 + // most highly rated keypoints are picked first, and so that none of them are
86 + // too close together.
87 + int SelectKeypointsInBox(
88 + const BoundingBox& box,
89 + const Keypoint* const candidate_keypoints,
90 + const int num_candidates,
91 + const int max_keypoints,
92 + const int num_existing_keypoints,
93 + const Keypoint* const existing_keypoints,
94 + Keypoint* const final_keypoints) const;
95 +
96 + // Selects from the supplied sorted keypoint pool a set of keypoints that will
97 + // best cover the given set of boxes, such that each box is covered at a
98 + // resolution proportional to its size.
99 + void SelectKeypoints(
100 + const std::vector<BoundingBox>& boxes,
101 + const Keypoint* const candidate_keypoints,
102 + const int num_candidates,
103 + FramePair* const frame_change) const;
104 +
105 + // Copies and compacts the found keypoints in the second frame of prev_change
106 + // into the array at new_keypoints.
107 + static int CopyKeypoints(const FramePair& prev_change,
108 + Keypoint* const new_keypoints);
109 +
110 + const KeypointDetectorConfig* const config_;
111 +
112 + // Scratch memory for keypoint candidacy detection and non-max suppression.
113 + std::unique_ptr<Image<uint8_t> > keypoint_scratch_;
114 +
115 + // Regions of the image to pay special attention to.
116 + std::unique_ptr<Image<bool> > interest_map_;
117 +
118 + // The current quadrant of the image to detect FAST keypoints in.
119 + // Keypoint detection is staggered for performance reasons. Every four frames
120 + // a full scan of the frame will have been performed.
121 + int fast_quadrant_;
122 +
123 + Keypoint tmp_keypoints_[kMaxTempKeypoints];
124 +};
125 +
126 +} // namespace tf_tracking
127 +
128 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
1 +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
17 +
18 +#ifdef STANDALONE_DEMO_LIB
19 +
20 +#include <android/log.h>
21 +#include <stdlib.h>
22 +#include <time.h>
23 +#include <iostream>
24 +#include <sstream>
25 +
26 +LogMessage::LogMessage(const char* fname, int line, int severity)
27 + : fname_(fname), line_(line), severity_(severity) {}
28 +
29 +void LogMessage::GenerateLogMessage() {
30 + int android_log_level;
31 + switch (severity_) {
32 + case INFO:
33 + android_log_level = ANDROID_LOG_INFO;
34 + break;
35 + case WARNING:
36 + android_log_level = ANDROID_LOG_WARN;
37 + break;
38 + case ERROR:
39 + android_log_level = ANDROID_LOG_ERROR;
40 + break;
41 + case FATAL:
42 + android_log_level = ANDROID_LOG_FATAL;
43 + break;
44 + default:
45 + if (severity_ < INFO) {
46 + android_log_level = ANDROID_LOG_VERBOSE;
47 + } else {
48 + android_log_level = ANDROID_LOG_ERROR;
49 + }
50 + break;
51 + }
52 +
53 + std::stringstream ss;
54 + const char* const partial_name = strrchr(fname_, '/');
55 + ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_
56 + << " " << str();
57 + __android_log_write(android_log_level, "native", ss.str().c_str());
58 +
59 + // Also log to stderr (for standalone Android apps).
60 + std::cerr << "native : " << ss.str() << std::endl;
61 +
62 + // Android logging at level FATAL does not terminate execution, so abort()
63 + // is still required to stop the program.
64 + if (severity_ == FATAL) {
65 + abort();
66 + }
67 +}
68 +
69 +namespace {
70 +
71 +// Parse log level (int64) from environment variable (char*)
72 +int64_t LogLevelStrToInt(const char* tf_env_var_val) {
73 + if (tf_env_var_val == nullptr) {
74 + return 0;
75 + }
76 +
77 + // Ideally we would use env_var / safe_strto64, but it is
78 + // hard to use here without pulling in a lot of dependencies,
79 + // so we use std:istringstream instead
80 + std::string min_log_level(tf_env_var_val);
81 + std::istringstream ss(min_log_level);
82 + int64_t level;
83 + if (!(ss >> level)) {
84 + // Invalid vlog level setting, set level to default (0)
85 + level = 0;
86 + }
87 +
88 + return level;
89 +}
90 +
91 +int64_t MinLogLevelFromEnv() {
92 + const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
93 + return LogLevelStrToInt(tf_env_var_val);
94 +}
95 +
96 +int64_t MinVLogLevelFromEnv() {
97 + const char* tf_env_var_val = getenv("TF_CPP_MIN_VLOG_LEVEL");
98 + return LogLevelStrToInt(tf_env_var_val);
99 +}
100 +
101 +} // namespace
102 +
103 +LogMessage::~LogMessage() {
104 + // Read the min log level once during the first call to logging.
105 + static int64_t min_log_level = MinLogLevelFromEnv();
106 + if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage();
107 +}
108 +
109 +int64_t LogMessage::MinVLogLevel() {
110 + static const int64_t min_vlog_level = MinVLogLevelFromEnv();
111 + return min_vlog_level;
112 +}
113 +
114 +LogMessageFatal::LogMessageFatal(const char* file, int line)
115 + : LogMessage(file, line, ANDROID_LOG_FATAL) {}
116 +LogMessageFatal::~LogMessageFatal() {
117 + // abort() ensures we don't return (we promised we would not via
118 + // ATTRIBUTE_NORETURN).
119 + GenerateLogMessage();
120 + abort();
121 +}
122 +
123 +void LogString(const char* fname, int line, int severity,
124 + const std::string& message) {
125 + LogMessage(fname, line, severity) << message;
126 +}
127 +
128 +void LogPrintF(const int severity, const char* format, ...) {
129 + char message[1024];
130 + va_list argptr;
131 + va_start(argptr, format);
132 + vsnprintf(message, 1024, format, argptr);
133 + va_end(argptr);
134 + __android_log_write(severity, "native", message);
135 +
136 + // Also log to stderr (for standalone Android apps).
137 + std::cerr << "native : " << message << std::endl;
138 +}
139 +
140 +#endif
1 +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
18 +
19 +#include <android/log.h>
20 +#include <string.h>
21 +#include <ostream>
22 +#include <sstream>
23 +#include <string>
24 +
25 +// Allow this library to be built without depending on TensorFlow by
26 +// defining STANDALONE_DEMO_LIB. Otherwise TensorFlow headers will be
27 +// used.
28 +#ifdef STANDALONE_DEMO_LIB
29 +
30 +// A macro to disallow the copy constructor and operator= functions
31 +// This is usually placed in the private: declarations for a class.
32 +#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \
33 + TypeName(const TypeName&) = delete; \
34 + void operator=(const TypeName&) = delete
35 +
36 +#if defined(COMPILER_GCC3)
37 +#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
38 +#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
39 +#else
40 +#define TF_PREDICT_FALSE(x) (x)
41 +#define TF_PREDICT_TRUE(x) (x)
42 +#endif
43 +
44 +// Log levels equivalent to those defined by
45 +// third_party/tensorflow/core/platform/logging.h
46 +const int INFO = 0; // base_logging::INFO;
47 +const int WARNING = 1; // base_logging::WARNING;
48 +const int ERROR = 2; // base_logging::ERROR;
49 +const int FATAL = 3; // base_logging::FATAL;
50 +const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES;
51 +
52 +class LogMessage : public std::basic_ostringstream<char> {
53 + public:
54 + LogMessage(const char* fname, int line, int severity);
55 + ~LogMessage();
56 +
57 + // Returns the minimum log level for VLOG statements.
58 + // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output,
59 + // but VLOG(3) will not. Defaults to 0.
60 + static int64_t MinVLogLevel();
61 +
62 + protected:
63 + void GenerateLogMessage();
64 +
65 + private:
66 + const char* fname_;
67 + int line_;
68 + int severity_;
69 +};
70 +
71 +// LogMessageFatal ensures the process will exit in failure after
72 +// logging this message.
73 +class LogMessageFatal : public LogMessage {
74 + public:
75 + LogMessageFatal(const char* file, int line);
76 + ~LogMessageFatal();
77 +};
78 +
79 +#define _TF_LOG_INFO \
80 + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
81 +#define _TF_LOG_WARNING \
82 + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::WARNING)
83 +#define _TF_LOG_ERROR \
84 + ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::ERROR)
85 +#define _TF_LOG_FATAL \
86 + ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__)
87 +
88 +#define _TF_LOG_QFATAL _TF_LOG_FATAL
89 +
90 +#define LOG(severity) _TF_LOG_##severity
91 +
92 +#define VLOG_IS_ON(lvl) ((lvl) <= LogMessage::MinVLogLevel())
93 +
94 +#define VLOG(lvl) \
95 + if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \
96 + LogMessage(__FILE__, __LINE__, ANDROID_LOG_INFO)
97 +
98 +void LogPrintF(const int severity, const char* format, ...);
99 +
100 +// Support for printf style logging.
101 +#define LOGV(...)
102 +#define LOGD(...)
103 +#define LOGI(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__);
104 +#define LOGW(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__);
105 +#define LOGE(...) LogPrintF(ANDROID_LOG_ERROR, __VA_ARGS__);
106 +
107 +#else
108 +
109 +#include "tensorflow/core/lib/strings/stringprintf.h"
110 +#include "tensorflow/core/platform/logging.h"
111 +
112 +// Support for printf style logging.
113 +#define LOGV(...)
114 +#define LOGD(...)
115 +#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
116 +#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
117 +#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
118 +
119 +#endif
120 +
121 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// NOTE: no native object detectors are currently provided or used by the code
17 +// in this directory. This class remains mainly for historical reasons.
18 +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
19 +
20 +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
21 +
22 +namespace tf_tracking {
23 +
24 +// This is here so that the vtable gets created properly.
25 +ObjectDetectorBase::~ObjectDetectorBase() {}
26 +
27 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// NOTE: no native object detectors are currently provided or used by the code
17 +// in this directory. This class remains mainly for historical reasons.
18 +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
19 +
20 +// Defines the ObjectDetector class that is the main interface for detecting
21 +// ObjectModelBases in frames.
22 +
23 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
24 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
25 +
26 +#include <float.h>
27 +#include <map>
28 +#include <memory>
29 +#include <sstream>
30 +#include <string>
31 +#include <vector>
32 +
33 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
34 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
35 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
36 +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
37 +#ifdef __RENDER_OPENGL__
38 +#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
39 +#endif
40 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
41 +
42 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
43 +#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
44 +#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
45 +
46 +namespace tf_tracking {
47 +
48 +// Adds BoundingSquares to a vector such that the first square added is centered
49 +// in the position given and of square_size, and the remaining squares are added
50 +// concentrentically, scaling down by scale_factor until the minimum threshold
51 +// size is passed.
52 +// Squares that do not fall completely within image_bounds will not be added.
53 +static inline void FillWithSquares(
54 + const BoundingBox& image_bounds,
55 + const BoundingBox& position,
56 + const float starting_square_size,
57 + const float smallest_square_size,
58 + const float scale_factor,
59 + std::vector<BoundingSquare>* const squares) {
60 + BoundingSquare descriptor_area =
61 + GetCenteredSquare(position, starting_square_size);
62 +
63 + SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);
64 +
65 + // Use a do/while loop to ensure that at least one descriptor is created.
66 + do {
67 + if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
68 + squares->push_back(descriptor_area);
69 + }
70 + descriptor_area.Scale(scale_factor);
71 + } while (descriptor_area.size_ >= smallest_square_size - EPSILON);
72 + LOGV("Created %zu squares starting from size %.2f to min size %.2f "
73 + "using scale factor: %.2f",
74 + squares->size(), starting_square_size, smallest_square_size,
75 + scale_factor);
76 +}
77 +
78 +
79 +// Represents a potential detection of a specific ObjectExemplar and Descriptor
80 +// at a specific position in the image.
81 +class Detection {
82 + public:
83 + explicit Detection(const ObjectModelBase* const object_model,
84 + const MatchScore match_score,
85 + const BoundingBox& bounding_box)
86 + : object_model_(object_model),
87 + match_score_(match_score),
88 + bounding_box_(bounding_box) {}
89 +
90 + Detection(const Detection& other)
91 + : object_model_(other.object_model_),
92 + match_score_(other.match_score_),
93 + bounding_box_(other.bounding_box_) {}
94 +
95 + virtual ~Detection() {}
96 +
97 + inline BoundingBox GetObjectBoundingBox() const {
98 + return bounding_box_;
99 + }
100 +
101 + inline MatchScore GetMatchScore() const {
102 + return match_score_;
103 + }
104 +
105 + inline const ObjectModelBase* GetObjectModel() const {
106 + return object_model_;
107 + }
108 +
109 + inline bool Intersects(const Detection& other) {
110 + // Check if any of the four axes separates us, there must be at least one.
111 + return bounding_box_.Intersects(other.bounding_box_);
112 + }
113 +
114 + struct Comp {
115 + inline bool operator()(const Detection& a, const Detection& b) const {
116 + return a.match_score_ > b.match_score_;
117 + }
118 + };
119 +
120 + // TODO(andrewharp): add accessors to update these instead.
121 + const ObjectModelBase* object_model_;
122 + MatchScore match_score_;
123 + BoundingBox bounding_box_;
124 +};
125 +
126 +inline std::ostream& operator<<(std::ostream& stream,
127 + const Detection& detection) {
128 + const BoundingBox actual_area = detection.GetObjectBoundingBox();
129 + stream << actual_area;
130 + return stream;
131 +}
132 +
133 +class ObjectDetectorBase {
134 + public:
135 + explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
136 + : config_(config),
137 + image_data_(NULL) {}
138 +
139 + virtual ~ObjectDetectorBase();
140 +
141 + // Sets the current image data. All calls to ObjectDetector other than
142 + // FillDescriptors use the image data last set.
143 + inline void SetImageData(const ImageData* const image_data) {
144 + image_data_ = image_data;
145 + }
146 +
147 + // Main entry point into the detection algorithm.
148 + // Scans the frame for candidates, tweaks them, and fills in the
149 + // given std::vector of Detection objects with acceptable matches.
150 + virtual void Detect(const std::vector<BoundingSquare>& positions,
151 + std::vector<Detection>* const detections) const = 0;
152 +
153 + virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;
154 +
155 + virtual void DeleteObjectModel(const std::string& name) = 0;
156 +
157 + virtual void GetObjectModels(
158 + std::vector<const ObjectModelBase*>* models) const = 0;
159 +
160 + // Creates a new ObjectExemplar from the given position in the context of
161 + // the last frame passed to NextFrame.
162 + // Will return null in the case that there's no room for a descriptor to be
163 + // created in the example area, or the example area is not completely
164 + // contained within the frame.
165 + virtual void UpdateModel(const Image<uint8_t>& base_image,
166 + const IntegralImage& integral_image,
167 + const BoundingBox& bounding_box, const bool locked,
168 + ObjectModelBase* model) const = 0;
169 +
170 + virtual void Draw() const = 0;
171 +
172 + virtual bool AllowSpontaneousDetections() = 0;
173 +
174 + protected:
175 + const std::unique_ptr<const ObjectDetectorConfig> config_;
176 +
177 + // The latest frame data, upon which all detections will be performed.
178 + // Not owned by this object, just provided for reference by ObjectTracker
179 + // via SetImageData().
180 + const ImageData* image_data_;
181 +
182 + private:
183 + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
184 +};
185 +
186 +template <typename ModelType>
187 +class ObjectDetector : public ObjectDetectorBase {
188 + public:
189 + explicit ObjectDetector(const ObjectDetectorConfig* const config)
190 + : ObjectDetectorBase(config) {}
191 +
192 + virtual ~ObjectDetector() {
193 + typename std::map<std::string, ModelType*>::const_iterator it =
194 + object_models_.begin();
195 + for (; it != object_models_.end(); ++it) {
196 + ModelType* model = it->second;
197 + delete model;
198 + }
199 + }
200 +
201 + virtual void DeleteObjectModel(const std::string& name) {
202 + ModelType* model = object_models_[name];
203 + CHECK_ALWAYS(model != NULL, "Model was null!");
204 + object_models_.erase(name);
205 + SAFE_DELETE(model);
206 + }
207 +
208 + virtual void GetObjectModels(
209 + std::vector<const ObjectModelBase*>* models) const {
210 + typename std::map<std::string, ModelType*>::const_iterator it =
211 + object_models_.begin();
212 + for (; it != object_models_.end(); ++it) {
213 + models->push_back(it->second);
214 + }
215 + }
216 +
217 + virtual bool AllowSpontaneousDetections() {
218 + return false;
219 + }
220 +
221 + protected:
222 + std::map<std::string, ModelType*> object_models_;
223 +
224 + private:
225 + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
226 +};
227 +
228 +} // namespace tf_tracking
229 +
230 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// NOTE: no native object detectors are currently provided or used by the code
17 +// in this directory. This class remains mainly for historical reasons.
18 +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
19 +
20 +// Contains ObjectModelBase declaration.
21 +
22 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
23 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
24 +
25 +#ifdef __RENDER_OPENGL__
26 +#include <GLES/gl.h>
27 +#include <GLES/glext.h>
28 +#endif
29 +
30 +#include <vector>
31 +
32 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
33 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
34 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
35 +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
36 +#ifdef __RENDER_OPENGL__
37 +#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
38 +#endif
39 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
40 +
41 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
42 +#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
43 +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
44 +
45 +namespace tf_tracking {
46 +
47 +// The ObjectModelBase class represents all the known appearance information for
48 +// an object. It is not a specific instance of the object in the world,
49 +// but just the general appearance information that enables detection. An
50 +// ObjectModelBase can be reused across multiple-instances of TrackedObjects.
51 +class ObjectModelBase {
52 + public:
53 + ObjectModelBase(const std::string& name) : name_(name) {}
54 +
55 + virtual ~ObjectModelBase() {}
56 +
57 + // Called when the next step in an ongoing track occurs.
58 + virtual void TrackStep(const BoundingBox& position,
59 + const Image<uint8_t>& image,
60 + const IntegralImage& integral_image,
61 + const bool authoritative) {}
62 +
63 + // Called when an object track is lost.
64 + virtual void TrackLost() {}
65 +
66 + // Called when an object track is confirmed as legitimate.
67 + virtual void TrackConfirmed() {}
68 +
69 + virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0;
70 +
71 + virtual MatchScore GetMatchScore(
72 + const BoundingBox& position, const ImageData& image_data) const = 0;
73 +
74 + virtual void Draw(float* const depth) const = 0;
75 +
76 + inline const std::string& GetName() const {
77 + return name_;
78 + }
79 +
80 + protected:
81 + const std::string name_;
82 +
83 + private:
84 + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase);
85 +};
86 +
87 +template <typename DetectorType>
88 +class ObjectModel : public ObjectModelBase {
89 + public:
90 + ObjectModel<DetectorType>(const DetectorType* const detector,
91 + const std::string& name)
92 + : ObjectModelBase(name), detector_(detector) {}
93 +
94 + protected:
95 + const DetectorType* const detector_;
96 +
97 + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>);
98 +};
99 +
100 +} // namespace tf_tracking
101 +
102 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifdef __RENDER_OPENGL__
17 +#include <GLES/gl.h>
18 +#include <GLES/glext.h>
19 +#endif
20 +
21 +#include <cinttypes>
22 +#include <map>
23 +#include <string>
24 +
25 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
27 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
28 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
29 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
30 +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
31 +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
32 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
33 +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
34 +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
35 +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
36 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
37 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
38 +
39 +namespace tf_tracking {
40 +
41 +ObjectTracker::ObjectTracker(const TrackerConfig* const config,
42 + ObjectDetectorBase* const detector)
43 + : config_(config),
44 + frame_width_(config->image_size.width),
45 + frame_height_(config->image_size.height),
46 + curr_time_(0),
47 + num_frames_(0),
48 + flow_cache_(&config->flow_config),
49 + keypoint_detector_(&config->keypoint_detector_config),
50 + curr_num_frame_pairs_(0),
51 + first_frame_index_(0),
52 + frame1_(new ImageData(frame_width_, frame_height_)),
53 + frame2_(new ImageData(frame_width_, frame_height_)),
54 + detector_(detector),
55 + num_detected_(0) {
56 + for (int i = 0; i < kNumFrames; ++i) {
57 + frame_pairs_[i].Init(-1, -1);
58 + }
59 +}
60 +
61 +
62 +ObjectTracker::~ObjectTracker() {
63 + for (TrackedObjectMap::iterator iter = objects_.begin();
64 + iter != objects_.end(); iter++) {
65 + TrackedObject* object = iter->second;
66 + SAFE_DELETE(object);
67 + }
68 +}
69 +
70 +
71 +// Finds the correspondences for all the points in the current pair of frames.
72 +// Stores the results in the given FramePair.
73 +void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
74 + // Keypoints aren't found until they're found.
75 + memset(frame_pair->optical_flow_found_keypoint_, false,
76 + sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
77 + TimeLog("Cleared old found keypoints");
78 +
79 + int num_keypoints_found = 0;
80 +
81 + // For every keypoint...
82 + for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
83 + Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
84 + Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
85 +
86 + if (flow_cache_.FindNewPositionOfPoint(
87 + keypoint1->pos_.x, keypoint1->pos_.y,
88 + &keypoint2->pos_.x, &keypoint2->pos_.y)) {
89 + frame_pair->optical_flow_found_keypoint_[i_feat] = true;
90 + ++num_keypoints_found;
91 + }
92 + }
93 +
94 + TimeLog("Found correspondences");
95 +
96 + LOGV("Found %d of %d keypoint correspondences",
97 + num_keypoints_found, frame_pair->number_of_keypoints_);
98 +}
99 +
100 +void ObjectTracker::NextFrame(const uint8_t* const new_frame,
101 + const uint8_t* const uv_frame,
102 + const int64_t timestamp,
103 + const float* const alignment_matrix_2x3) {
104 + IncrementFrameIndex();
105 + LOGV("Received frame %d", num_frames_);
106 +
107 + FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
108 + curr_change->Init(curr_time_, timestamp);
109 +
110 + CHECK_ALWAYS(curr_time_ < timestamp,
111 + "Timestamp must monotonically increase! Went from %" PRId64
112 + " to %" PRId64 " on frame %d.",
113 + curr_time_, timestamp, num_frames_);
114 +
115 + curr_time_ = timestamp;
116 +
117 + // Swap the frames.
118 + frame1_.swap(frame2_);
119 +
120 + frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
121 +
122 + if (detector_.get() != NULL) {
123 + detector_->SetImageData(frame2_.get());
124 + }
125 +
126 + flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
127 +
128 + if (num_frames_ == 1) {
129 + // This must be the first frame, so abort.
130 + return;
131 + }
132 +
133 + if (config_->always_track || objects_.size() > 0) {
134 + LOGV("Tracking %zu targets", objects_.size());
135 + ComputeKeypoints(true);
136 + TimeLog("Keypoints computed!");
137 +
138 + FindCorrespondences(curr_change);
139 + TimeLog("Flow computed!");
140 +
141 + TrackObjects();
142 + }
143 + TimeLog("Targets tracked!");
144 +
145 + if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
146 + DetectTargets();
147 + }
148 + TimeLog("Detected objects.");
149 +}
150 +
151 +TrackedObject* ObjectTracker::MaybeAddObject(
152 + const std::string& id, const Image<uint8_t>& source_image,
153 + const BoundingBox& bounding_box, const ObjectModelBase* object_model) {
154 + // Train the detector if this is a new object.
155 + if (objects_.find(id) != objects_.end()) {
156 + return objects_[id];
157 + }
158 +
159 + // Need to get a non-const version of the model, or create a new one if it
160 + // wasn't given.
161 + ObjectModelBase* model = NULL;
162 + if (detector_ != NULL) {
163 + // If a detector is registered, then this new object must have a model.
164 + CHECK_ALWAYS(object_model != NULL, "No model given!");
165 + model = detector_->CreateObjectModel(object_model->GetName());
166 + }
167 + TrackedObject* const object =
168 + new TrackedObject(id, source_image, bounding_box, model);
169 +
170 + objects_[id] = object;
171 + return object;
172 +}
173 +
174 +void ObjectTracker::RegisterNewObjectWithAppearance(
175 + const std::string& id, const uint8_t* const new_frame,
176 + const BoundingBox& bounding_box) {
177 + ObjectModelBase* object_model = NULL;
178 +
179 + Image<uint8_t> image(frame_width_, frame_height_);
180 + image.FromArray(new_frame, frame_width_, 1);
181 +
182 + if (detector_ != NULL) {
183 + object_model = detector_->CreateObjectModel(id);
184 + CHECK_ALWAYS(object_model != NULL, "Null object model!");
185 +
186 + const IntegralImage integral_image(image);
187 + object_model->TrackStep(bounding_box, image, integral_image, true);
188 + }
189 +
190 + // Create an object at this position.
191 + CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
192 + if (objects_.find(id) == objects_.end()) {
193 + TrackedObject* const object =
194 + MaybeAddObject(id, image, bounding_box, object_model);
195 + CHECK_ALWAYS(object != NULL, "Object not created!");
196 + }
197 +}
198 +
199 +void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
200 + const BoundingBox& bounding_box,
201 + const int64_t timestamp) {
202 + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
203 + CHECK_ALWAYS(timestamp <= curr_time_,
204 + "Timestamp too great! %" PRId64 " vs %" PRId64, timestamp,
205 + curr_time_);
206 +
207 + TrackedObject* const object = GetObject(id);
208 +
209 + // Track this bounding box from the past to the current time.
210 + const BoundingBox current_position = TrackBox(bounding_box, timestamp);
211 +
212 + object->UpdatePosition(current_position, curr_time_, *frame2_, false);
213 +
214 + VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
215 + << std::endl;
216 +}
217 +
218 +
219 +void ObjectTracker::SetCurrentPositionOfObject(
220 + const std::string& id, const BoundingBox& bounding_box) {
221 + SetPreviousPositionOfObject(id, bounding_box, curr_time_);
222 +}
223 +
224 +
225 +void ObjectTracker::ForgetTarget(const std::string& id) {
226 + LOGV("Forgetting object %s", id.c_str());
227 + TrackedObject* const object = GetObject(id);
228 + delete object;
229 + objects_.erase(id);
230 +
231 + if (detector_ != NULL) {
232 + detector_->DeleteObjectModel(id);
233 + }
234 +}
235 +
236 +int ObjectTracker::GetKeypointsPacked(uint16_t* const out_data,
237 + const float scale) const {
238 + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
239 + uint16_t* curr_data = out_data;
240 + int num_keypoints = 0;
241 +
242 + for (int i = 0; i < change.number_of_keypoints_; ++i) {
243 + if (change.optical_flow_found_keypoint_[i]) {
244 + ++num_keypoints;
245 + const Point2f& point1 = change.frame1_keypoints_[i].pos_;
246 + *curr_data++ = RealToFixed115(point1.x * scale);
247 + *curr_data++ = RealToFixed115(point1.y * scale);
248 +
249 + const Point2f& point2 = change.frame2_keypoints_[i].pos_;
250 + *curr_data++ = RealToFixed115(point2.x * scale);
251 + *curr_data++ = RealToFixed115(point2.y * scale);
252 + }
253 + }
254 +
255 + return num_keypoints;
256 +}
257 +
258 +
259 +int ObjectTracker::GetKeypoints(const bool only_found,
260 + float* const out_data) const {
261 + int curr_keypoint = 0;
262 + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
263 +
264 + for (int i = 0; i < change.number_of_keypoints_; ++i) {
265 + if (!only_found || change.optical_flow_found_keypoint_[i]) {
266 + const int base = curr_keypoint * kKeypointStep;
267 + out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
268 + out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
269 +
270 + out_data[base + 2] =
271 + change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
272 + out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
273 + out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
274 +
275 + out_data[base + 5] = change.frame1_keypoints_[i].score_;
276 + out_data[base + 6] = change.frame1_keypoints_[i].type_;
277 + ++curr_keypoint;
278 + }
279 + }
280 +
281 + LOGV("Got %d keypoints.", curr_keypoint);
282 +
283 + return curr_keypoint;
284 +}
285 +
286 +
287 +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
288 + const FramePair& frame_pair) const {
289 + float translation_x;
290 + float translation_y;
291 +
292 + float scale_x;
293 + float scale_y;
294 +
295 + BoundingBox tracked_box(region);
296 + frame_pair.AdjustBox(
297 + tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
298 +
299 + tracked_box.Shift(Point2f(translation_x, translation_y));
300 +
301 + if (scale_x > 0 && scale_y > 0) {
302 + tracked_box.Scale(scale_x, scale_y);
303 + }
304 + return tracked_box;
305 +}
306 +
307 +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
308 + const int64_t timestamp) const {
309 + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
310 + CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
311 +
312 + // Anything that ended before the requested timestamp is of no concern to us.
313 + bool found_it = false;
314 + int num_frames_back = -1;
315 + for (int i = 0; i < curr_num_frame_pairs_; ++i) {
316 + const FramePair& frame_pair =
317 + frame_pairs_[GetNthIndexFromEnd(i)];
318 +
319 + if (frame_pair.end_time_ <= timestamp) {
320 + num_frames_back = i - 1;
321 +
322 + if (num_frames_back > 0) {
323 + LOGV("Went %d out of %d frames before finding frame. (index: %d)",
324 + num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
325 + }
326 +
327 + found_it = true;
328 + break;
329 + }
330 + }
331 +
332 + if (!found_it) {
333 + LOGW("History did not go back far enough! %" PRId64 " vs %" PRId64,
334 + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
335 + frame_pairs_[GetNthIndexFromStart(0)].end_time_,
336 + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
337 + }
338 +
339 + // Loop over all the frames in the queue, tracking the accumulated delta
340 + // of the point from frame to frame. It's possible the point could
341 + // go out of frame, but keep tracking as best we can, using points near
342 + // the edge of the screen where it went out of bounds.
343 + BoundingBox tracked_box(region);
344 + for (int i = num_frames_back; i >= 0; --i) {
345 + const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
346 + SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
347 + tracked_box = TrackBox(tracked_box, frame_pair);
348 + }
349 + return tracked_box;
350 +}
351 +
352 +
353 +// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
354 +// 3d transformation matrix.
355 +inline void Convert3x3To4x4(
356 + const float* const in_matrix, float* const out_matrix) {
357 + // X
358 + out_matrix[0] = in_matrix[0];
359 + out_matrix[1] = in_matrix[3];
360 + out_matrix[2] = 0.0f;
361 + out_matrix[3] = 0.0f;
362 +
363 + // Y
364 + out_matrix[4] = in_matrix[1];
365 + out_matrix[5] = in_matrix[4];
366 + out_matrix[6] = 0.0f;
367 + out_matrix[7] = 0.0f;
368 +
369 + // Z
370 + out_matrix[8] = 0.0f;
371 + out_matrix[9] = 0.0f;
372 + out_matrix[10] = 1.0f;
373 + out_matrix[11] = 0.0f;
374 +
375 + // Translation
376 + out_matrix[12] = in_matrix[2];
377 + out_matrix[13] = in_matrix[5];
378 + out_matrix[14] = 0.0f;
379 + out_matrix[15] = 1.0f;
380 +}
381 +
382 +
383 +void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
384 + const float* const frame_to_canvas) const {
385 +#ifdef __RENDER_OPENGL__
386 + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
387 +
388 + glMatrixMode(GL_PROJECTION);
389 + glLoadIdentity();
390 +
391 + glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
392 +
393 + // To make Y go the right direction (0 at top of frame).
394 + glScalef(1.0f, -1.0f, 1.0f);
395 + glTranslatef(0.0f, -canvas_height, 0.0f);
396 +
397 + glMatrixMode(GL_MODELVIEW);
398 + glLoadIdentity();
399 +
400 + glPushMatrix();
401 +
402 + // Apply the frame to canvas transformation.
403 + static GLfloat transformation[16];
404 + Convert3x3To4x4(frame_to_canvas, transformation);
405 + glMultMatrixf(transformation);
406 +
407 + // Draw tracked object bounding boxes.
408 + for (TrackedObjectMap::const_iterator iter = objects_.begin();
409 + iter != objects_.end(); ++iter) {
410 + TrackedObject* tracked_object = iter->second;
411 + tracked_object->Draw();
412 + }
413 +
414 + static const bool kRenderDebugPyramid = false;
415 + if (kRenderDebugPyramid) {
416 + glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
417 + for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
418 + Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
419 + }
420 + }
421 +
422 + static const bool kRenderDebugDerivative = false;
423 + if (kRenderDebugDerivative) {
424 + glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
425 + for (int i = 0; i < kNumPyramidLevels; ++i) {
426 + const Image<int32_t>& dx = *frame1_->GetSpatialX(i);
427 + Image<uint8_t> render_image(dx.GetWidth(), dx.GetHeight());
428 + for (int y = 0; y < dx.GetHeight(); ++y) {
429 + const int32_t* dx_ptr = dx[y];
430 + uint8_t* dst_ptr = render_image[y];
431 + for (int x = 0; x < dx.GetWidth(); ++x) {
432 + *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
433 + }
434 + }
435 +
436 + Sprite(render_image).Draw();
437 + }
438 + }
439 +
440 + if (detector_ != NULL) {
441 + glDisable(GL_CULL_FACE);
442 + detector_->Draw();
443 + }
444 + glPopMatrix();
445 +#endif
446 +}
447 +
448 +static void AddQuadrants(const BoundingBox& box,
449 + std::vector<BoundingBox>* boxes) {
450 + const Point2f center = box.GetCenter();
451 +
452 + float x1 = box.left_;
453 + float x2 = center.x;
454 + float x3 = box.right_;
455 +
456 + float y1 = box.top_;
457 + float y2 = center.y;
458 + float y3 = box.bottom_;
459 +
460 + // Upper left.
461 + boxes->push_back(BoundingBox(x1, y1, x2, y2));
462 +
463 + // Upper right.
464 + boxes->push_back(BoundingBox(x2, y1, x3, y2));
465 +
466 + // Bottom left.
467 + boxes->push_back(BoundingBox(x1, y2, x2, y3));
468 +
469 + // Bottom right.
470 + boxes->push_back(BoundingBox(x2, y2, x3, y3));
471 +
472 + // Whole thing.
473 + boxes->push_back(box);
474 +}
475 +
476 +void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
477 + const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
478 + FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
479 +
480 + std::vector<BoundingBox> boxes;
481 +
482 + for (TrackedObjectMap::iterator object_iter = objects_.begin();
483 + object_iter != objects_.end(); ++object_iter) {
484 + BoundingBox box = object_iter->second->GetPosition();
485 + box.Scale(config_->object_box_scale_factor_for_features,
486 + config_->object_box_scale_factor_for_features);
487 + AddQuadrants(box, &boxes);
488 + }
489 +
490 + AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
491 +
492 + keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
493 +}
494 +
495 +
496 +// Given a vector of detections and a model, simply returns the Detection for
497 +// that model with the highest correlation.
498 +bool ObjectTracker::GetBestObjectForDetection(
499 + const Detection& detection, TrackedObject** match) const {
500 + TrackedObject* best_match = NULL;
501 + float best_overlap = -FLT_MAX;
502 +
503 + LOGV("Looking for matches in %zu objects!", objects_.size());
504 + for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
505 + object_iter != objects_.end(); ++object_iter) {
506 + TrackedObject* const tracked_object = object_iter->second;
507 +
508 + const float overlap = tracked_object->GetPosition().PascalScore(
509 + detection.GetObjectBoundingBox());
510 +
511 + if (!detector_->AllowSpontaneousDetections() &&
512 + (detection.GetObjectModel() != tracked_object->GetModel())) {
513 + if (overlap > 0.0f) {
514 + return false;
515 + }
516 + continue;
517 + }
518 +
519 + const float jump_distance =
520 + (tracked_object->GetPosition().GetCenter() -
521 + detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
522 +
523 + const float allowed_distance =
524 + tracked_object->GetAllowableDistanceSquared();
525 +
526 + LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
527 + jump_distance, allowed_distance, overlap);
528 +
529 + // TODO(andrewharp): No need to do this verification twice, eliminate
530 + // one of the score checks (the other being in OnDetection).
531 + if (jump_distance < allowed_distance &&
532 + overlap > best_overlap &&
533 + tracked_object->GetMatchScore() + kMatchScoreBuffer <
534 + detection.GetMatchScore()) {
535 + best_match = tracked_object;
536 + best_overlap = overlap;
537 + } else if (overlap > 0.0f) {
538 + return false;
539 + }
540 + }
541 +
542 + *match = best_match;
543 + return true;
544 +}
545 +
546 +
547 +void ObjectTracker::ProcessDetections(
548 + std::vector<Detection>* const detections) {
549 + LOGV("Initial detection done, iterating over %zu detections now.",
550 + detections->size());
551 +
552 + const bool spontaneous_detections_allowed =
553 + detector_->AllowSpontaneousDetections();
554 + for (std::vector<Detection>::const_iterator it = detections->begin();
555 + it != detections->end(); ++it) {
556 + const Detection& detection = *it;
557 + SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
558 + "Frame does not contain bounding box!");
559 +
560 + TrackedObject* best_match = NULL;
561 +
562 + const bool no_collisions =
563 + GetBestObjectForDetection(detection, &best_match);
564 +
565 + // Need to get a non-const version of the model, or create a new one if it
566 + // wasn't given.
567 + ObjectModelBase* model =
568 + const_cast<ObjectModelBase*>(detection.GetObjectModel());
569 +
570 + if (best_match != NULL) {
571 + if (model != best_match->GetModel()) {
572 + CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
573 + "Model for object changed but spontaneous detections not allowed!");
574 + }
575 + best_match->OnDetection(model,
576 + detection.GetObjectBoundingBox(),
577 + detection.GetMatchScore(),
578 + curr_time_, *frame2_);
579 + } else if (no_collisions && spontaneous_detections_allowed) {
580 + if (detection.GetMatchScore() > kMinimumMatchScore) {
581 + LOGV("No match, adding it!");
582 + const ObjectModelBase* model = detection.GetObjectModel();
583 + std::ostringstream ss;
584 + // TODO(andrewharp): Generate this in a more general fashion.
585 + ss << "hand_" << num_detected_++;
586 + std::string object_name = ss.str();
587 + MaybeAddObject(object_name, *frame2_->GetImage(),
588 + detection.GetObjectBoundingBox(), model);
589 + }
590 + }
591 + }
592 +}
593 +
594 +
595 +void ObjectTracker::DetectTargets() {
596 + // Detect all object model types that we're currently tracking.
597 + std::vector<const ObjectModelBase*> object_models;
598 + detector_->GetObjectModels(&object_models);
599 + if (object_models.size() == 0) {
600 + LOGV("No objects to search for, aborting.");
601 + return;
602 + }
603 +
604 + LOGV("Trying to detect %zu models", object_models.size());
605 +
606 + LOGV("Creating test vector!");
607 + std::vector<BoundingSquare> positions;
608 +
609 + for (TrackedObjectMap::iterator object_iter = objects_.begin();
610 + object_iter != objects_.end(); ++object_iter) {
611 + TrackedObject* const tracked_object = object_iter->second;
612 +
613 +#if DEBUG_PREDATOR
614 + positions.push_back(GetCenteredSquare(
615 + frame2_->GetImage()->GetContainingBox(), 32.0f));
616 +#else
617 + const BoundingBox& position = tracked_object->GetPosition();
618 +
619 + const float square_size = MAX(
620 + kScanMinSquareSize / (kLastKnownPositionScaleFactor *
621 + kLastKnownPositionScaleFactor),
622 + MIN(position.GetWidth(),
623 + position.GetHeight())) / kLastKnownPositionScaleFactor;
624 +
625 + FillWithSquares(frame2_->GetImage()->GetContainingBox(),
626 + tracked_object->GetPosition(),
627 + square_size,
628 + kScanMinSquareSize,
629 + kLastKnownPositionScaleFactor,
630 + &positions);
631 + }
632 +#endif
633 +
634 + LOGV("Created test vector!");
635 +
636 + std::vector<Detection> detections;
637 + LOGV("Detecting!");
638 + detector_->Detect(positions, &detections);
639 + LOGV("Found %zu detections", detections.size());
640 +
641 + TimeLog("Finished detection.");
642 +
643 + ProcessDetections(&detections);
644 +
645 + TimeLog("iterated over detections");
646 +
647 + LOGV("Done detecting!");
648 +}
649 +
650 +
651 +void ObjectTracker::TrackObjects() {
652 + // TODO(andrewharp): Correlation should be allowed to remove objects too.
653 + const bool automatic_removal_allowed = detector_.get() != NULL ?
654 + detector_->AllowSpontaneousDetections() : false;
655 +
656 + LOGV("Tracking %zu objects!", objects_.size());
657 + std::vector<std::string> dead_objects;
658 + for (TrackedObjectMap::iterator iter = objects_.begin();
659 + iter != objects_.end(); iter++) {
660 + TrackedObject* object = iter->second;
661 + const BoundingBox tracked_position = TrackBox(
662 + object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
663 + object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
664 +
665 + if (automatic_removal_allowed &&
666 + object->GetNumConsecutiveFramesBelowThreshold() >
667 + kMaxNumDetectionFailures * 5) {
668 + dead_objects.push_back(iter->first);
669 + }
670 + }
671 +
672 + if (detector_ != NULL && automatic_removal_allowed) {
673 + for (std::vector<std::string>::iterator iter = dead_objects.begin();
674 + iter != dead_objects.end(); iter++) {
675 + LOGE("Removing object! %s", iter->c_str());
676 + ForgetTarget(*iter);
677 + }
678 + }
679 + TimeLog("Tracked all objects.");
680 +
681 + LOGV("%zu objects tracked!", objects_.size());
682 +}
683 +
684 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
18 +
19 +#include <map>
20 +#include <string>
21 +
22 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
23 +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
27 +
28 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
29 +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
30 +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
31 +#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
32 +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
33 +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
34 +
35 +namespace tf_tracking {
36 +
37 +typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
38 +
39 +inline std::ostream& operator<<(std::ostream& stream,
40 + const TrackedObjectMap& map) {
41 + for (TrackedObjectMap::const_iterator iter = map.begin();
42 + iter != map.end(); ++iter) {
43 + const TrackedObject& tracked_object = *iter->second;
44 + const std::string& key = iter->first;
45 + stream << key << ": " << tracked_object;
46 + }
47 + return stream;
48 +}
49 +
50 +
51 +// ObjectTracker is the highest-level class in the tracking/detection framework.
52 +// It handles basic image processing, keypoint detection, keypoint tracking,
53 +// object tracking, and object detection/relocalization.
54 +class ObjectTracker {
55 + public:
56 + ObjectTracker(const TrackerConfig* const config,
57 + ObjectDetectorBase* const detector);
58 + virtual ~ObjectTracker();
59 +
60 + virtual void NextFrame(const uint8_t* const new_frame,
61 + const int64_t timestamp,
62 + const float* const alignment_matrix_2x3) {
63 + NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
64 + }
65 +
66 + // Called upon the arrival of a new frame of raw data.
67 + // Does all image processing, keypoint detection, and object
68 + // tracking/detection for registered objects.
69 + // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
70 + // represents the main transformation that has happened between the last
71 + // and the current frame.
72 + // Argument align_level is the pyramid level (where 0 == finest) that
73 + // the matrix is valid for.
74 + virtual void NextFrame(const uint8_t* const new_frame,
75 + const uint8_t* const uv_frame, const int64_t timestamp,
76 + const float* const alignment_matrix_2x3);
77 +
78 + virtual void RegisterNewObjectWithAppearance(const std::string& id,
79 + const uint8_t* const new_frame,
80 + const BoundingBox& bounding_box);
81 +
82 + // Updates the position of a tracked object, given that it was known to be at
83 + // a certain position at some point in the past.
84 + virtual void SetPreviousPositionOfObject(const std::string& id,
85 + const BoundingBox& bounding_box,
86 + const int64_t timestamp);
87 +
88 + // Sets the current position of the object in the most recent frame provided.
89 + virtual void SetCurrentPositionOfObject(const std::string& id,
90 + const BoundingBox& bounding_box);
91 +
92 + // Tells the ObjectTracker to stop tracking a target.
93 + void ForgetTarget(const std::string& id);
94 +
95 + // Fills the given out_data buffer with the latest detected keypoint
96 + // correspondences, first scaled by scale_factor (to adjust for downsampling
97 + // that may have occurred elsewhere), then packed in a fixed-point format.
98 + int GetKeypointsPacked(uint16_t* const out_data,
99 + const float scale_factor) const;
100 +
101 + // Copy the keypoint arrays after computeFlow is called.
102 + // out_data should be at least kMaxKeypoints * kKeypointStep long.
103 + // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
104 + // where N is the number of keypoints tracked. N is returned as the result.
105 + int GetKeypoints(const bool only_found, float* const out_data) const;
106 +
107 + // Returns the current position of a box, given that it was at a certain
108 + // position at the given time.
109 + BoundingBox TrackBox(const BoundingBox& region,
110 + const int64_t timestamp) const;
111 +
112 + // Returns the number of frames that have been passed to NextFrame().
113 + inline int GetNumFrames() const {
114 + return num_frames_;
115 + }
116 +
117 + inline bool HaveObject(const std::string& id) const {
118 + return objects_.find(id) != objects_.end();
119 + }
120 +
121 + // Returns the TrackedObject associated with the given id.
122 + inline const TrackedObject* GetObject(const std::string& id) const {
123 + TrackedObjectMap::const_iterator iter = objects_.find(id);
124 + CHECK_ALWAYS(iter != objects_.end(),
125 + "Unknown object key! \"%s\"", id.c_str());
126 + TrackedObject* const object = iter->second;
127 + return object;
128 + }
129 +
130 + // Returns the TrackedObject associated with the given id.
131 + inline TrackedObject* GetObject(const std::string& id) {
132 + TrackedObjectMap::iterator iter = objects_.find(id);
133 + CHECK_ALWAYS(iter != objects_.end(),
134 + "Unknown object key! \"%s\"", id.c_str());
135 + TrackedObject* const object = iter->second;
136 + return object;
137 + }
138 +
139 + bool IsObjectVisible(const std::string& id) const {
140 + SCHECK(HaveObject(id), "Don't have this object.");
141 +
142 + const TrackedObject* object = GetObject(id);
143 + return object->IsVisible();
144 + }
145 +
146 + virtual void Draw(const int canvas_width, const int canvas_height,
147 + const float* const frame_to_canvas) const;
148 +
149 + protected:
150 + // Creates a new tracked object at the given position.
151 + // If an object model is provided, then that model will be associated with the
152 + // object. If not, a new model may be created from the appearance at the
153 + // initial position and registered with the object detector.
154 + virtual TrackedObject* MaybeAddObject(const std::string& id,
155 + const Image<uint8_t>& image,
156 + const BoundingBox& bounding_box,
157 + const ObjectModelBase* object_model);
158 +
159 + // Find the keypoints in the frame before the current frame.
160 + // If only one frame exists, keypoints will be found in that frame.
161 + void ComputeKeypoints(const bool cached_ok = false);
162 +
163 + // Finds the correspondences for all the points in the current pair of frames.
164 + // Stores the results in the given FramePair.
165 + void FindCorrespondences(FramePair* const curr_change) const;
166 +
167 + inline int GetNthIndexFromEnd(const int offset) const {
168 + return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
169 + }
170 +
171 + BoundingBox TrackBox(const BoundingBox& region,
172 + const FramePair& frame_pair) const;
173 +
174 + inline void IncrementFrameIndex() {
175 + // Move the current framechange index up.
176 + ++num_frames_;
177 + ++curr_num_frame_pairs_;
178 +
179 + // If we've got too many, push up the start of the queue.
180 + if (curr_num_frame_pairs_ > kNumFrames) {
181 + first_frame_index_ = GetNthIndexFromStart(1);
182 + --curr_num_frame_pairs_;
183 + }
184 + }
185 +
186 + inline int GetNthIndexFromStart(const int offset) const {
187 + SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
188 + "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_);
189 + return (first_frame_index_ + offset) % kNumFrames;
190 + }
191 +
192 + void TrackObjects();
193 +
194 + const std::unique_ptr<const TrackerConfig> config_;
195 +
196 + const int frame_width_;
197 + const int frame_height_;
198 +
199 + int64_t curr_time_;
200 +
201 + int num_frames_;
202 +
203 + TrackedObjectMap objects_;
204 +
205 + FlowCache flow_cache_;
206 +
207 + KeypointDetector keypoint_detector_;
208 +
209 + int curr_num_frame_pairs_;
210 + int first_frame_index_;
211 +
212 + std::unique_ptr<ImageData> frame1_;
213 + std::unique_ptr<ImageData> frame2_;
214 +
215 + FramePair frame_pairs_[kNumFrames];
216 +
217 + std::unique_ptr<ObjectDetectorBase> detector_;
218 +
219 + int num_detected_;
220 +
221 + private:
222 + void TrackTarget(TrackedObject* const object);
223 +
224 + bool GetBestObjectForDetection(
225 + const Detection& detection, TrackedObject** match) const;
226 +
227 + void ProcessDetections(std::vector<Detection>* const detections);
228 +
229 + void DetectTargets();
230 +
231 + // Temp object used in ObjectTracker::CreateNewExample.
232 + mutable std::vector<BoundingSquare> squares;
233 +
234 + friend std::ostream& operator<<(std::ostream& stream,
235 + const ObjectTracker& tracker);
236 +
237 + TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
238 +};
239 +
240 +inline std::ostream& operator<<(std::ostream& stream,
241 + const ObjectTracker& tracker) {
242 + stream << "Frame size: " << tracker.frame_width_ << "x"
243 + << tracker.frame_height_ << std::endl;
244 +
245 + stream << "Num frames: " << tracker.num_frames_ << std::endl;
246 +
247 + stream << "Curr time: " << tracker.curr_time_ << std::endl;
248 +
249 + const int first_frame_index = tracker.GetNthIndexFromStart(0);
250 + const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
251 +
252 + const int last_frame_index = tracker.GetNthIndexFromEnd(0);
253 + const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
254 +
255 + stream << "first frame: " << first_frame_index << ","
256 + << first_frame_pair.end_time_ << " "
257 + << "last frame: " << last_frame_index << ","
258 + << last_frame_pair.end_time_ << " diff: "
259 + << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
260 + << std::endl;
261 +
262 + stream << "Tracked targets:";
263 + stream << tracker.objects_;
264 +
265 + return stream;
266 +}
267 +
268 +} // namespace tf_tracking
269 +
270 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#include <android/log.h>
17 +#include <jni.h>
18 +#include <stdint.h>
19 +#include <stdlib.h>
20 +#include <string.h>
21 +#include <cstdint>
22 +
23 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27 +
28 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
29 +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
30 +
31 +namespace tf_tracking {
32 +
33 +#define OBJECT_TRACKER_METHOD(METHOD_NAME) \
34 + Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT
35 +
36 +JniLongField object_tracker_field("nativeObjectTracker");
37 +
38 +ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
39 + ObjectTracker* const object_tracker =
40 + reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
41 + CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
42 + return object_tracker;
43 +}
44 +
45 +void set_object_tracker(JNIEnv* env, jobject thiz,
46 + const ObjectTracker* object_tracker) {
47 + object_tracker_field.set(env, thiz,
48 + reinterpret_cast<intptr_t>(object_tracker));
49 +}
50 +
51 +#ifdef __cplusplus
52 +extern "C" {
53 +#endif
54 +JNIEXPORT
55 +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
56 + jint width, jint height,
57 + jboolean always_track);
58 +
59 +JNIEXPORT
60 +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
61 + jobject thiz);
62 +
63 +JNIEXPORT
64 +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
65 + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
66 + jfloat x2, jfloat y2, jbyteArray frame_data);
67 +
68 +JNIEXPORT
69 +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
70 + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
71 + jfloat x2, jfloat y2, jlong timestamp);
72 +
73 +JNIEXPORT
74 +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
75 + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
76 + jfloat x2, jfloat y2);
77 +
78 +JNIEXPORT
79 +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
80 + jstring object_id);
81 +
82 +JNIEXPORT
83 +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
84 + jobject thiz,
85 + jstring object_id);
86 +
87 +JNIEXPORT
88 +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
89 + jobject thiz,
90 + jstring object_id);
91 +
92 +JNIEXPORT
93 +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
94 + jobject thiz,
95 + jstring object_id);
96 +
97 +JNIEXPORT
98 +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
99 + jstring object_id);
100 +
101 +JNIEXPORT
102 +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
103 + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
104 +
105 +JNIEXPORT
106 +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
107 + jbyteArray y_data,
108 + jbyteArray uv_data,
109 + jlong timestamp,
110 + jfloatArray vg_matrix_2x3);
111 +
112 +JNIEXPORT
113 +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
114 + jstring object_id);
115 +
116 +JNIEXPORT
117 +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
118 + JNIEnv* env, jobject thiz, jfloat scale_factor);
119 +
120 +JNIEXPORT
121 +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
122 + JNIEnv* env, jobject thiz, jboolean only_found_);
123 +
124 +JNIEXPORT
125 +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
126 + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
127 + jfloat position_y1, jfloat position_x2, jfloat position_y2,
128 + jfloatArray delta);
129 +
130 +JNIEXPORT
131 +void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
132 + jint view_width,
133 + jint view_height,
134 + jfloatArray delta);
135 +
136 +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
137 + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
138 + jbyteArray input, jint factor, jbyteArray output);
139 +
140 +#ifdef __cplusplus
141 +}
142 +#endif
143 +
144 +JNIEXPORT
145 +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
146 + jint width, jint height,
147 + jboolean always_track) {
148 + LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
149 + const Size image_size(width, height);
150 + TrackerConfig* const tracker_config = new TrackerConfig(image_size);
151 + tracker_config->always_track = always_track;
152 +
153 + // XXX detector
154 + ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
155 + set_object_tracker(env, thiz, tracker);
156 + LOGI("Initialized!");
157 +
158 + CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
159 + "Failure to set hand tracker!");
160 +}
161 +
162 +JNIEXPORT
163 +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
164 + jobject thiz) {
165 + delete get_object_tracker(env, thiz);
166 + set_object_tracker(env, thiz, NULL);
167 +}
168 +
169 +JNIEXPORT
170 +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
171 + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
172 + jfloat x2, jfloat y2, jbyteArray frame_data) {
173 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
174 +
175 + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
176 + x2, y2);
177 +
178 + jboolean iCopied = JNI_FALSE;
179 +
180 + // Copy image into currFrame.
181 + jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
182 +
183 + BoundingBox bounding_box(x1, y1, x2, y2);
184 + get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
185 + id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box);
186 +
187 + env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
188 +
189 + env->ReleaseStringUTFChars(object_id, id_str);
190 +}
191 +
192 +JNIEXPORT
193 +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
194 + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
195 + jfloat x2, jfloat y2, jlong timestamp) {
196 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
197 +
198 + LOGI(
199 + "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
200 + " at time %lld",
201 + id_str, x1, y1, x2, y2, static_cast<long long>(timestamp));
202 +
203 + get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
204 + id_str, BoundingBox(x1, y1, x2, y2), timestamp);
205 +
206 + env->ReleaseStringUTFChars(object_id, id_str);
207 +}
208 +
209 +JNIEXPORT
210 +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
211 + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
212 + jfloat x2, jfloat y2) {
213 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
214 +
215 + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
216 + x2, y2);
217 +
218 + get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
219 + id_str, BoundingBox(x1, y1, x2, y2));
220 +
221 + env->ReleaseStringUTFChars(object_id, id_str);
222 +}
223 +
224 +JNIEXPORT
225 +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
226 + jstring object_id) {
227 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
228 +
229 + const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
230 + env->ReleaseStringUTFChars(object_id, id_str);
231 + return haveObject;
232 +}
233 +
234 +JNIEXPORT
235 +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
236 + jobject thiz,
237 + jstring object_id) {
238 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
239 +
240 + const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
241 + env->ReleaseStringUTFChars(object_id, id_str);
242 + return visible;
243 +}
244 +
245 +JNIEXPORT
246 +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
247 + jobject thiz,
248 + jstring object_id) {
249 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
250 + const TrackedObject* const object =
251 + get_object_tracker(env, thiz)->GetObject(id_str);
252 + env->ReleaseStringUTFChars(object_id, id_str);
253 + jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
254 + return model_name;
255 +}
256 +
257 +JNIEXPORT
258 +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
259 + jobject thiz,
260 + jstring object_id) {
261 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
262 +
263 + const float correlation =
264 + get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
265 + env->ReleaseStringUTFChars(object_id, id_str);
266 + return correlation;
267 +}
268 +
269 +JNIEXPORT
270 +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
271 + jstring object_id) {
272 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
273 +
274 + const float match_score =
275 + get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
276 + env->ReleaseStringUTFChars(object_id, id_str);
277 + return match_score;
278 +}
279 +
280 +JNIEXPORT
281 +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
282 + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
283 + jboolean iCopied = JNI_FALSE;
284 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
285 +
286 + const BoundingBox bounding_box =
287 + get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
288 + env->ReleaseStringUTFChars(object_id, id_str);
289 +
290 + jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
291 + bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
292 + env->ReleaseFloatArrayElements(rect_array, rect, 0);
293 +}
294 +
295 +JNIEXPORT
296 +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
297 + jbyteArray y_data,
298 + jbyteArray uv_data,
299 + jlong timestamp,
300 + jfloatArray vg_matrix_2x3) {
301 + TimeLog("Starting object tracker");
302 +
303 + jboolean iCopied = JNI_FALSE;
304 +
305 + float vision_gyro_matrix_array[6];
306 + jfloat* jmat = NULL;
307 +
308 + if (vg_matrix_2x3 != NULL) {
309 + // Copy the alignment matrix into a float array.
310 + jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
311 + for (int i = 0; i < 6; ++i) {
312 + vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
313 + }
314 + }
315 + // Copy image into currFrame.
316 + jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
317 + jbyte* uv_pixels =
318 + uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
319 +
320 + TimeLog("Got elements");
321 +
322 + // Add the frame to the object tracker object.
323 + get_object_tracker(env, thiz)->NextFrame(
324 + reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels),
325 + timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
326 +
327 + env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
328 +
329 + if (uv_data != NULL) {
330 + env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
331 + }
332 +
333 + if (vg_matrix_2x3 != NULL) {
334 + env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
335 + }
336 +
337 + TimeLog("Released elements");
338 +
339 + PrintTimeLog();
340 + ResetTimeLog();
341 +}
342 +
343 +JNIEXPORT
344 +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
345 + jstring object_id) {
346 + const char* const id_str = env->GetStringUTFChars(object_id, 0);
347 +
348 + get_object_tracker(env, thiz)->ForgetTarget(id_str);
349 +
350 + env->ReleaseStringUTFChars(object_id, id_str);
351 +}
352 +
353 +JNIEXPORT
354 +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
355 + JNIEnv* env, jobject thiz, jboolean only_found) {
356 + jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
357 +
358 + const int number_of_keypoints =
359 + get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
360 +
361 + // Create and return the array that will be passed back to Java.
362 + jfloatArray keypoints =
363 + env->NewFloatArray(number_of_keypoints * kKeypointStep);
364 + if (keypoints == NULL) {
365 + LOGE("null array!");
366 + return NULL;
367 + }
368 + env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
369 + keypoint_arr);
370 +
371 + return keypoints;
372 +}
373 +
374 +JNIEXPORT
375 +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
376 + JNIEnv* env, jobject thiz, jfloat scale_factor) {
377 + // 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint.
378 + const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2;
379 + jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
380 +
381 + const int number_of_keypoints =
382 + get_object_tracker(env, thiz)->GetKeypointsPacked(
383 + reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor);
384 +
385 + // Create and return the array that will be passed back to Java.
386 + jbyteArray keypoints =
387 + env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
388 +
389 + if (keypoints == NULL) {
390 + LOGE("null array!");
391 + return NULL;
392 + }
393 +
394 + env->SetByteArrayRegion(
395 + keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
396 +
397 + return keypoints;
398 +}
399 +
400 +JNIEXPORT
401 +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
402 + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
403 + jfloat position_y1, jfloat position_x2, jfloat position_y2,
404 + jfloatArray delta) {
405 + jfloat point_arr[4];
406 +
407 + const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
408 + BoundingBox(position_x1, position_y1, position_x2, position_y2),
409 + timestamp);
410 +
411 + new_position.CopyToArray(point_arr);
412 + env->SetFloatArrayRegion(delta, 0, 4, point_arr);
413 +}
414 +
415 +JNIEXPORT
416 +void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
417 + JNIEnv* env, jobject thiz, jint view_width, jint view_height,
418 + jfloatArray frame_to_canvas_arr) {
419 + ObjectTracker* object_tracker = get_object_tracker(env, thiz);
420 + if (object_tracker != NULL) {
421 + jfloat* frame_to_canvas =
422 + env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
423 +
424 + object_tracker->Draw(view_width, view_height, frame_to_canvas);
425 + env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
426 + JNI_ABORT);
427 + }
428 +}
429 +
430 +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
431 + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
432 + jbyteArray input, jint factor, jbyteArray output) {
433 + if (input == NULL || output == NULL) {
434 + LOGW("Received null arrays, hopefully this is a test!");
435 + return;
436 + }
437 +
438 + jbyte* const input_array = env->GetByteArrayElements(input, 0);
439 + jbyte* const output_array = env->GetByteArrayElements(output, 0);
440 +
441 + {
442 + tf_tracking::Image<uint8_t> full_image(
443 + width, height, reinterpret_cast<uint8_t*>(input_array), false);
444 +
445 + const int new_width = (width + factor - 1) / factor;
446 + const int new_height = (height + factor - 1) / factor;
447 +
448 + tf_tracking::Image<uint8_t> downsampled_image(
449 + new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false);
450 +
451 + downsampled_image.DownsampleAveraged(
452 + reinterpret_cast<uint8_t*>(input_array), row_stride, factor);
453 + }
454 +
455 + env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
456 + env->ReleaseByteArrayElements(output, output_array, 0);
457 +}
458 +
459 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#include <math.h>
17 +
18 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
19 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
20 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
21 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 +
24 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
27 +#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
28 +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
29 +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
30 +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
31 +
32 +namespace tf_tracking {
33 +
34 +OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config)
35 + : config_(config),
36 + frame1_(NULL),
37 + frame2_(NULL),
38 + working_size_(config->image_size) {}
39 +
40 +
41 +void OpticalFlow::NextFrame(const ImageData* const image_data) {
42 + // Special case for the first frame: make sure the image ends up in
43 + // frame1_ so that keypoint detection can be done on it if desired.
44 + frame1_ = (frame1_ == NULL) ? image_data : frame2_;
45 + frame2_ = image_data;
46 +}
47 +
48 +
49 +// Static heart of the optical flow computation.
50 +// Lucas Kanade algorithm.
51 +bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8_t>& img_I,
52 + const Image<uint8_t>& img_J,
53 + const Image<int32_t>& I_x,
54 + const Image<int32_t>& I_y, const float p_x,
55 + const float p_y, float* out_g_x,
56 + float* out_g_y) {
57 + float g_x = *out_g_x;
58 + float g_y = *out_g_y;
59 + // Get values for frame 1. They remain constant through the inner
60 + // iteration loop.
61 + float vals_I[kFlowArraySize];
62 + float vals_I_x[kFlowArraySize];
63 + float vals_I_y[kFlowArraySize];
64 +
65 + const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1;
66 + const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize);
67 +
68 +#if USE_FIXED_POINT_FLOW
69 + const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1;
70 + const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1;
71 +#else
72 + const float real_x_max = I_x.width_less_one_ - EPSILON;
73 + const float real_y_max = I_x.height_less_one_ - EPSILON;
74 +#endif
75 +
76 + // Get the window around the original point.
77 + const float src_left_real = p_x - kWindowSizeFloat;
78 + const float src_top_real = p_y - kWindowSizeFloat;
79 + float* vals_I_ptr = vals_I;
80 + float* vals_I_x_ptr = vals_I_x;
81 + float* vals_I_y_ptr = vals_I_y;
82 +#if USE_FIXED_POINT_FLOW
83 + // Source integer coordinates.
84 + const int src_left_fixed = RealToFixed1616(src_left_real);
85 + const int src_top_fixed = RealToFixed1616(src_top_real);
86 +
87 + for (int y = 0; y < kPatchSize; ++y) {
88 + const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max);
89 +
90 + for (int x = 0; x < kPatchSize; ++x) {
91 + const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max);
92 +
93 + *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y);
94 + *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
95 + *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
96 + }
97 + }
98 +#else
99 + for (int y = 0; y < kPatchSize; ++y) {
100 + const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max);
101 +
102 + for (int x = 0; x < kPatchSize; ++x) {
103 + const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max);
104 +
105 + *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos);
106 + *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos);
107 + *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos);
108 + }
109 + }
110 +#endif
111 +
112 + // Compute the spatial gradient matrix about point p.
113 + float G[] = { 0, 0, 0, 0 };
114 + CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G);
115 +
116 + // Find the inverse of G.
117 + float G_inv[4];
118 + if (!Invert2x2(G, G_inv)) {
119 + return false;
120 + }
121 +
122 +#if NORMALIZE
123 + const float mean_I = ComputeMean(vals_I, kFlowArraySize);
124 + const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I);
125 +#endif
126 +
127 + // Iterate kNumIterations times or until we converge.
128 + for (int iteration = 0; iteration < kNumIterations; ++iteration) {
129 + // Get values for frame 2.
130 + float vals_J[kFlowArraySize];
131 +
132 + // Get the window around the destination point.
133 + const float left_real = p_x + g_x - kWindowSizeFloat;
134 + const float top_real = p_y + g_y - kWindowSizeFloat;
135 + float* vals_J_ptr = vals_J;
136 +#if USE_FIXED_POINT_FLOW
137 + // The top-left sub-pixel is set for the current iteration (in 16:16
138 + // fixed). This is constant over one iteration.
139 + const int left_fixed = RealToFixed1616(left_real);
140 + const int top_fixed = RealToFixed1616(top_real);
141 +
142 + for (int win_y = 0; win_y < kPatchSize; ++win_y) {
143 + const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max);
144 + for (int win_x = 0; win_x < kPatchSize; ++win_x) {
145 + const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max);
146 + *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y);
147 + }
148 + }
149 +#else
150 + for (int win_y = 0; win_y < kPatchSize; ++win_y) {
151 + const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max);
152 + for (int win_x = 0; win_x < kPatchSize; ++win_x) {
153 + const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max);
154 + *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos);
155 + }
156 + }
157 +#endif
158 +
159 +#if NORMALIZE
160 + const float mean_J = ComputeMean(vals_J, kFlowArraySize);
161 + const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J);
162 +
163 + // TODO(andrewharp): Probably better to completely detect and handle the
164 + // "corner case" where the patch is fully outside the image diagonally.
165 + const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f;
166 +#endif
167 +
168 + // Compute image mismatch vector.
169 + float b_x = 0.0f;
170 + float b_y = 0.0f;
171 +
172 + vals_I_ptr = vals_I;
173 + vals_J_ptr = vals_J;
174 + vals_I_x_ptr = vals_I_x;
175 + vals_I_y_ptr = vals_I_y;
176 +
177 + for (int win_y = 0; win_y < kPatchSize; ++win_y) {
178 + for (int win_x = 0; win_x < kPatchSize; ++win_x) {
179 +#if NORMALIZE
180 + // Normalized Image difference.
181 + const float dI =
182 + (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio;
183 +#else
184 + const float dI = *vals_I_ptr++ - *vals_J_ptr++;
185 +#endif
186 + b_x += dI * *vals_I_x_ptr++;
187 + b_y += dI * *vals_I_y_ptr++;
188 + }
189 + }
190 +
191 + // Optical flow... solve n = G^-1 * b
192 + const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y);
193 + const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y);
194 +
195 + // Update best guess with residual displacement from this level and
196 + // iteration.
197 + g_x += n_x;
198 + g_y += n_y;
199 +
200 + // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y);
201 +
202 + // Abort early if we're already below the threshold.
203 + if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) {
204 + break;
205 + }
206 + } // Iteration.
207 +
208 + // Copy value back into output.
209 + *out_g_x = g_x;
210 + *out_g_y = g_y;
211 + return true;
212 +}
213 +
214 +
215 +// Pointwise flow using translational 2dof ESM.
216 +bool OpticalFlow::FindFlowAtPoint_ESM(
217 + const Image<uint8_t>& img_I, const Image<uint8_t>& img_J,
218 + const Image<int32_t>& I_x, const Image<int32_t>& I_y,
219 + const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x,
220 + const float p_y, float* out_g_x, float* out_g_y) {
221 + float g_x = *out_g_x;
222 + float g_y = *out_g_y;
223 + const float area_inv = 1.0f / static_cast<float>(kFlowArraySize);
224 +
225 + // Get values for frame 1. They remain constant through the inner
226 + // iteration loop.
227 + uint8_t vals_I[kFlowArraySize];
228 + uint8_t vals_J[kFlowArraySize];
229 + int16_t src_gradient_x[kFlowArraySize];
230 + int16_t src_gradient_y[kFlowArraySize];
231 +
232 + // TODO(rspring): try out the IntegerPatchAlign() method once
233 + // the code for that is in ../common.
234 + const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize);
235 + const int src_left_fixed = RealToFixed1616(p_x - wsize_float);
236 + const int src_top_fixed = RealToFixed1616(p_y - wsize_float);
237 + const int patch_size = 2 * kFlowIntegrationWindowSize + 1;
238 +
239 + // Create the keypoint template patch from a subpixel location.
240 + if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
241 + patch_size, patch_size, vals_I) ||
242 + !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
243 + patch_size, patch_size,
244 + src_gradient_x) ||
245 + !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
246 + patch_size, patch_size,
247 + src_gradient_y)) {
248 + return false;
249 + }
250 +
251 + int bright_offset = 0;
252 + int sum_diff = 0;
253 +
254 + // The top-left sub-pixel is set for the current iteration (in 16:16 fixed).
255 + // This is constant over one iteration.
256 + int left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
257 + int top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
258 +
259 + // The truncated version gives the most top-left pixel that is used.
260 + int left_trunc = left_fixed >> 16;
261 + int top_trunc = top_fixed >> 16;
262 +
263 + // Compute an initial brightness offset.
264 + if (kDoBrightnessNormalize &&
265 + left_trunc >= 0 && top_trunc >= 0 &&
266 + (left_trunc + patch_size) < img_J.width_less_one_ &&
267 + (top_trunc + patch_size) < img_J.height_less_one_) {
268 + int templ_index = 0;
269 + const uint8_t* j_row = img_J[top_trunc] + left_trunc;
270 +
271 + const int j_stride = img_J.stride();
272 +
273 + for (int y = 0; y < patch_size; ++y, j_row += j_stride) {
274 + for (int x = 0; x < patch_size; ++x) {
275 + sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++];
276 + }
277 + }
278 +
279 + bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv);
280 + }
281 +
282 + // Iterate kNumIterations times or until we go out of image.
283 + for (int iteration = 0; iteration < kNumIterations; ++iteration) {
284 + int jtj[3] = { 0, 0, 0 };
285 + int jtr[2] = { 0, 0 };
286 + sum_diff = 0;
287 +
288 + // Extract the target image values.
289 + // Extract the gradient from the target image patch and accumulate to
290 + // the gradient of the source image patch.
291 + if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed,
292 + patch_size, patch_size,
293 + vals_J)) {
294 + break;
295 + }
296 +
297 + const uint8_t* templ_row = vals_I;
298 + const uint8_t* extract_row = vals_J;
299 + const int16_t* src_dx_row = src_gradient_x;
300 + const int16_t* src_dy_row = src_gradient_y;
301 +
302 + for (int y = 0; y < patch_size; ++y, templ_row += patch_size,
303 + src_dx_row += patch_size, src_dy_row += patch_size,
304 + extract_row += patch_size) {
305 + const int fp_y = top_fixed + (y << 16);
306 + for (int x = 0; x < patch_size; ++x) {
307 + const int fp_x = left_fixed + (x << 16);
308 + int32_t target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y);
309 + int32_t target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y);
310 +
311 + // Combine the two Jacobians.
312 + // Right-shift by one to account for the fact that we add
313 + // two Jacobians.
314 + int32_t dx = (src_dx_row[x] + target_dx) >> 1;
315 + int32_t dy = (src_dy_row[x] + target_dy) >> 1;
316 +
317 + // The current residual b - h(q) == extracted - (template + offset)
318 + int32_t diff = static_cast<int32_t>(extract_row[x]) -
319 + static_cast<int32_t>(templ_row[x]) - bright_offset;
320 +
321 + jtj[0] += dx * dx;
322 + jtj[1] += dx * dy;
323 + jtj[2] += dy * dy;
324 +
325 + jtr[0] += dx * diff;
326 + jtr[1] += dy * diff;
327 +
328 + sum_diff += diff;
329 + }
330 + }
331 +
332 + const float jtr1_float = static_cast<float>(jtr[0]);
333 + const float jtr2_float = static_cast<float>(jtr[1]);
334 +
335 + // Add some baseline stability to the system.
336 + jtj[0] += kEsmRegularizer;
337 + jtj[2] += kEsmRegularizer;
338 +
339 + const int64_t prod1 = static_cast<int64_t>(jtj[0]) * jtj[2];
340 + const int64_t prod2 = static_cast<int64_t>(jtj[1]) * jtj[1];
341 +
342 + // One ESM step.
343 + const float jtj_1[4] = { static_cast<float>(jtj[2]),
344 + static_cast<float>(-jtj[1]),
345 + static_cast<float>(-jtj[1]),
346 + static_cast<float>(jtj[0]) };
347 + const double det_inv = 1.0 / static_cast<double>(prod1 - prod2);
348 +
349 + g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float);
350 + g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float);
351 +
352 + if (kDoBrightnessNormalize) {
353 + bright_offset +=
354 + static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f);
355 + }
356 +
357 + // Update top left position.
358 + left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
359 + top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
360 +
361 + left_trunc = left_fixed >> 16;
362 + top_trunc = top_fixed >> 16;
363 +
364 + // Abort iterations if we go out of borders.
365 + if (left_trunc < 0 || top_trunc < 0 ||
366 + (left_trunc + patch_size) >= J_x.width_less_one_ ||
367 + (top_trunc + patch_size) >= J_y.height_less_one_) {
368 + break;
369 + }
370 + } // Iteration.
371 +
372 + // Copy value back into output.
373 + *out_g_x = g_x;
374 + *out_g_y = g_y;
375 + return true;
376 +}
377 +
378 +
379 +bool OpticalFlow::FindFlowAtPointReversible(
380 + const int level, const float u_x, const float u_y,
381 + const bool reverse_flow,
382 + float* flow_x, float* flow_y) const {
383 + const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_;
384 + const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_;
385 +
386 + // Images I (prev) and J (next).
387 + const Image<uint8_t>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2);
388 + const Image<uint8_t>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2);
389 +
390 + // Computed gradients.
391 + const Image<int32_t>& I_x = *frame_a.GetSpatialX(level);
392 + const Image<int32_t>& I_y = *frame_a.GetSpatialY(level);
393 + const Image<int32_t>& J_x = *frame_b.GetSpatialX(level);
394 + const Image<int32_t>& J_y = *frame_b.GetSpatialY(level);
395 +
396 + // Shrink factor from original.
397 + const float shrink_factor = (1 << level);
398 +
399 + // Image position vector (p := u^l), scaled for this level.
400 + const float scaled_p_x = u_x / shrink_factor;
401 + const float scaled_p_y = u_y / shrink_factor;
402 +
403 + float scaled_flow_x = *flow_x / shrink_factor;
404 + float scaled_flow_y = *flow_y / shrink_factor;
405 +
406 + // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level,
407 + // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y);
408 +
409 + const bool success = kUseEsm ?
410 + FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y,
411 + scaled_p_x, scaled_p_y,
412 + &scaled_flow_x, &scaled_flow_y) :
413 + FindFlowAtPoint_LK(img_I, img_J, I_x, I_y,
414 + scaled_p_x, scaled_p_y,
415 + &scaled_flow_x, &scaled_flow_y);
416 +
417 + *flow_x = scaled_flow_x * shrink_factor;
418 + *flow_y = scaled_flow_y * shrink_factor;
419 +
420 + return success;
421 +}
422 +
423 +
424 +bool OpticalFlow::FindFlowAtPointSingleLevel(
425 + const int level,
426 + const float u_x, const float u_y,
427 + const bool filter_by_fb_error,
428 + float* flow_x, float* flow_y) const {
429 + if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) {
430 + return false;
431 + }
432 +
433 + if (filter_by_fb_error) {
434 + const float new_position_x = u_x + *flow_x;
435 + const float new_position_y = u_y + *flow_y;
436 +
437 + float reverse_flow_x = 0.0f;
438 + float reverse_flow_y = 0.0f;
439 +
440 + // Now find the backwards flow and confirm it lines up with the original
441 + // starting point.
442 + if (!FindFlowAtPointReversible(level, new_position_x, new_position_y,
443 + true,
444 + &reverse_flow_x, &reverse_flow_y)) {
445 + LOGE("Backward error!");
446 + return false;
447 + }
448 +
449 + const float discrepancy_length =
450 + sqrtf(Square(*flow_x + reverse_flow_x) +
451 + Square(*flow_y + reverse_flow_y));
452 +
453 + const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y));
454 +
455 + return discrepancy_length <
456 + (kMaxForwardBackwardErrorAllowed * flow_length);
457 + }
458 +
459 + return true;
460 +}
461 +
462 +
463 +// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm.
464 +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details.
465 +bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y,
466 + const bool filter_by_fb_error,
467 + float* flow_x, float* flow_y) const {
468 + const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment,
469 + kNumPyramidLevels - kNumCacheLevels);
470 +
471 + // For every level in the pyramid, update the coordinates of the best match.
472 + for (int l = max_level - 1; l >= 0; --l) {
473 + if (!FindFlowAtPointSingleLevel(l, u_x, u_y,
474 + filter_by_fb_error, flow_x, flow_y)) {
475 + return false;
476 + }
477 + }
478 +
479 + return true;
480 +}
481 +
482 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
18 +
19 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 +
24 +#include "tensorflow/examples/android/jni/object_tracking/config.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
27 +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
28 +
29 +namespace tf_tracking {
30 +
31 +class FlowCache;
32 +
33 +// Class encapsulating all the data and logic necessary for performing optical
34 +// flow.
35 +class OpticalFlow {
36 + public:
37 + explicit OpticalFlow(const OpticalFlowConfig* const config);
38 +
39 + // Add a new frame to the optical flow. Will update all the non-keypoint
40 + // related member variables.
41 + //
42 + // new_frame should be a buffer of grayscale values, one byte per pixel,
43 + // at the original frame_width and frame_height used to initialize the
44 + // OpticalFlow object. Downsampling will be handled internally.
45 + //
46 + // time_stamp should be a time in milliseconds that later calls to this and
47 + // other methods will be relative to.
48 + void NextFrame(const ImageData* const image_data);
49 +
50 + // An implementation of the Lucas-Kanade Optical Flow algorithm.
51 + static bool FindFlowAtPoint_LK(const Image<uint8_t>& img_I,
52 + const Image<uint8_t>& img_J,
53 + const Image<int32_t>& I_x,
54 + const Image<int32_t>& I_y, const float p_x,
55 + const float p_y, float* out_g_x,
56 + float* out_g_y);
57 +
58 + // Pointwise flow using translational 2dof ESM.
59 + static bool FindFlowAtPoint_ESM(
60 + const Image<uint8_t>& img_I, const Image<uint8_t>& img_J,
61 + const Image<int32_t>& I_x, const Image<int32_t>& I_y,
62 + const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x,
63 + const float p_y, float* out_g_x, float* out_g_y);
64 +
65 + // Finds the flow using a specific level, in either direction.
66 + // If reversed, the coordinates are in the context of the latest
67 + // frame, not the frame before it.
68 + // All coordinates used in parameters are global, not scaled.
69 + bool FindFlowAtPointReversible(
70 + const int level, const float u_x, const float u_y,
71 + const bool reverse_flow,
72 + float* final_x, float* final_y) const;
73 +
74 + // Finds the flow using a specific level, filterable by forward-backward
75 + // error. All coordinates used in parameters are global, not scaled.
76 + bool FindFlowAtPointSingleLevel(const int level,
77 + const float u_x, const float u_y,
78 + const bool filter_by_fb_error,
79 + float* flow_x, float* flow_y) const;
80 +
81 + // Pyramidal optical-flow using all levels.
82 + bool FindFlowAtPointPyramidal(const float u_x, const float u_y,
83 + const bool filter_by_fb_error,
84 + float* flow_x, float* flow_y) const;
85 +
86 + private:
87 + const OpticalFlowConfig* const config_;
88 +
89 + const ImageData* frame1_;
90 + const ImageData* frame2_;
91 +
92 + // Size of the internally allocated images (after original is downsampled).
93 + const Size working_size_;
94 +
95 + TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow);
96 +};
97 +
98 +} // namespace tf_tracking
99 +
100 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
18 +
19 +#ifdef __RENDER_OPENGL__
20 +
21 +#include <GLES/gl.h>
22 +#include <GLES/glext.h>
23 +
24 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
26 +
27 +namespace tf_tracking {
28 +
29 +// This class encapsulates the logic necessary to load an render image data
30 +// at the same aspect ratio as the original source.
31 +class Sprite {
32 + public:
33 + // Only create Sprites when you have an OpenGl context.
34 + explicit Sprite(const Image<uint8_t>& image) { LoadTexture(image, NULL); }
35 +
36 + Sprite(const Image<uint8_t>& image, const BoundingBox* const area) {
37 + LoadTexture(image, area);
38 + }
39 +
40 + // Also, try to only delete a Sprite when holding an OpenGl context.
41 + ~Sprite() {
42 + glDeleteTextures(1, &texture_);
43 + }
44 +
45 + inline int GetWidth() const {
46 + return actual_width_;
47 + }
48 +
49 + inline int GetHeight() const {
50 + return actual_height_;
51 + }
52 +
53 + // Draw the sprite at 0,0 - original width/height in the current reference
54 + // frame. Any transformations desired must be applied before calling this
55 + // function.
56 + void Draw() const {
57 + const float float_width = static_cast<float>(actual_width_);
58 + const float float_height = static_cast<float>(actual_height_);
59 +
60 + // Where it gets rendered to.
61 + const float vertices[] = { 0.0f, 0.0f, 0.0f,
62 + 0.0f, float_height, 0.0f,
63 + float_width, 0.0f, 0.0f,
64 + float_width, float_height, 0.0f,
65 + };
66 +
67 + // The coordinates the texture gets drawn from.
68 + const float max_x = float_width / texture_width_;
69 + const float max_y = float_height / texture_height_;
70 + const float textureVertices[] = {
71 + 0, 0,
72 + 0, max_y,
73 + max_x, 0,
74 + max_x, max_y,
75 + };
76 +
77 + glEnable(GL_TEXTURE_2D);
78 + glBindTexture(GL_TEXTURE_2D, texture_);
79 +
80 + glEnableClientState(GL_VERTEX_ARRAY);
81 + glEnableClientState(GL_TEXTURE_COORD_ARRAY);
82 +
83 + glVertexPointer(3, GL_FLOAT, 0, vertices);
84 + glTexCoordPointer(2, GL_FLOAT, 0, textureVertices);
85 +
86 + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
87 +
88 + glDisableClientState(GL_VERTEX_ARRAY);
89 + glDisableClientState(GL_TEXTURE_COORD_ARRAY);
90 + }
91 +
92 + private:
93 + inline int GetNextPowerOfTwo(const int number) const {
94 + int power_of_two = 1;
95 + while (power_of_two < number) {
96 + power_of_two *= 2;
97 + }
98 + return power_of_two;
99 + }
100 +
101 + // TODO(andrewharp): Allow sprites to have their textures reloaded.
102 + void LoadTexture(const Image<uint8_t>& texture_source,
103 + const BoundingBox* const area) {
104 + glEnable(GL_TEXTURE_2D);
105 +
106 + glGenTextures(1, &texture_);
107 +
108 + glBindTexture(GL_TEXTURE_2D, texture_);
109 +
110 + int left = 0;
111 + int top = 0;
112 +
113 + if (area != NULL) {
114 + // If a sub-region was provided to pull the texture from, use that.
115 + left = area->left_;
116 + top = area->top_;
117 + actual_width_ = area->GetWidth();
118 + actual_height_ = area->GetHeight();
119 + } else {
120 + actual_width_ = texture_source.GetWidth();
121 + actual_height_ = texture_source.GetHeight();
122 + }
123 +
124 + // The textures must be a power of two, so find the sizes that are large
125 + // enough to contain the image data.
126 + texture_width_ = GetNextPowerOfTwo(actual_width_);
127 + texture_height_ = GetNextPowerOfTwo(actual_height_);
128 +
129 + bool allocated_data = false;
130 + uint8_t* texture_data;
131 +
132 + // Except in the lucky case where we're not using a sub-region of the
133 + // original image AND the source data has dimensions that are power of two,
134 + // care must be taken to copy data at the appropriate source and destination
135 + // strides so that the final block can be copied directly into texture
136 + // memory.
137 + // TODO(andrewharp): Figure out if data can be pulled directly from the
138 + // source image with some alignment modifications.
139 + if (left != 0 || top != 0 ||
140 + actual_width_ != texture_source.GetWidth() ||
141 + actual_height_ != texture_source.GetHeight()) {
142 + texture_data = new uint8_t[actual_width_ * actual_height_];
143 +
144 + for (int y = 0; y < actual_height_; ++y) {
145 + memcpy(texture_data + actual_width_ * y, texture_source[top + y] + left,
146 + actual_width_ * sizeof(uint8_t));
147 + }
148 + allocated_data = true;
149 + } else {
150 + // Cast away const-ness because for some reason glTexSubImage2D wants
151 + // a non-const data pointer.
152 + texture_data = const_cast<uint8_t*>(texture_source.data());
153 + }
154 +
155 + glTexImage2D(GL_TEXTURE_2D,
156 + 0,
157 + GL_LUMINANCE,
158 + texture_width_,
159 + texture_height_,
160 + 0,
161 + GL_LUMINANCE,
162 + GL_UNSIGNED_BYTE,
163 + NULL);
164 +
165 + glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
166 + glTexSubImage2D(GL_TEXTURE_2D,
167 + 0,
168 + 0,
169 + 0,
170 + actual_width_,
171 + actual_height_,
172 + GL_LUMINANCE,
173 + GL_UNSIGNED_BYTE,
174 + texture_data);
175 +
176 + if (allocated_data) {
177 + delete(texture_data);
178 + }
179 +
180 + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
181 + }
182 +
183 + // The id for the texture on the GPU.
184 + GLuint texture_;
185 +
186 + // The width and height to be used for display purposes, referring to the
187 + // dimensions of the original texture.
188 + int actual_width_;
189 + int actual_height_;
190 +
191 + // The allocated dimensions of the texture data, which must be powers of 2.
192 + int texture_width_;
193 + int texture_height_;
194 +
195 + TF_DISALLOW_COPY_AND_ASSIGN(Sprite);
196 +};
197 +
198 +} // namespace tf_tracking
199 +
200 +#endif // __RENDER_OPENGL__
201 +
202 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
17 +
18 +#ifdef LOG_TIME
19 +// Storage for logging functionality.
20 +int num_time_logs = 0;
21 +LogEntry time_logs[NUM_LOGS];
22 +
23 +int num_avg_entries = 0;
24 +AverageEntry avg_entries[NUM_LOGS];
25 +#endif
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// Utility functions for performance profiling.
17 +
18 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
19 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
20 +
21 +#include <stdint.h>
22 +
23 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
25 +
26 +#ifdef LOG_TIME
27 +
28 +// Blend constant for running average.
29 +#define ALPHA 0.98f
30 +#define NUM_LOGS 100
31 +
32 +struct LogEntry {
33 + const char* id;
34 + int64_t time_stamp;
35 +};
36 +
37 +struct AverageEntry {
38 + const char* id;
39 + float average_duration;
40 +};
41 +
42 +// Storage for keeping track of this frame's values.
43 +extern int num_time_logs;
44 +extern LogEntry time_logs[NUM_LOGS];
45 +
46 +// Storage for keeping track of average values (each entry may not be printed
47 +// out each frame).
48 +extern AverageEntry avg_entries[NUM_LOGS];
49 +extern int num_avg_entries;
50 +
51 +// Call this at the start of a logging phase.
52 +inline static void ResetTimeLog() {
53 + num_time_logs = 0;
54 +}
55 +
56 +
57 +// Log a message to be printed out when printTimeLog is called, along with the
58 +// amount of time in ms that has passed since the last call to this function.
59 +inline static void TimeLog(const char* const str) {
60 + LOGV("%s", str);
61 + if (num_time_logs >= NUM_LOGS) {
62 + LOGE("Out of log entries!");
63 + return;
64 + }
65 +
66 + time_logs[num_time_logs].id = str;
67 + time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos();
68 + ++num_time_logs;
69 +}
70 +
71 +
72 +inline static float Blend(float old_val, float new_val) {
73 + return ALPHA * old_val + (1.0f - ALPHA) * new_val;
74 +}
75 +
76 +
77 +inline static float UpdateAverage(const char* str, const float new_val) {
78 + for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) {
79 + AverageEntry* const entry = avg_entries + entry_num;
80 + if (str == entry->id) {
81 + entry->average_duration = Blend(entry->average_duration, new_val);
82 + return entry->average_duration;
83 + }
84 + }
85 +
86 + if (num_avg_entries >= NUM_LOGS) {
87 + LOGE("Too many log entries!");
88 + }
89 +
90 + // If it wasn't there already, add it.
91 + avg_entries[num_avg_entries].id = str;
92 + avg_entries[num_avg_entries].average_duration = new_val;
93 + ++num_avg_entries;
94 +
95 + return new_val;
96 +}
97 +
98 +
99 +// Prints out all the timeLog statements in chronological order with the
100 +// interval that passed between subsequent statements. The total time between
101 +// the first and last statements is printed last.
102 +inline static void PrintTimeLog() {
103 + LogEntry* last_time = time_logs;
104 +
105 + float average_running_total = 0.0f;
106 +
107 + for (int i = 0; i < num_time_logs; ++i) {
108 + LogEntry* const this_time = time_logs + i;
109 +
110 + const float curr_time =
111 + (this_time->time_stamp - last_time->time_stamp) / 1000000.0f;
112 +
113 + const float avg_time = UpdateAverage(this_time->id, curr_time);
114 + average_running_total += avg_time;
115 +
116 + LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time);
117 + last_time = this_time;
118 + }
119 +
120 + const float total_time =
121 + (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f;
122 +
123 + LOGD("TOTAL TIME: %6.3fms %6.4fms\n",
124 + total_time, average_running_total);
125 + LOGD(" ");
126 +}
127 +#else
128 +inline static void ResetTimeLog() {}
129 +
130 +inline static void TimeLog(const char* const str) {
131 + LOGV("%s", str);
132 +}
133 +
134 +inline static void PrintTimeLog() {}
135 +#endif
136 +
137 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
17 +
18 +namespace tf_tracking {
19 +
20 +static const float kInitialDistance = 20.0f;
21 +
22 +static void InitNormalized(const Image<uint8_t>& src_image,
23 + const BoundingBox& position,
24 + Image<float>* const dst_image) {
25 + BoundingBox scaled_box(position);
26 + CopyArea(src_image, scaled_box, dst_image);
27 + NormalizeImage(dst_image);
28 +}
29 +
30 +TrackedObject::TrackedObject(const std::string& id, const Image<uint8_t>& image,
31 + const BoundingBox& bounding_box,
32 + ObjectModelBase* const model)
33 + : id_(id),
34 + last_known_position_(bounding_box),
35 + last_detection_position_(bounding_box),
36 + position_last_computed_time_(-1),
37 + object_model_(model),
38 + last_detection_thumbnail_(kNormalizedThumbnailSize,
39 + kNormalizedThumbnailSize),
40 + last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize),
41 + tracked_correlation_(0.0f),
42 + tracked_match_score_(0.0),
43 + num_consecutive_frames_below_threshold_(0),
44 + allowable_detection_distance_(Square(kInitialDistance)) {
45 + InitNormalized(image, bounding_box, &last_detection_thumbnail_);
46 +}
47 +
48 +TrackedObject::~TrackedObject() {}
49 +
50 +void TrackedObject::UpdatePosition(const BoundingBox& new_position,
51 + const int64_t timestamp,
52 + const ImageData& image_data,
53 + const bool authoritative) {
54 + last_known_position_ = new_position;
55 + position_last_computed_time_ = timestamp;
56 +
57 + InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_);
58 +
59 + const float last_localization_correlation = ComputeCrossCorrelation(
60 + last_detection_thumbnail_.data(),
61 + last_frame_thumbnail_.data(),
62 + last_frame_thumbnail_.data_size_);
63 + LOGV("Tracked correlation to last localization: %.6f",
64 + last_localization_correlation);
65 +
66 + // Correlation to object model, if it exists.
67 + if (object_model_ != NULL) {
68 + tracked_correlation_ =
69 + object_model_->GetMaxCorrelation(last_frame_thumbnail_);
70 + LOGV("Tracked correlation to model: %.6f",
71 + tracked_correlation_);
72 +
73 + tracked_match_score_ =
74 + object_model_->GetMatchScore(new_position, image_data);
75 + LOGV("Tracked match score with model: %.6f",
76 + tracked_match_score_.value);
77 + } else {
78 + // If there's no model to check against, set the tracked correlation to
79 + // simply be the correlation to the last set position.
80 + tracked_correlation_ = last_localization_correlation;
81 + tracked_match_score_ = MatchScore(0.0f);
82 + }
83 +
84 + // Determine if it's still being tracked.
85 + if (tracked_correlation_ >= kMinimumCorrelationForTracking &&
86 + tracked_match_score_ >= kMinimumMatchScore) {
87 + num_consecutive_frames_below_threshold_ = 0;
88 +
89 + if (object_model_ != NULL) {
90 + object_model_->TrackStep(last_known_position_, *image_data.GetImage(),
91 + *image_data.GetIntegralImage(), authoritative);
92 + }
93 + } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) {
94 + if (num_consecutive_frames_below_threshold_ < 1000) {
95 + LOGD("Tracked match score is way too low (%.6f), aborting track.",
96 + tracked_match_score_.value);
97 + }
98 +
99 + // Add an absurd amount of missed frames so that all heuristics will
100 + // consider it a lost track.
101 + num_consecutive_frames_below_threshold_ += 1000;
102 +
103 + if (object_model_ != NULL) {
104 + object_model_->TrackLost();
105 + }
106 + } else {
107 + ++num_consecutive_frames_below_threshold_;
108 + allowable_detection_distance_ *= 1.1f;
109 + }
110 +}
111 +
112 +void TrackedObject::OnDetection(ObjectModelBase* const model,
113 + const BoundingBox& detection_position,
114 + const MatchScore match_score,
115 + const int64_t timestamp,
116 + const ImageData& image_data) {
117 + const float overlap = detection_position.PascalScore(last_known_position_);
118 + if (overlap > kPositionOverlapThreshold) {
119 + // If the position agreement with the current tracked position is good
120 + // enough, lock all the current unlocked examples.
121 + object_model_->TrackConfirmed();
122 + num_consecutive_frames_below_threshold_ = 0;
123 + }
124 +
125 + // Before relocalizing, make sure the new proposed position is better than
126 + // the existing position by a small amount to prevent thrashing.
127 + if (match_score <= tracked_match_score_ + kMatchScoreBuffer) {
128 + LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f",
129 + match_score.value, tracked_match_score_.value,
130 + kMatchScoreBuffer.value);
131 + return;
132 + }
133 +
134 + LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to "
135 + "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f",
136 + last_known_position_.left_, last_known_position_.top_,
137 + last_known_position_.GetWidth(), last_known_position_.GetHeight(),
138 + detection_position.left_, detection_position.top_,
139 + detection_position.GetWidth(), detection_position.GetHeight(),
140 + match_score.value, tracked_match_score_.value);
141 +
142 + if (overlap < kPositionOverlapThreshold) {
143 + // The path might be good, it might be bad, but it's no longer a path
144 + // since we're moving the box to a new position, so just nuke it from
145 + // orbit to be safe.
146 + object_model_->TrackLost();
147 + }
148 +
149 + object_model_ = model;
150 +
151 + // Reset the last detected appearance.
152 + InitNormalized(
153 + *image_data.GetImage(), detection_position, &last_detection_thumbnail_);
154 +
155 + num_consecutive_frames_below_threshold_ = 0;
156 + last_detection_position_ = detection_position;
157 +
158 + UpdatePosition(detection_position, timestamp, image_data, false);
159 + allowable_detection_distance_ = Square(kInitialDistance);
160 +}
161 +
162 +} // namespace tf_tracking
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
18 +
19 +#ifdef __RENDER_OPENGL__
20 +#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
21 +#endif
22 +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
23 +
24 +namespace tf_tracking {
25 +
26 +// A TrackedObject is a specific instance of an ObjectModel, with a known
27 +// position in the world.
28 +// It provides the last known position and number of recent detection failures,
29 +// in addition to the more general appearance data associated with the object
30 +// class (which is in ObjectModel).
31 +// TODO(andrewharp): Make getters/setters follow styleguide.
32 +class TrackedObject {
33 + public:
34 + TrackedObject(const std::string& id, const Image<uint8_t>& image,
35 + const BoundingBox& bounding_box, ObjectModelBase* const model);
36 +
37 + ~TrackedObject();
38 +
39 + void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp,
40 + const ImageData& image_data, const bool authoritative);
41 +
42 + // This method is called when the tracked object is detected at a
43 + // given position, and allows the associated Model to grow and/or prune
44 + // itself based on where the detection occurred.
45 + void OnDetection(ObjectModelBase* const model,
46 + const BoundingBox& detection_position,
47 + const MatchScore match_score, const int64_t timestamp,
48 + const ImageData& image_data);
49 +
50 + // Called when there's no detection of the tracked object. This will cause
51 + // a tracking failure after enough consecutive failures if the area under
52 + // the current bounding box also doesn't meet a minimum correlation threshold
53 + // with the model.
54 + void OnDetectionFailure() {}
55 +
56 + inline bool IsVisible() const {
57 + return tracked_correlation_ >= kMinimumCorrelationForTracking ||
58 + num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
59 + }
60 +
61 + inline float GetCorrelation() {
62 + return tracked_correlation_;
63 + }
64 +
65 + inline MatchScore GetMatchScore() {
66 + return tracked_match_score_;
67 + }
68 +
69 + inline BoundingBox GetPosition() const {
70 + return last_known_position_;
71 + }
72 +
73 + inline BoundingBox GetLastDetectionPosition() const {
74 + return last_detection_position_;
75 + }
76 +
77 + inline const ObjectModelBase* GetModel() const {
78 + return object_model_;
79 + }
80 +
81 + inline const std::string& GetName() const {
82 + return id_;
83 + }
84 +
85 + inline void Draw() const {
86 +#ifdef __RENDER_OPENGL__
87 + if (tracked_correlation_ < kMinimumCorrelationForTracking) {
88 + glColor4f(MAX(0.0f, -tracked_correlation_),
89 + MAX(0.0f, tracked_correlation_),
90 + 0.0f,
91 + 1.0f);
92 + } else {
93 + glColor4f(MAX(0.0f, -tracked_correlation_),
94 + MAX(0.0f, tracked_correlation_),
95 + 1.0f,
96 + 1.0f);
97 + }
98 +
99 + // Render the box itself.
100 + BoundingBox temp_box(last_known_position_);
101 + DrawBox(temp_box);
102 +
103 + // Render a box inside this one (in case the actual box is hidden).
104 + const float kBufferSize = 1.0f;
105 + temp_box.left_ -= kBufferSize;
106 + temp_box.top_ -= kBufferSize;
107 + temp_box.right_ += kBufferSize;
108 + temp_box.bottom_ += kBufferSize;
109 + DrawBox(temp_box);
110 +
111 + // Render one outside as well.
112 + temp_box.left_ -= -2.0f * kBufferSize;
113 + temp_box.top_ -= -2.0f * kBufferSize;
114 + temp_box.right_ += -2.0f * kBufferSize;
115 + temp_box.bottom_ += -2.0f * kBufferSize;
116 + DrawBox(temp_box);
117 +#endif
118 + }
119 +
120 + // Get current object's num_consecutive_frames_below_threshold_.
121 + inline int64_t GetNumConsecutiveFramesBelowThreshold() {
122 + return num_consecutive_frames_below_threshold_;
123 + }
124 +
125 + // Reset num_consecutive_frames_below_threshold_ to 0.
126 + inline void resetNumConsecutiveFramesBelowThreshold() {
127 + num_consecutive_frames_below_threshold_ = 0;
128 + }
129 +
130 + inline float GetAllowableDistanceSquared() const {
131 + return allowable_detection_distance_;
132 + }
133 +
134 + private:
135 + // The unique id used throughout the system to identify this
136 + // tracked object.
137 + const std::string id_;
138 +
139 + // The last known position of the object.
140 + BoundingBox last_known_position_;
141 +
142 + // The last known position of the object.
143 + BoundingBox last_detection_position_;
144 +
145 + // When the position was last computed.
146 + int64_t position_last_computed_time_;
147 +
148 + // The object model this tracked object is representative of.
149 + ObjectModelBase* object_model_;
150 +
151 + Image<float> last_detection_thumbnail_;
152 +
153 + Image<float> last_frame_thumbnail_;
154 +
155 + // The correlation of the object model with the preview frame at its last
156 + // tracked position.
157 + float tracked_correlation_;
158 +
159 + MatchScore tracked_match_score_;
160 +
161 + // The number of consecutive frames that the tracked position for this object
162 + // has been under the correlation threshold.
163 + int num_consecutive_frames_below_threshold_;
164 +
165 + float allowable_detection_distance_;
166 +
167 + friend std::ostream& operator<<(std::ostream& stream,
168 + const TrackedObject& tracked_object);
169 +
170 + TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
171 +};
172 +
173 +inline std::ostream& operator<<(std::ostream& stream,
174 + const TrackedObject& tracked_object) {
175 + stream << tracked_object.id_
176 + << " " << tracked_object.last_known_position_
177 + << " " << tracked_object.position_last_computed_time_
178 + << " " << tracked_object.num_consecutive_frames_below_threshold_
179 + << " " << tracked_object.object_model_
180 + << " " << tracked_object.tracked_correlation_;
181 + return stream;
182 +}
183 +
184 +} // namespace tf_tracking
185 +
186 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
18 +
19 +#include <math.h>
20 +#include <stdint.h>
21 +#include <stdlib.h>
22 +#include <time.h>
23 +
24 +#include <cmath> // for std::abs(float)
25 +
26 +#ifndef HAVE_CLOCK_GETTIME
27 +// Use gettimeofday() instead of clock_gettime().
28 +#include <sys/time.h>
29 +#endif // ifdef HAVE_CLOCK_GETTIME
30 +
31 +#include "tensorflow/examples/android/jni/object_tracking/logging.h"
32 +
33 +// TODO(andrewharp): clean up these macros to use the codebase statndard.
34 +
35 +// A very small number, generally used as the tolerance for accumulated
36 +// floating point errors in bounds-checks.
37 +#define EPSILON 0.00001f
38 +
39 +#define SAFE_DELETE(pointer) {\
40 + if ((pointer) != NULL) {\
41 + LOGV("Safe deleting pointer: %s", #pointer);\
42 + delete (pointer);\
43 + (pointer) = NULL;\
44 + } else {\
45 + LOGV("Pointer already null: %s", #pointer);\
46 + }\
47 +}
48 +
49 +
50 +#ifdef __GOOGLE__
51 +
52 +#define CHECK_ALWAYS(condition, format, ...) {\
53 + CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
54 +}
55 +
56 +#define SCHECK(condition, format, ...) {\
57 + DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
58 +}
59 +
60 +#else
61 +
62 +#define CHECK_ALWAYS(condition, format, ...) {\
63 + if (!(condition)) {\
64 + LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\
65 + abort();\
66 + }\
67 +}
68 +
69 +#ifdef SANITY_CHECKS
70 +#define SCHECK(condition, format, ...) {\
71 + CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\
72 +}
73 +#else
74 +#define SCHECK(condition, format, ...) {}
75 +#endif // SANITY_CHECKS
76 +
77 +#endif // __GOOGLE__
78 +
79 +
80 +#ifndef MAX
81 +#define MAX(a, b) (((a) > (b)) ? (a) : (b))
82 +#endif
83 +#ifndef MIN
84 +#define MIN(a, b) (((a) > (b)) ? (b) : (a))
85 +#endif
86 +
87 +inline static int64_t CurrentThreadTimeNanos() {
88 +#ifdef HAVE_CLOCK_GETTIME
89 + struct timespec tm;
90 + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm);
91 + return tm.tv_sec * 1000000000LL + tm.tv_nsec;
92 +#else
93 + struct timeval tv;
94 + gettimeofday(&tv, NULL);
95 + return tv.tv_sec * 1000000000 + tv.tv_usec * 1000;
96 +#endif
97 +}
98 +
99 +inline static int64_t CurrentRealTimeMillis() {
100 +#ifdef HAVE_CLOCK_GETTIME
101 + struct timespec tm;
102 + clock_gettime(CLOCK_MONOTONIC, &tm);
103 + return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL;
104 +#else
105 + struct timeval tv;
106 + gettimeofday(&tv, NULL);
107 + return tv.tv_sec * 1000 + tv.tv_usec / 1000;
108 +#endif
109 +}
110 +
111 +
112 +template<typename T>
113 +inline static T Square(const T a) {
114 + return a * a;
115 +}
116 +
117 +
118 +template<typename T>
119 +inline static T Clip(const T a, const T floor, const T ceil) {
120 + SCHECK(ceil >= floor, "Bounds mismatch!");
121 + return (a <= floor) ? floor : ((a >= ceil) ? ceil : a);
122 +}
123 +
124 +
125 +template<typename T>
126 +inline static int Floor(const T a) {
127 + return static_cast<int>(a);
128 +}
129 +
130 +
131 +template<typename T>
132 +inline static int Ceil(const T a) {
133 + return Floor(a) + 1;
134 +}
135 +
136 +
137 +template<typename T>
138 +inline static bool InRange(const T a, const T min, const T max) {
139 + return (a >= min) && (a <= max);
140 +}
141 +
142 +
143 +inline static bool ValidIndex(const int a, const int max) {
144 + return (a >= 0) && (a < max);
145 +}
146 +
147 +
148 +inline bool NearlyEqual(const float a, const float b, const float tolerance) {
149 + return std::abs(a - b) < tolerance;
150 +}
151 +
152 +
153 +inline bool NearlyEqual(const float a, const float b) {
154 + return NearlyEqual(a, b, EPSILON);
155 +}
156 +
157 +
158 +template<typename T>
159 +inline static int Round(const float a) {
160 + return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a));
161 +}
162 +
163 +
164 +template<typename T>
165 +inline static void Swap(T* const a, T* const b) {
166 + // Cache out the VALUE of what's at a.
167 + T tmp = *a;
168 + *a = *b;
169 +
170 + *b = tmp;
171 +}
172 +
173 +
174 +static inline float randf() {
175 + return rand() / static_cast<float>(RAND_MAX);
176 +}
177 +
178 +static inline float randf(const float min_value, const float max_value) {
179 + return randf() * (max_value - min_value) + min_value;
180 +}
181 +
182 +static inline uint16_t RealToFixed115(const float real_number) {
183 + SCHECK(InRange(real_number, 0.0f, 2048.0f),
184 + "Value out of range! %.2f", real_number);
185 +
186 + static const float kMult = 32.0f;
187 + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
188 + return static_cast<uint16_t>(real_number * kMult + round_add);
189 +}
190 +
191 +static inline float FixedToFloat115(const uint16_t fp_number) {
192 + const float kDiv = 32.0f;
193 + return (static_cast<float>(fp_number) / kDiv);
194 +}
195 +
196 +static inline int RealToFixed1616(const float real_number) {
197 + static const float kMult = 65536.0f;
198 + SCHECK(InRange(real_number, -kMult, kMult),
199 + "Value out of range! %.2f", real_number);
200 +
201 + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
202 + return static_cast<int>(real_number * kMult + round_add);
203 +}
204 +
205 +static inline float FixedToFloat1616(const int fp_number) {
206 + const float kDiv = 65536.0f;
207 + return (static_cast<float>(fp_number) / kDiv);
208 +}
209 +
210 +template<typename T>
211 +// produces numbers in range [0,2*M_PI] (rather than -PI,PI)
212 +inline T FastAtan2(const T y, const T x) {
213 + static const T coeff_1 = (T)(M_PI / 4.0);
214 + static const T coeff_2 = (T)(3.0 * coeff_1);
215 + const T abs_y = fabs(y);
216 + T angle;
217 + if (x >= 0) {
218 + T r = (x - abs_y) / (x + abs_y);
219 + angle = coeff_1 - coeff_1 * r;
220 + } else {
221 + T r = (x + abs_y) / (abs_y - x);
222 + angle = coeff_2 - coeff_1 * r;
223 + }
224 + static const T PI_2 = 2.0 * M_PI;
225 + return y < 0 ? PI_2 - angle : angle;
226 +}
227 +
228 +#define NELEMS(X) (sizeof(X) / sizeof(X[0]))
229 +
230 +namespace tf_tracking {
231 +
232 +#ifdef __ARM_NEON
233 +float ComputeMeanNeon(const float* const values, const int num_vals);
234 +
235 +float ComputeStdDevNeon(const float* const values, const int num_vals,
236 + const float mean);
237 +
238 +float ComputeWeightedMeanNeon(const float* const values,
239 + const float* const weights, const int num_vals);
240 +
241 +float ComputeCrossCorrelationNeon(const float* const values1,
242 + const float* const values2,
243 + const int num_vals);
244 +#endif
245 +
246 +inline float ComputeMeanCpu(const float* const values, const int num_vals) {
247 + // Get mean.
248 + float sum = values[0];
249 + for (int i = 1; i < num_vals; ++i) {
250 + sum += values[i];
251 + }
252 + return sum / static_cast<float>(num_vals);
253 +}
254 +
255 +
256 +inline float ComputeMean(const float* const values, const int num_vals) {
257 + return
258 +#ifdef __ARM_NEON
259 + (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) :
260 +#endif
261 + ComputeMeanCpu(values, num_vals);
262 +}
263 +
264 +
265 +inline float ComputeStdDevCpu(const float* const values,
266 + const int num_vals,
267 + const float mean) {
268 + // Get Std dev.
269 + float squared_sum = 0.0f;
270 + for (int i = 0; i < num_vals; ++i) {
271 + squared_sum += Square(values[i] - mean);
272 + }
273 + return sqrt(squared_sum / static_cast<float>(num_vals));
274 +}
275 +
276 +
277 +inline float ComputeStdDev(const float* const values,
278 + const int num_vals,
279 + const float mean) {
280 + return
281 +#ifdef __ARM_NEON
282 + (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) :
283 +#endif
284 + ComputeStdDevCpu(values, num_vals, mean);
285 +}
286 +
287 +
288 +// TODO(andrewharp): Accelerate with NEON.
289 +inline float ComputeWeightedMean(const float* const values,
290 + const float* const weights,
291 + const int num_vals) {
292 + float sum = 0.0f;
293 + float total_weight = 0.0f;
294 + for (int i = 0; i < num_vals; ++i) {
295 + sum += values[i] * weights[i];
296 + total_weight += weights[i];
297 + }
298 + return sum / num_vals;
299 +}
300 +
301 +
302 +inline float ComputeCrossCorrelationCpu(const float* const values1,
303 + const float* const values2,
304 + const int num_vals) {
305 + float sxy = 0.0f;
306 + for (int offset = 0; offset < num_vals; ++offset) {
307 + sxy += values1[offset] * values2[offset];
308 + }
309 +
310 + const float cross_correlation = sxy / num_vals;
311 +
312 + return cross_correlation;
313 +}
314 +
315 +
316 +inline float ComputeCrossCorrelation(const float* const values1,
317 + const float* const values2,
318 + const int num_vals) {
319 + return
320 +#ifdef __ARM_NEON
321 + (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals)
322 + :
323 +#endif
324 + ComputeCrossCorrelationCpu(values1, values2, num_vals);
325 +}
326 +
327 +
328 +inline void NormalizeNumbers(float* const values, const int num_vals) {
329 + // Find the mean and then subtract so that the new mean is 0.0.
330 + const float mean = ComputeMean(values, num_vals);
331 + VLOG(2) << "Mean is " << mean;
332 + float* curr_data = values;
333 + for (int i = 0; i < num_vals; ++i) {
334 + *curr_data -= mean;
335 + curr_data++;
336 + }
337 +
338 + // Now divide by the std deviation so the new standard deviation is 1.0.
339 + // The numbers might all be identical (and thus shifted to 0.0 now),
340 + // so only scale by the standard deviation if this is not the case.
341 + const float std_dev = ComputeStdDev(values, num_vals, 0.0f);
342 + if (std_dev > 0.0f) {
343 + VLOG(2) << "Std dev is " << std_dev;
344 + curr_data = values;
345 + for (int i = 0; i < num_vals; ++i) {
346 + *curr_data /= std_dev;
347 + curr_data++;
348 + }
349 + }
350 +}
351 +
352 +
353 +// Returns the determinant of a 2x2 matrix.
354 +template<class T>
355 +inline T FindDeterminant2x2(const T* const a) {
356 + // Determinant: (ad - bc)
357 + return a[0] * a[3] - a[1] * a[2];
358 +}
359 +
360 +
361 +// Finds the inverse of a 2x2 matrix.
362 +// Returns true upon success, false if the matrix is not invertible.
363 +template<class T>
364 +inline bool Invert2x2(const T* const a, float* const a_inv) {
365 + const float det = static_cast<float>(FindDeterminant2x2(a));
366 + if (fabs(det) < EPSILON) {
367 + return false;
368 + }
369 + const float inv_det = 1.0f / det;
370 +
371 + a_inv[0] = inv_det * static_cast<float>(a[3]); // d
372 + a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b
373 + a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c
374 + a_inv[3] = inv_det * static_cast<float>(a[0]); // a
375 +
376 + return true;
377 +}
378 +
379 +} // namespace tf_tracking
380 +
381 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// NEON implementations of Image methods for compatible devices. Control
17 +// should never enter this compilation unit on incompatible devices.
18 +
19 +#ifdef __ARM_NEON
20 +
21 +#include <arm_neon.h>
22 +
23 +#include "tensorflow/examples/android/jni/object_tracking/geom.h"
24 +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
25 +#include "tensorflow/examples/android/jni/object_tracking/image.h"
26 +#include "tensorflow/examples/android/jni/object_tracking/utils.h"
27 +
28 +namespace tf_tracking {
29 +
30 +inline static float GetSum(const float32x4_t& values) {
31 + static float32_t summed_values[4];
32 + vst1q_f32(summed_values, values);
33 + return summed_values[0]
34 + + summed_values[1]
35 + + summed_values[2]
36 + + summed_values[3];
37 +}
38 +
39 +
40 +float ComputeMeanNeon(const float* const values, const int num_vals) {
41 + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
42 +
43 + const float32_t* const arm_vals = (const float32_t* const) values;
44 + float32x4_t accum = vdupq_n_f32(0.0f);
45 +
46 + int offset = 0;
47 + for (; offset <= num_vals - 4; offset += 4) {
48 + accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset]));
49 + }
50 +
51 + // Pull the accumulated values into a single variable.
52 + float sum = GetSum(accum);
53 +
54 + // Get the remaining 1 to 3 values.
55 + for (; offset < num_vals; ++offset) {
56 + sum += values[offset];
57 + }
58 +
59 + const float mean_neon = sum / static_cast<float>(num_vals);
60 +
61 +#ifdef SANITY_CHECKS
62 + const float mean_cpu = ComputeMeanCpu(values, num_vals);
63 + SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals),
64 + "Neon mismatch with CPU mean! %.10f vs %.10f",
65 + mean_neon, mean_cpu);
66 +#endif
67 +
68 + return mean_neon;
69 +}
70 +
71 +
72 +float ComputeStdDevNeon(const float* const values,
73 + const int num_vals, const float mean) {
74 + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
75 +
76 + const float32_t* const arm_vals = (const float32_t* const) values;
77 + const float32x4_t mean_vec = vdupq_n_f32(-mean);
78 +
79 + float32x4_t accum = vdupq_n_f32(0.0f);
80 +
81 + int offset = 0;
82 + for (; offset <= num_vals - 4; offset += 4) {
83 + const float32x4_t deltas =
84 + vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset]));
85 +
86 + accum = vmlaq_f32(accum, deltas, deltas);
87 + }
88 +
89 + // Pull the accumulated values into a single variable.
90 + float squared_sum = GetSum(accum);
91 +
92 + // Get the remaining 1 to 3 values.
93 + for (; offset < num_vals; ++offset) {
94 + squared_sum += Square(values[offset] - mean);
95 + }
96 +
97 + const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals));
98 +
99 +#ifdef SANITY_CHECKS
100 + const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean);
101 + SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals),
102 + "Neon mismatch with CPU std dev! %.10f vs %.10f",
103 + std_dev_neon, std_dev_cpu);
104 +#endif
105 +
106 + return std_dev_neon;
107 +}
108 +
109 +
110 +float ComputeCrossCorrelationNeon(const float* const values1,
111 + const float* const values2,
112 + const int num_vals) {
113 + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
114 +
115 + const float32_t* const arm_vals1 = (const float32_t* const) values1;
116 + const float32_t* const arm_vals2 = (const float32_t* const) values2;
117 +
118 + float32x4_t accum = vdupq_n_f32(0.0f);
119 +
120 + int offset = 0;
121 + for (; offset <= num_vals - 4; offset += 4) {
122 + accum = vmlaq_f32(accum,
123 + vld1q_f32(&arm_vals1[offset]),
124 + vld1q_f32(&arm_vals2[offset]));
125 + }
126 +
127 + // Pull the accumulated values into a single variable.
128 + float sxy = GetSum(accum);
129 +
130 + // Get the remaining 1 to 3 values.
131 + for (; offset < num_vals; ++offset) {
132 + sxy += values1[offset] * values2[offset];
133 + }
134 +
135 + const float cross_correlation_neon = sxy / num_vals;
136 +
137 +#ifdef SANITY_CHECKS
138 + const float cross_correlation_cpu =
139 + ComputeCrossCorrelationCpu(values1, values2, num_vals);
140 + SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu,
141 + EPSILON * num_vals),
142 + "Neon mismatch with CPU cross correlation! %.10f vs %.10f",
143 + cross_correlation_neon, cross_correlation_cpu);
144 +#endif
145 +
146 + return cross_correlation_neon;
147 +}
148 +
149 +} // namespace tf_tracking
150 +
151 +#endif // __ARM_NEON
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// These utility functions allow for the conversion of RGB data to YUV data.
17 +
18 +#include "tensorflow/examples/android/jni/rgb2yuv.h"
19 +
20 +static inline void WriteYUV(const int x, const int y, const int width,
21 + const int r8, const int g8, const int b8,
22 + uint8_t* const pY, uint8_t* const pUV) {
23 + // Using formulas from http://msdn.microsoft.com/en-us/library/ms893078
24 + *pY = ((66 * r8 + 129 * g8 + 25 * b8 + 128) >> 8) + 16;
25 +
26 + // Odd widths get rounded up so that UV blocks on the side don't get cut off.
27 + const int blocks_per_row = (width + 1) / 2;
28 +
29 + // 2 bytes per UV block
30 + const int offset = 2 * (((y / 2) * blocks_per_row + (x / 2)));
31 +
32 + // U and V are the average values of all 4 pixels in the block.
33 + if (!(x & 1) && !(y & 1)) {
34 + // Explicitly clear the block if this is the first pixel in it.
35 + pUV[offset] = 0;
36 + pUV[offset + 1] = 0;
37 + }
38 +
39 + // V (with divide by 4 factored in)
40 +#ifdef __APPLE__
41 + const int u_offset = 0;
42 + const int v_offset = 1;
43 +#else
44 + const int u_offset = 1;
45 + const int v_offset = 0;
46 +#endif
47 + pUV[offset + v_offset] += ((112 * r8 - 94 * g8 - 18 * b8 + 128) >> 10) + 32;
48 +
49 + // U (with divide by 4 factored in)
50 + pUV[offset + u_offset] += ((-38 * r8 - 74 * g8 + 112 * b8 + 128) >> 10) + 32;
51 +}
52 +
53 +void ConvertARGB8888ToYUV420SP(const uint32_t* const input,
54 + uint8_t* const output, int width, int height) {
55 + uint8_t* pY = output;
56 + uint8_t* pUV = output + (width * height);
57 + const uint32_t* in = input;
58 +
59 + for (int y = 0; y < height; y++) {
60 + for (int x = 0; x < width; x++) {
61 + const uint32_t rgb = *in++;
62 +#ifdef __APPLE__
63 + const int nB = (rgb >> 8) & 0xFF;
64 + const int nG = (rgb >> 16) & 0xFF;
65 + const int nR = (rgb >> 24) & 0xFF;
66 +#else
67 + const int nR = (rgb >> 16) & 0xFF;
68 + const int nG = (rgb >> 8) & 0xFF;
69 + const int nB = rgb & 0xFF;
70 +#endif
71 + WriteYUV(x, y, width, nR, nG, nB, pY++, pUV);
72 + }
73 + }
74 +}
75 +
76 +void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
77 + const int width, const int height) {
78 + uint8_t* pY = output;
79 + uint8_t* pUV = output + (width * height);
80 + const uint16_t* in = input;
81 +
82 + for (int y = 0; y < height; y++) {
83 + for (int x = 0; x < width; x++) {
84 + const uint32_t rgb = *in++;
85 +
86 + const int r5 = ((rgb >> 11) & 0x1F);
87 + const int g6 = ((rgb >> 5) & 0x3F);
88 + const int b5 = (rgb & 0x1F);
89 +
90 + // Shift left, then fill in the empty low bits with a copy of the high
91 + // bits so we can stretch across the entire 0 - 255 range.
92 + const int r8 = r5 << 3 | r5 >> 2;
93 + const int g8 = g6 << 2 | g6 >> 4;
94 + const int b8 = b5 << 3 | b5 >> 2;
95 +
96 + WriteYUV(x, y, width, r8, g8, b8, pY++, pUV);
97 + }
98 + }
99 +}
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
17 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
18 +
19 +#include <stdint.h>
20 +
21 +#ifdef __cplusplus
22 +extern "C" {
23 +#endif
24 +
25 +void ConvertARGB8888ToYUV420SP(const uint32_t* const input,
26 + uint8_t* const output, int width, int height);
27 +
28 +void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
29 + const int width, const int height);
30 +
31 +#ifdef __cplusplus
32 +}
33 +#endif
34 +
35 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
1 +VERS_1.0 {
2 + # Export JNI symbols.
3 + global:
4 + Java_*;
5 + JNI_OnLoad;
6 + JNI_OnUnload;
7 +
8 + # Hide everything else.
9 + local:
10 + *;
11 +};
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// This is a collection of routines which converts various YUV image formats
17 +// to ARGB.
18 +
19 +#include "tensorflow/examples/android/jni/yuv2rgb.h"
20 +
21 +#ifndef MAX
22 +#define MAX(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a > _b ? _a : _b; })
23 +#define MIN(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a < _b ? _a : _b; })
24 +#endif
25 +
26 +// This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
27 +// are normalized to eight bits.
28 +static const int kMaxChannelValue = 262143;
29 +
30 +static inline uint32_t YUV2RGB(int nY, int nU, int nV) {
31 + nY -= 16;
32 + nU -= 128;
33 + nV -= 128;
34 + if (nY < 0) nY = 0;
35 +
36 + // This is the floating point equivalent. We do the conversion in integer
37 + // because some Android devices do not have floating point in hardware.
38 + // nR = (int)(1.164 * nY + 2.018 * nU);
39 + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
40 + // nB = (int)(1.164 * nY + 1.596 * nV);
41 +
42 + int nR = 1192 * nY + 1634 * nV;
43 + int nG = 1192 * nY - 833 * nV - 400 * nU;
44 + int nB = 1192 * nY + 2066 * nU;
45 +
46 + nR = MIN(kMaxChannelValue, MAX(0, nR));
47 + nG = MIN(kMaxChannelValue, MAX(0, nG));
48 + nB = MIN(kMaxChannelValue, MAX(0, nB));
49 +
50 + nR = (nR >> 10) & 0xff;
51 + nG = (nG >> 10) & 0xff;
52 + nB = (nB >> 10) & 0xff;
53 +
54 + return 0xff000000 | (nR << 16) | (nG << 8) | nB;
55 +}
56 +
57 +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by
58 +// separate u and v planes with arbitrary row and column strides,
59 +// containing 8 bit 2x2 subsampled chroma samples.
60 +// Converts to a packed ARGB 32 bit output of the same pixel dimensions.
61 +void ConvertYUV420ToARGB8888(const uint8_t* const yData,
62 + const uint8_t* const uData,
63 + const uint8_t* const vData, uint32_t* const output,
64 + const int width, const int height,
65 + const int y_row_stride, const int uv_row_stride,
66 + const int uv_pixel_stride) {
67 + uint32_t* out = output;
68 +
69 + for (int y = 0; y < height; y++) {
70 + const uint8_t* pY = yData + y_row_stride * y;
71 +
72 + const int uv_row_start = uv_row_stride * (y >> 1);
73 + const uint8_t* pU = uData + uv_row_start;
74 + const uint8_t* pV = vData + uv_row_start;
75 +
76 + for (int x = 0; x < width; x++) {
77 + const int uv_offset = (x >> 1) * uv_pixel_stride;
78 + *out++ = YUV2RGB(pY[x], pU[uv_offset], pV[uv_offset]);
79 + }
80 + }
81 +}
82 +
83 +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
84 +// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
85 +// except the interleave order of U and V is reversed. Converts to a packed
86 +// ARGB 32 bit output of the same pixel dimensions.
87 +void ConvertYUV420SPToARGB8888(const uint8_t* const yData,
88 + const uint8_t* const uvData,
89 + uint32_t* const output, const int width,
90 + const int height) {
91 + const uint8_t* pY = yData;
92 + const uint8_t* pUV = uvData;
93 + uint32_t* out = output;
94 +
95 + for (int y = 0; y < height; y++) {
96 + for (int x = 0; x < width; x++) {
97 + int nY = *pY++;
98 + int offset = (y >> 1) * width + 2 * (x >> 1);
99 +#ifdef __APPLE__
100 + int nU = pUV[offset];
101 + int nV = pUV[offset + 1];
102 +#else
103 + int nV = pUV[offset];
104 + int nU = pUV[offset + 1];
105 +#endif
106 +
107 + *out++ = YUV2RGB(nY, nU, nV);
108 + }
109 + }
110 +}
111 +
112 +// The same as above, but downsamples each dimension to half size.
113 +void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input,
114 + uint32_t* const output, int width,
115 + int height) {
116 + const uint8_t* pY = input;
117 + const uint8_t* pUV = input + (width * height);
118 + uint32_t* out = output;
119 + int stride = width;
120 + width >>= 1;
121 + height >>= 1;
122 +
123 + for (int y = 0; y < height; y++) {
124 + for (int x = 0; x < width; x++) {
125 + int nY = (pY[0] + pY[1] + pY[stride] + pY[stride + 1]) >> 2;
126 + pY += 2;
127 +#ifdef __APPLE__
128 + int nU = *pUV++;
129 + int nV = *pUV++;
130 +#else
131 + int nV = *pUV++;
132 + int nU = *pUV++;
133 +#endif
134 +
135 + *out++ = YUV2RGB(nY, nU, nV);
136 + }
137 + pY += stride;
138 + }
139 +}
140 +
141 +// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
142 +// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
143 +// except the interleave order of U and V is reversed. Converts to a packed
144 +// RGB 565 bit output of the same pixel dimensions.
145 +void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
146 + const int width, const int height) {
147 + const uint8_t* pY = input;
148 + const uint8_t* pUV = input + (width * height);
149 + uint16_t* out = output;
150 +
151 + for (int y = 0; y < height; y++) {
152 + for (int x = 0; x < width; x++) {
153 + int nY = *pY++;
154 + int offset = (y >> 1) * width + 2 * (x >> 1);
155 +#ifdef __APPLE__
156 + int nU = pUV[offset];
157 + int nV = pUV[offset + 1];
158 +#else
159 + int nV = pUV[offset];
160 + int nU = pUV[offset + 1];
161 +#endif
162 +
163 + nY -= 16;
164 + nU -= 128;
165 + nV -= 128;
166 + if (nY < 0) nY = 0;
167 +
168 + // This is the floating point equivalent. We do the conversion in integer
169 + // because some Android devices do not have floating point in hardware.
170 + // nR = (int)(1.164 * nY + 2.018 * nU);
171 + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
172 + // nB = (int)(1.164 * nY + 1.596 * nV);
173 +
174 + int nR = 1192 * nY + 1634 * nV;
175 + int nG = 1192 * nY - 833 * nV - 400 * nU;
176 + int nB = 1192 * nY + 2066 * nU;
177 +
178 + nR = MIN(kMaxChannelValue, MAX(0, nR));
179 + nG = MIN(kMaxChannelValue, MAX(0, nG));
180 + nB = MIN(kMaxChannelValue, MAX(0, nB));
181 +
182 + // Shift more than for ARGB8888 and apply appropriate bitmask.
183 + nR = (nR >> 13) & 0x1f;
184 + nG = (nG >> 12) & 0x3f;
185 + nB = (nB >> 13) & 0x1f;
186 +
187 + // R is high 5 bits, G is middle 6 bits, and B is low 5 bits.
188 + *out++ = (nR << 11) | (nG << 5) | nB;
189 + }
190 + }
191 +}
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +// This is a collection of routines which converts various YUV image formats
17 +// to (A)RGB.
18 +
19 +#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
20 +#define TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
21 +
22 +#include <stdint.h>
23 +
24 +#ifdef __cplusplus
25 +extern "C" {
26 +#endif
27 +
28 +void ConvertYUV420ToARGB8888(const uint8_t* const yData,
29 + const uint8_t* const uData,
30 + const uint8_t* const vData, uint32_t* const output,
31 + const int width, const int height,
32 + const int y_row_stride, const int uv_row_stride,
33 + const int uv_pixel_stride);
34 +
35 +// Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
36 +// and height. The input and output must already be allocated and non-null.
37 +// For efficiency, no error checking is performed.
38 +void ConvertYUV420SPToARGB8888(const uint8_t* const pY,
39 + const uint8_t* const pUV, uint32_t* const output,
40 + const int width, const int height);
41 +
42 +// The same as above, but downsamples each dimension to half size.
43 +void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input,
44 + uint32_t* const output, int width,
45 + int height);
46 +
47 +// Converts YUV420 semi-planar data to RGB 565 data using the supplied width
48 +// and height. The input and output must already be allocated and non-null.
49 +// For efficiency, no error checking is performed.
50 +void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
51 + const int width, const int height);
52 +
53 +#ifdef __cplusplus
54 +}
55 +#endif
56 +
57 +#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<set xmlns:android="http://schemas.android.com/apk/res/android"
17 + android:ordering="sequentially">
18 + <objectAnimator
19 + android:propertyName="backgroundColor"
20 + android:duration="375"
21 + android:valueFrom="0x00b3ccff"
22 + android:valueTo="0xffb3ccff"
23 + android:valueType="colorType"/>
24 + <objectAnimator
25 + android:propertyName="backgroundColor"
26 + android:duration="375"
27 + android:valueFrom="0xffb3ccff"
28 + android:valueTo="0x00b3ccff"
29 + android:valueType="colorType"/>
30 +</set>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle" >
17 + <solid android:color="#00000000" />
18 + <stroke android:width="1dip" android:color="#cccccc" />
19 +</shape>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 + xmlns:tools="http://schemas.android.com/tools"
18 + android:id="@+id/container"
19 + android:layout_width="match_parent"
20 + android:layout_height="match_parent"
21 + android:background="#000"
22 + tools:context="org.tensorflow.demo.CameraActivity" />
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<FrameLayout
17 + xmlns:android="http://schemas.android.com/apk/res/android"
18 + xmlns:app="http://schemas.android.com/apk/res-auto"
19 + xmlns:tools="http://schemas.android.com/tools"
20 + android:layout_width="match_parent"
21 + android:layout_height="match_parent"
22 + tools:context="org.tensorflow.demo.SpeechActivity">
23 +
24 + <TextView
25 + android:layout_width="wrap_content"
26 + android:layout_height="wrap_content"
27 + android:text="Say one of the words below!"
28 + android:id="@+id/textView"
29 + android:textAlignment="center"
30 + android:layout_gravity="top"
31 + android:textSize="24dp"
32 + android:layout_marginTop="10dp"
33 + android:layout_marginLeft="10dp"
34 + />
35 +
36 + <ListView
37 + android:id="@+id/list_view"
38 + android:layout_width="240dp"
39 + android:layout_height="wrap_content"
40 + android:background="@drawable/border"
41 + android:layout_gravity="top|center_horizontal"
42 + android:textAlignment="center"
43 + android:layout_marginTop="100dp"
44 + />
45 +
46 + <Button
47 + android:id="@+id/quit"
48 + android:layout_width="wrap_content"
49 + android:layout_height="wrap_content"
50 + android:text="Quit"
51 + android:layout_gravity="bottom|center_horizontal"
52 + android:layout_marginBottom="10dp"
53 + />
54 +
55 +</FrameLayout>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 + android:layout_width="match_parent"
18 + android:layout_height="match_parent">
19 +
20 + <org.tensorflow.demo.AutoFitTextureView
21 + android:id="@+id/texture"
22 + android:layout_width="wrap_content"
23 + android:layout_height="wrap_content"
24 + android:layout_alignParentBottom="true" />
25 +
26 + <org.tensorflow.demo.RecognitionScoreView
27 + android:id="@+id/results"
28 + android:layout_width="match_parent"
29 + android:layout_height="112dp"
30 + android:layout_alignParentTop="true" />
31 +
32 + <org.tensorflow.demo.OverlayView
33 + android:id="@+id/debug_overlay"
34 + android:layout_width="match_parent"
35 + android:layout_height="match_parent"
36 + android:layout_alignParentBottom="true" />
37 +
38 +</RelativeLayout>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 + android:orientation="vertical"
18 + android:layout_width="match_parent"
19 + android:layout_height="match_parent">
20 + <org.tensorflow.demo.AutoFitTextureView
21 + android:id="@+id/texture"
22 + android:layout_width="wrap_content"
23 + android:layout_height="wrap_content"
24 + android:layout_alignParentTop="true" />
25 +
26 + <RelativeLayout
27 + android:id="@+id/black"
28 + android:layout_width="match_parent"
29 + android:layout_height="match_parent"
30 + android:background="#FF000000" />
31 +
32 + <GridView
33 + android:id="@+id/grid_layout"
34 + android:numColumns="7"
35 + android:stretchMode="columnWidth"
36 + android:layout_alignParentBottom="true"
37 + android:layout_width="match_parent"
38 + android:layout_height="wrap_content" />
39 +
40 + <org.tensorflow.demo.OverlayView
41 + android:id="@+id/overlay"
42 + android:layout_width="match_parent"
43 + android:layout_height="match_parent"
44 + android:layout_alignParentTop="true" />
45 +
46 + <org.tensorflow.demo.OverlayView
47 + android:id="@+id/debug_overlay"
48 + android:layout_width="match_parent"
49 + android:layout_height="match_parent"
50 + android:layout_alignParentTop="true" />
51 +</RelativeLayout>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 + android:layout_width="match_parent"
18 + android:layout_height="match_parent">
19 +
20 + <org.tensorflow.demo.AutoFitTextureView
21 + android:id="@+id/texture"
22 + android:layout_width="wrap_content"
23 + android:layout_height="wrap_content"/>
24 +
25 + <org.tensorflow.demo.OverlayView
26 + android:id="@+id/tracking_overlay"
27 + android:layout_width="match_parent"
28 + android:layout_height="match_parent"/>
29 +
30 + <org.tensorflow.demo.OverlayView
31 + android:id="@+id/debug_overlay"
32 + android:layout_width="match_parent"
33 + android:layout_height="match_parent"/>
34 +</FrameLayout>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<TextView
17 + xmlns:android="http://schemas.android.com/apk/res/android"
18 + android:id="@+id/list_text_item"
19 + android:layout_width="match_parent"
20 + android:layout_height="wrap_content"
21 + android:text="TextView"
22 + android:textSize="24dp"
23 + android:textAlignment="center"
24 + android:gravity="center_horizontal"
25 + />
1 +<!--
2 + Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 + -->
16 +
17 +<resources>
18 +
19 + <!-- Semantic definitions -->
20 +
21 + <dimen name="horizontal_page_margin">@dimen/margin_huge</dimen>
22 + <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
23 +
24 +</resources>
1 +<!--
2 + Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 + -->
16 +
17 +<resources>
18 +
19 + <style name="Widget.SampleMessage">
20 + <item name="android:textAppearance">?android:textAppearanceLarge</item>
21 + <item name="android:lineSpacingMultiplier">1.2</item>
22 + <item name="android:shadowDy">-6.5</item>
23 + </style>
24 +
25 +</resources>
1 +<?xml version="1.0" encoding="utf-8"?>
2 +<resources>
3 +
4 + <!--
5 + Base application theme for API 11+. This theme completely replaces
6 + AppBaseTheme from res/values/styles.xml on API 11+ devices.
7 + -->
8 + <style name="AppBaseTheme" parent="android:Theme.Holo.Light">
9 + <!-- API 11 theme customizations can go here. -->
10 + </style>
11 +
12 + <style name="FullscreenTheme" parent="android:Theme.Holo">
13 + <item name="android:actionBarStyle">@style/FullscreenActionBarStyle</item>
14 + <item name="android:windowActionBarOverlay">true</item>
15 + <item name="android:windowBackground">@null</item>
16 + <item name="metaButtonBarStyle">?android:attr/buttonBarStyle</item>
17 + <item name="metaButtonBarButtonStyle">?android:attr/buttonBarButtonStyle</item>
18 + </style>
19 +
20 + <style name="FullscreenActionBarStyle" parent="android:Widget.Holo.ActionBar">
21 + <!-- <item name="android:background">@color/black_overlay</item> -->
22 + </style>
23 +
24 +</resources>
1 +<!--
2 + Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 + -->
16 +
17 +<resources>
18 +
19 + <!-- Activity themes -->
20 + <style name="Theme.Base" parent="android:Theme.Holo.Light" />
21 +
22 +</resources>
1 +<resources>
2 +
3 + <!--
4 + Base application theme for API 14+. This theme completely replaces
5 + AppBaseTheme from BOTH res/values/styles.xml and
6 + res/values-v11/styles.xml on API 14+ devices.
7 + -->
8 + <style name="AppBaseTheme" parent="android:Theme.Holo.Light.DarkActionBar">
9 + <!-- API 14 theme customizations can go here. -->
10 + </style>
11 +
12 +</resources>
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<!--
3 + Copyright 2013 The TensorFlow Authors. All Rights Reserved.
4 +
5 + Licensed under the Apache License, Version 2.0 (the "License");
6 + you may not use this file except in compliance with the License.
7 + You may obtain a copy of the License at
8 +
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +
11 + Unless required by applicable law or agreed to in writing, software
12 + distributed under the License is distributed on an "AS IS" BASIS,
13 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 + See the License for the specific language governing permissions and
15 + limitations under the License.
16 +-->
17 +
18 +<resources>
19 +
20 +
21 +</resources>
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<!--
3 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 +
5 + Licensed under the Apache License, Version 2.0 (the "License");
6 + you may not use this file except in compliance with the License.
7 + You may obtain a copy of the License at
8 +
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +
11 + Unless required by applicable law or agreed to in writing, software
12 + distributed under the License is distributed on an "AS IS" BASIS,
13 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 + See the License for the specific language governing permissions and
15 + limitations under the License.
16 +-->
17 +
18 +<resources>
19 +
20 + <!-- Activity themes -->
21 + <style name="Theme.Base" parent="android:Theme.Material.Light">
22 + </style>
23 +
24 +</resources>
1 +<resources>
2 +
3 + <!--
4 + Declare custom theme attributes that allow changing which styles are
5 + used for button bars depending on the API level.
6 + ?android:attr/buttonBarStyle is new as of API 11 so this is
7 + necessary to support previous API levels.
8 + -->
9 + <declare-styleable name="ButtonBarContainerTheme">
10 + <attr name="metaButtonBarStyle" format="reference" />
11 + <attr name="metaButtonBarButtonStyle" format="reference" />
12 + </declare-styleable>
13 +
14 +</resources>
1 +<?xml version="1.0" encoding="UTF-8"?>
2 +<!--
3 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 +
5 + Licensed under the Apache License, Version 2.0 (the "License");
6 + you may not use this file except in compliance with the License.
7 + You may obtain a copy of the License at
8 +
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +
11 + Unless required by applicable law or agreed to in writing, software
12 + distributed under the License is distributed on an "AS IS" BASIS,
13 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 + See the License for the specific language governing permissions and
15 + limitations under the License.
16 +-->
17 +
18 +<resources>
19 + <string name="app_name">TensorFlow Demo</string>
20 + <string name="activity_name_classification">TF Classify</string>
21 + <string name="activity_name_detection">TF Detect</string>
22 + <string name="activity_name_stylize">TF Stylize</string>
23 + <string name="activity_name_speech">TF Speech</string>
24 +</resources>
1 +<?xml version="1.0" encoding="utf-8"?>
2 +<!--
3 + Copyright 2015 The TensorFlow Authors. All Rights Reserved.
4 +
5 + Licensed under the Apache License, Version 2.0 (the "License");
6 + you may not use this file except in compliance with the License.
7 + You may obtain a copy of the License at
8 +
9 + http://www.apache.org/licenses/LICENSE-2.0
10 +
11 + Unless required by applicable law or agreed to in writing, software
12 + distributed under the License is distributed on an "AS IS" BASIS,
13 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 + See the License for the specific language governing permissions and
15 + limitations under the License.
16 +-->
17 +<resources>
18 + <color name="control_background">#cc4285f4</color>
19 +</resources>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<resources>
17 + <string name="description_info">Info</string>
18 + <string name="request_permission">This sample needs camera permission.</string>
19 + <string name="camera_error">This device doesn\'t support Camera2 API.</string>
20 +</resources>
1 +<?xml version="1.0" encoding="utf-8"?><!--
2 + Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 +-->
16 +<resources>
17 + <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
18 +</resources>
1 +<!--
2 + Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 + -->
16 +
17 +<resources>
18 +
19 + <!-- Define standard dimensions to comply with Holo-style grids and rhythm. -->
20 +
21 + <dimen name="margin_tiny">4dp</dimen>
22 + <dimen name="margin_small">8dp</dimen>
23 + <dimen name="margin_medium">16dp</dimen>
24 + <dimen name="margin_large">32dp</dimen>
25 + <dimen name="margin_huge">64dp</dimen>
26 +
27 + <!-- Semantic definitions -->
28 +
29 + <dimen name="horizontal_page_margin">@dimen/margin_medium</dimen>
30 + <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
31 +
32 +</resources>
1 +<!--
2 + Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 +
4 + Licensed under the Apache License, Version 2.0 (the "License");
5 + you may not use this file except in compliance with the License.
6 + You may obtain a copy of the License at
7 +
8 + http://www.apache.org/licenses/LICENSE-2.0
9 +
10 + Unless required by applicable law or agreed to in writing, software
11 + distributed under the License is distributed on an "AS IS" BASIS,
12 + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + See the License for the specific language governing permissions and
14 + limitations under the License.
15 + -->
16 +
17 +<resources>
18 +
19 + <!-- Activity themes -->
20 +
21 + <style name="Theme.Base" parent="android:Theme.Light" />
22 +
23 + <style name="Theme.Sample" parent="Theme.Base" />
24 +
25 + <style name="AppTheme" parent="Theme.Sample" />
26 + <!-- Widget styling -->
27 +
28 + <style name="Widget" />
29 +
30 + <style name="Widget.SampleMessage">
31 + <item name="android:textAppearance">?android:textAppearanceMedium</item>
32 + <item name="android:lineSpacingMultiplier">1.1</item>
33 + </style>
34 +
35 + <style name="Widget.SampleMessageTile">
36 + <item name="android:background">@drawable/tile</item>
37 + <item name="android:shadowColor">#7F000000</item>
38 + <item name="android:shadowDy">-3.5</item>
39 + <item name="android:shadowRadius">2</item>
40 + </style>
41 +
42 +</resources>
1 +/*
2 + * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.content.Context;
20 +import android.util.AttributeSet;
21 +import android.view.TextureView;
22 +
23 +/**
24 + * A {@link TextureView} that can be adjusted to a specified aspect ratio.
25 + */
26 +public class AutoFitTextureView extends TextureView {
27 + private int ratioWidth = 0;
28 + private int ratioHeight = 0;
29 +
30 + public AutoFitTextureView(final Context context) {
31 + this(context, null);
32 + }
33 +
34 + public AutoFitTextureView(final Context context, final AttributeSet attrs) {
35 + this(context, attrs, 0);
36 + }
37 +
38 + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) {
39 + super(context, attrs, defStyle);
40 + }
41 +
42 + /**
43 + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio
44 + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that
45 + * is, calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result.
46 + *
47 + * @param width Relative horizontal size
48 + * @param height Relative vertical size
49 + */
50 + public void setAspectRatio(final int width, final int height) {
51 + if (width < 0 || height < 0) {
52 + throw new IllegalArgumentException("Size cannot be negative.");
53 + }
54 + ratioWidth = width;
55 + ratioHeight = height;
56 + requestLayout();
57 + }
58 +
59 + @Override
60 + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
61 + super.onMeasure(widthMeasureSpec, heightMeasureSpec);
62 + final int width = MeasureSpec.getSize(widthMeasureSpec);
63 + final int height = MeasureSpec.getSize(heightMeasureSpec);
64 + if (0 == ratioWidth || 0 == ratioHeight) {
65 + setMeasuredDimension(width, height);
66 + } else {
67 + if (width < height * ratioWidth / ratioHeight) {
68 + setMeasuredDimension(width, width * ratioHeight / ratioWidth);
69 + } else {
70 + setMeasuredDimension(height * ratioWidth / ratioHeight, height);
71 + }
72 + }
73 + }
74 +}
1 +/*
2 + * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.Manifest;
20 +import android.app.Activity;
21 +import android.app.Fragment;
22 +import android.content.Context;
23 +import android.content.pm.PackageManager;
24 +import android.hardware.Camera;
25 +import android.hardware.camera2.CameraAccessException;
26 +import android.hardware.camera2.CameraCharacteristics;
27 +import android.hardware.camera2.CameraManager;
28 +import android.hardware.camera2.params.StreamConfigurationMap;
29 +import android.media.Image;
30 +import android.media.Image.Plane;
31 +import android.media.ImageReader;
32 +import android.media.ImageReader.OnImageAvailableListener;
33 +import android.os.Build;
34 +import android.os.Bundle;
35 +import android.os.Handler;
36 +import android.os.HandlerThread;
37 +import android.os.Trace;
38 +import android.util.Size;
39 +import android.view.KeyEvent;
40 +import android.view.Surface;
41 +import android.view.WindowManager;
42 +import android.widget.Toast;
43 +import java.nio.ByteBuffer;
44 +import org.tensorflow.demo.env.ImageUtils;
45 +import org.tensorflow.demo.env.Logger;
46 +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
47 +
48 +public abstract class CameraActivity extends Activity
49 + implements OnImageAvailableListener, Camera.PreviewCallback {
50 + private static final Logger LOGGER = new Logger();
51 +
52 + private static final int PERMISSIONS_REQUEST = 1;
53 +
54 + private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
55 + private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
56 +
57 + private boolean debug = false;
58 +
59 + private Handler handler;
60 + private HandlerThread handlerThread;
61 + private boolean useCamera2API;
62 + private boolean isProcessingFrame = false;
63 + private byte[][] yuvBytes = new byte[3][];
64 + private int[] rgbBytes = null;
65 + private int yRowStride;
66 +
67 + protected int previewWidth = 0;
68 + protected int previewHeight = 0;
69 +
70 + private Runnable postInferenceCallback;
71 + private Runnable imageConverter;
72 +
73 + @Override
74 + protected void onCreate(final Bundle savedInstanceState) {
75 + LOGGER.d("onCreate " + this);
76 + super.onCreate(null);
77 + getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
78 +
79 + setContentView(R.layout.activity_camera);
80 +
81 + if (hasPermission()) {
82 + setFragment();
83 + } else {
84 + requestPermission();
85 + }
86 + }
87 +
88 + private byte[] lastPreviewFrame;
89 +
90 + protected int[] getRgbBytes() {
91 + imageConverter.run();
92 + return rgbBytes;
93 + }
94 +
95 + protected int getLuminanceStride() {
96 + return yRowStride;
97 + }
98 +
99 + protected byte[] getLuminance() {
100 + return yuvBytes[0];
101 + }
102 +
103 + /**
104 + * Callback for android.hardware.Camera API
105 + */
106 + @Override
107 + public void onPreviewFrame(final byte[] bytes, final Camera camera) {
108 + if (isProcessingFrame) {
109 + LOGGER.w("Dropping frame!");
110 + return;
111 + }
112 +
113 + try {
114 + // Initialize the storage bitmaps once when the resolution is known.
115 + if (rgbBytes == null) {
116 + Camera.Size previewSize = camera.getParameters().getPreviewSize();
117 + previewHeight = previewSize.height;
118 + previewWidth = previewSize.width;
119 + rgbBytes = new int[previewWidth * previewHeight];
120 + onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90);
121 + }
122 + } catch (final Exception e) {
123 + LOGGER.e(e, "Exception!");
124 + return;
125 + }
126 +
127 + isProcessingFrame = true;
128 + lastPreviewFrame = bytes;
129 + yuvBytes[0] = bytes;
130 + yRowStride = previewWidth;
131 +
132 + imageConverter =
133 + new Runnable() {
134 + @Override
135 + public void run() {
136 + ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes);
137 + }
138 + };
139 +
140 + postInferenceCallback =
141 + new Runnable() {
142 + @Override
143 + public void run() {
144 + camera.addCallbackBuffer(bytes);
145 + isProcessingFrame = false;
146 + }
147 + };
148 + processImage();
149 + }
150 +
151 + /**
152 + * Callback for Camera2 API
153 + */
154 + @Override
155 + public void onImageAvailable(final ImageReader reader) {
156 + //We need wait until we have some size from onPreviewSizeChosen
157 + if (previewWidth == 0 || previewHeight == 0) {
158 + return;
159 + }
160 + if (rgbBytes == null) {
161 + rgbBytes = new int[previewWidth * previewHeight];
162 + }
163 + try {
164 + final Image image = reader.acquireLatestImage();
165 +
166 + if (image == null) {
167 + return;
168 + }
169 +
170 + if (isProcessingFrame) {
171 + image.close();
172 + return;
173 + }
174 + isProcessingFrame = true;
175 + Trace.beginSection("imageAvailable");
176 + final Plane[] planes = image.getPlanes();
177 + fillBytes(planes, yuvBytes);
178 + yRowStride = planes[0].getRowStride();
179 + final int uvRowStride = planes[1].getRowStride();
180 + final int uvPixelStride = planes[1].getPixelStride();
181 +
182 + imageConverter =
183 + new Runnable() {
184 + @Override
185 + public void run() {
186 + ImageUtils.convertYUV420ToARGB8888(
187 + yuvBytes[0],
188 + yuvBytes[1],
189 + yuvBytes[2],
190 + previewWidth,
191 + previewHeight,
192 + yRowStride,
193 + uvRowStride,
194 + uvPixelStride,
195 + rgbBytes);
196 + }
197 + };
198 +
199 + postInferenceCallback =
200 + new Runnable() {
201 + @Override
202 + public void run() {
203 + image.close();
204 + isProcessingFrame = false;
205 + }
206 + };
207 +
208 + processImage();
209 + } catch (final Exception e) {
210 + LOGGER.e(e, "Exception!");
211 + Trace.endSection();
212 + return;
213 + }
214 + Trace.endSection();
215 + }
216 +
217 + @Override
218 + public synchronized void onStart() {
219 + LOGGER.d("onStart " + this);
220 + super.onStart();
221 + }
222 +
223 + @Override
224 + public synchronized void onResume() {
225 + LOGGER.d("onResume " + this);
226 + super.onResume();
227 +
228 + handlerThread = new HandlerThread("inference");
229 + handlerThread.start();
230 + handler = new Handler(handlerThread.getLooper());
231 + }
232 +
233 + @Override
234 + public synchronized void onPause() {
235 + LOGGER.d("onPause " + this);
236 +
237 + if (!isFinishing()) {
238 + LOGGER.d("Requesting finish");
239 + finish();
240 + }
241 +
242 + handlerThread.quitSafely();
243 + try {
244 + handlerThread.join();
245 + handlerThread = null;
246 + handler = null;
247 + } catch (final InterruptedException e) {
248 + LOGGER.e(e, "Exception!");
249 + }
250 +
251 + super.onPause();
252 + }
253 +
254 + @Override
255 + public synchronized void onStop() {
256 + LOGGER.d("onStop " + this);
257 + super.onStop();
258 + }
259 +
260 + @Override
261 + public synchronized void onDestroy() {
262 + LOGGER.d("onDestroy " + this);
263 + super.onDestroy();
264 + }
265 +
266 + protected synchronized void runInBackground(final Runnable r) {
267 + if (handler != null) {
268 + handler.post(r);
269 + }
270 + }
271 +
272 + @Override
273 + public void onRequestPermissionsResult(
274 + final int requestCode, final String[] permissions, final int[] grantResults) {
275 + if (requestCode == PERMISSIONS_REQUEST) {
276 + if (grantResults.length > 0
277 + && grantResults[0] == PackageManager.PERMISSION_GRANTED
278 + && grantResults[1] == PackageManager.PERMISSION_GRANTED) {
279 + setFragment();
280 + } else {
281 + requestPermission();
282 + }
283 + }
284 + }
285 +
286 + private boolean hasPermission() {
287 + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
288 + return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED &&
289 + checkSelfPermission(PERMISSION_STORAGE) == PackageManager.PERMISSION_GRANTED;
290 + } else {
291 + return true;
292 + }
293 + }
294 +
295 + private void requestPermission() {
296 + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
297 + if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA) ||
298 + shouldShowRequestPermissionRationale(PERMISSION_STORAGE)) {
299 + Toast.makeText(CameraActivity.this,
300 + "Camera AND storage permission are required for this demo", Toast.LENGTH_LONG).show();
301 + }
302 + requestPermissions(new String[] {PERMISSION_CAMERA, PERMISSION_STORAGE}, PERMISSIONS_REQUEST);
303 + }
304 + }
305 +
306 + // Returns true if the device supports the required hardware level, or better.
307 + private boolean isHardwareLevelSupported(
308 + CameraCharacteristics characteristics, int requiredLevel) {
309 + int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL);
310 + if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) {
311 + return requiredLevel == deviceLevel;
312 + }
313 + // deviceLevel is not LEGACY, can use numerical sort
314 + return requiredLevel <= deviceLevel;
315 + }
316 +
317 + private String chooseCamera() {
318 + final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE);
319 + try {
320 + for (final String cameraId : manager.getCameraIdList()) {
321 + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
322 +
323 + // We don't use a front facing camera in this sample.
324 + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING);
325 + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) {
326 + continue;
327 + }
328 +
329 + final StreamConfigurationMap map =
330 + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
331 +
332 + if (map == null) {
333 + continue;
334 + }
335 +
336 + // Fallback to camera1 API for internal cameras that don't have full support.
337 + // This should help with legacy situations where using the camera2 API causes
338 + // distorted or otherwise broken previews.
339 + useCamera2API = (facing == CameraCharacteristics.LENS_FACING_EXTERNAL)
340 + || isHardwareLevelSupported(characteristics,
341 + CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL);
342 + LOGGER.i("Camera API lv2?: %s", useCamera2API);
343 + return cameraId;
344 + }
345 + } catch (CameraAccessException e) {
346 + LOGGER.e(e, "Not allowed to access camera");
347 + }
348 +
349 + return null;
350 + }
351 +
352 + protected void setFragment() {
353 + String cameraId = chooseCamera();
354 + if (cameraId == null) {
355 + Toast.makeText(this, "No Camera Detected", Toast.LENGTH_SHORT).show();
356 + finish();
357 + }
358 +
359 + Fragment fragment;
360 + if (useCamera2API) {
361 + CameraConnectionFragment camera2Fragment =
362 + CameraConnectionFragment.newInstance(
363 + new CameraConnectionFragment.ConnectionCallback() {
364 + @Override
365 + public void onPreviewSizeChosen(final Size size, final int rotation) {
366 + previewHeight = size.getHeight();
367 + previewWidth = size.getWidth();
368 + CameraActivity.this.onPreviewSizeChosen(size, rotation);
369 + }
370 + },
371 + this,
372 + getLayoutId(),
373 + getDesiredPreviewFrameSize());
374 +
375 + camera2Fragment.setCamera(cameraId);
376 + fragment = camera2Fragment;
377 + } else {
378 + fragment =
379 + new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize());
380 + }
381 +
382 + getFragmentManager()
383 + .beginTransaction()
384 + .replace(R.id.container, fragment)
385 + .commit();
386 + }
387 +
388 + protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) {
389 + // Because of the variable row stride it's not possible to know in
390 + // advance the actual necessary dimensions of the yuv planes.
391 + for (int i = 0; i < planes.length; ++i) {
392 + final ByteBuffer buffer = planes[i].getBuffer();
393 + if (yuvBytes[i] == null) {
394 + LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity());
395 + yuvBytes[i] = new byte[buffer.capacity()];
396 + }
397 + buffer.get(yuvBytes[i]);
398 + }
399 + }
400 +
401 + public boolean isDebug() {
402 + return debug;
403 + }
404 +
405 + public void requestRender() {
406 + final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay);
407 + if (overlay != null) {
408 + overlay.postInvalidate();
409 + }
410 + }
411 +
412 + public void addCallback(final OverlayView.DrawCallback callback) {
413 + final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay);
414 + if (overlay != null) {
415 + overlay.addCallback(callback);
416 + }
417 + }
418 +
419 + public void onSetDebug(final boolean debug) {}
420 +
421 + @Override
422 + public boolean onKeyDown(final int keyCode, final KeyEvent event) {
423 + if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP
424 + || keyCode == KeyEvent.KEYCODE_BUTTON_L1 || keyCode == KeyEvent.KEYCODE_DPAD_CENTER) {
425 + debug = !debug;
426 + requestRender();
427 + onSetDebug(debug);
428 + return true;
429 + }
430 + return super.onKeyDown(keyCode, event);
431 + }
432 +
433 + protected void readyForNextImage() {
434 + if (postInferenceCallback != null) {
435 + postInferenceCallback.run();
436 + }
437 + }
438 +
439 + protected int getScreenOrientation() {
440 + switch (getWindowManager().getDefaultDisplay().getRotation()) {
441 + case Surface.ROTATION_270:
442 + return 270;
443 + case Surface.ROTATION_180:
444 + return 180;
445 + case Surface.ROTATION_90:
446 + return 90;
447 + default:
448 + return 0;
449 + }
450 + }
451 +
452 + protected abstract void processImage();
453 +
454 + protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
455 + protected abstract int getLayoutId();
456 + protected abstract Size getDesiredPreviewFrameSize();
457 +}
1 +/*
2 + * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.app.Activity;
20 +import android.app.AlertDialog;
21 +import android.app.Dialog;
22 +import android.app.DialogFragment;
23 +import android.app.Fragment;
24 +import android.content.Context;
25 +import android.content.DialogInterface;
26 +import android.content.res.Configuration;
27 +import android.graphics.ImageFormat;
28 +import android.graphics.Matrix;
29 +import android.graphics.RectF;
30 +import android.graphics.SurfaceTexture;
31 +import android.hardware.camera2.CameraAccessException;
32 +import android.hardware.camera2.CameraCaptureSession;
33 +import android.hardware.camera2.CameraCharacteristics;
34 +import android.hardware.camera2.CameraDevice;
35 +import android.hardware.camera2.CameraManager;
36 +import android.hardware.camera2.CaptureRequest;
37 +import android.hardware.camera2.CaptureResult;
38 +import android.hardware.camera2.TotalCaptureResult;
39 +import android.hardware.camera2.params.StreamConfigurationMap;
40 +import android.media.ImageReader;
41 +import android.media.ImageReader.OnImageAvailableListener;
42 +import android.os.Bundle;
43 +import android.os.Handler;
44 +import android.os.HandlerThread;
45 +import android.text.TextUtils;
46 +import android.util.Size;
47 +import android.util.SparseIntArray;
48 +import android.view.LayoutInflater;
49 +import android.view.Surface;
50 +import android.view.TextureView;
51 +import android.view.View;
52 +import android.view.ViewGroup;
53 +import android.widget.Toast;
54 +import java.util.ArrayList;
55 +import java.util.Arrays;
56 +import java.util.Collections;
57 +import java.util.Comparator;
58 +import java.util.List;
59 +import java.util.concurrent.Semaphore;
60 +import java.util.concurrent.TimeUnit;
61 +import org.tensorflow.demo.env.Logger;
62 +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
63 +
64 +public class CameraConnectionFragment extends Fragment {
65 + private static final Logger LOGGER = new Logger();
66 +
67 + /**
68 + * The camera preview size will be chosen to be the smallest frame by pixel size capable of
69 + * containing a DESIRED_SIZE x DESIRED_SIZE square.
70 + */
71 + private static final int MINIMUM_PREVIEW_SIZE = 320;
72 +
73 + /**
74 + * Conversion from screen rotation to JPEG orientation.
75 + */
76 + private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
77 + private static final String FRAGMENT_DIALOG = "dialog";
78 +
79 + static {
80 + ORIENTATIONS.append(Surface.ROTATION_0, 90);
81 + ORIENTATIONS.append(Surface.ROTATION_90, 0);
82 + ORIENTATIONS.append(Surface.ROTATION_180, 270);
83 + ORIENTATIONS.append(Surface.ROTATION_270, 180);
84 + }
85 +
86 + /**
87 + * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a
88 + * {@link TextureView}.
89 + */
90 + private final TextureView.SurfaceTextureListener surfaceTextureListener =
91 + new TextureView.SurfaceTextureListener() {
92 + @Override
93 + public void onSurfaceTextureAvailable(
94 + final SurfaceTexture texture, final int width, final int height) {
95 + openCamera(width, height);
96 + }
97 +
98 + @Override
99 + public void onSurfaceTextureSizeChanged(
100 + final SurfaceTexture texture, final int width, final int height) {
101 + configureTransform(width, height);
102 + }
103 +
104 + @Override
105 + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) {
106 + return true;
107 + }
108 +
109 + @Override
110 + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
111 + };
112 +
113 + /**
114 + * Callback for Activities to use to initialize their data once the
115 + * selected preview size is known.
116 + */
117 + public interface ConnectionCallback {
118 + void onPreviewSizeChosen(Size size, int cameraRotation);
119 + }
120 +
121 + /**
122 + * ID of the current {@link CameraDevice}.
123 + */
124 + private String cameraId;
125 +
126 + /**
127 + * An {@link AutoFitTextureView} for camera preview.
128 + */
129 + private AutoFitTextureView textureView;
130 +
131 + /**
132 + * A {@link CameraCaptureSession } for camera preview.
133 + */
134 + private CameraCaptureSession captureSession;
135 +
136 + /**
137 + * A reference to the opened {@link CameraDevice}.
138 + */
139 + private CameraDevice cameraDevice;
140 +
141 + /**
142 + * The rotation in degrees of the camera sensor from the display.
143 + */
144 + private Integer sensorOrientation;
145 +
146 + /**
147 + * The {@link android.util.Size} of camera preview.
148 + */
149 + private Size previewSize;
150 +
151 + /**
152 + * {@link android.hardware.camera2.CameraDevice.StateCallback}
153 + * is called when {@link CameraDevice} changes its state.
154 + */
155 + private final CameraDevice.StateCallback stateCallback =
156 + new CameraDevice.StateCallback() {
157 + @Override
158 + public void onOpened(final CameraDevice cd) {
159 + // This method is called when the camera is opened. We start camera preview here.
160 + cameraOpenCloseLock.release();
161 + cameraDevice = cd;
162 + createCameraPreviewSession();
163 + }
164 +
165 + @Override
166 + public void onDisconnected(final CameraDevice cd) {
167 + cameraOpenCloseLock.release();
168 + cd.close();
169 + cameraDevice = null;
170 + }
171 +
172 + @Override
173 + public void onError(final CameraDevice cd, final int error) {
174 + cameraOpenCloseLock.release();
175 + cd.close();
176 + cameraDevice = null;
177 + final Activity activity = getActivity();
178 + if (null != activity) {
179 + activity.finish();
180 + }
181 + }
182 + };
183 +
184 + /**
185 + * An additional thread for running tasks that shouldn't block the UI.
186 + */
187 + private HandlerThread backgroundThread;
188 +
189 + /**
190 + * A {@link Handler} for running tasks in the background.
191 + */
192 + private Handler backgroundHandler;
193 +
194 + /**
195 + * An {@link ImageReader} that handles preview frame capture.
196 + */
197 + private ImageReader previewReader;
198 +
199 + /**
200 + * {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview
201 + */
202 + private CaptureRequest.Builder previewRequestBuilder;
203 +
204 + /**
205 + * {@link CaptureRequest} generated by {@link #previewRequestBuilder}
206 + */
207 + private CaptureRequest previewRequest;
208 +
209 + /**
210 + * A {@link Semaphore} to prevent the app from exiting before closing the camera.
211 + */
212 + private final Semaphore cameraOpenCloseLock = new Semaphore(1);
213 +
214 + /**
215 + * A {@link OnImageAvailableListener} to receive frames as they are available.
216 + */
217 + private final OnImageAvailableListener imageListener;
218 +
219 + /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */
220 + private final Size inputSize;
221 +
222 + /**
223 + * The layout identifier to inflate for this Fragment.
224 + */
225 + private final int layout;
226 +
227 +
228 + private final ConnectionCallback cameraConnectionCallback;
229 +
230 + private CameraConnectionFragment(
231 + final ConnectionCallback connectionCallback,
232 + final OnImageAvailableListener imageListener,
233 + final int layout,
234 + final Size inputSize) {
235 + this.cameraConnectionCallback = connectionCallback;
236 + this.imageListener = imageListener;
237 + this.layout = layout;
238 + this.inputSize = inputSize;
239 + }
240 +
241 + /**
242 + * Shows a {@link Toast} on the UI thread.
243 + *
244 + * @param text The message to show
245 + */
246 + private void showToast(final String text) {
247 + final Activity activity = getActivity();
248 + if (activity != null) {
249 + activity.runOnUiThread(
250 + new Runnable() {
251 + @Override
252 + public void run() {
253 + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show();
254 + }
255 + });
256 + }
257 + }
258 +
259 + /**
260 + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose
261 + * width and height are at least as large as the minimum of both, or an exact match if possible.
262 + *
263 + * @param choices The list of sizes that the camera supports for the intended output class
264 + * @param width The minimum desired width
265 + * @param height The minimum desired height
266 + * @return The optimal {@code Size}, or an arbitrary one if none were big enough
267 + */
268 + protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) {
269 + final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE);
270 + final Size desiredSize = new Size(width, height);
271 +
272 + // Collect the supported resolutions that are at least as big as the preview Surface
273 + boolean exactSizeFound = false;
274 + final List<Size> bigEnough = new ArrayList<Size>();
275 + final List<Size> tooSmall = new ArrayList<Size>();
276 + for (final Size option : choices) {
277 + if (option.equals(desiredSize)) {
278 + // Set the size but don't return yet so that remaining sizes will still be logged.
279 + exactSizeFound = true;
280 + }
281 +
282 + if (option.getHeight() >= minSize && option.getWidth() >= minSize) {
283 + bigEnough.add(option);
284 + } else {
285 + tooSmall.add(option);
286 + }
287 + }
288 +
289 + LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize);
290 + LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]");
291 + LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]");
292 +
293 + if (exactSizeFound) {
294 + LOGGER.i("Exact size match found.");
295 + return desiredSize;
296 + }
297 +
298 + // Pick the smallest of those, assuming we found any
299 + if (bigEnough.size() > 0) {
300 + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea());
301 + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight());
302 + return chosenSize;
303 + } else {
304 + LOGGER.e("Couldn't find any suitable preview size");
305 + return choices[0];
306 + }
307 + }
308 +
309 + public static CameraConnectionFragment newInstance(
310 + final ConnectionCallback callback,
311 + final OnImageAvailableListener imageListener,
312 + final int layout,
313 + final Size inputSize) {
314 + return new CameraConnectionFragment(callback, imageListener, layout, inputSize);
315 + }
316 +
317 + @Override
318 + public View onCreateView(
319 + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) {
320 + return inflater.inflate(layout, container, false);
321 + }
322 +
323 + @Override
324 + public void onViewCreated(final View view, final Bundle savedInstanceState) {
325 + textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
326 + }
327 +
328 + @Override
329 + public void onActivityCreated(final Bundle savedInstanceState) {
330 + super.onActivityCreated(savedInstanceState);
331 + }
332 +
333 + @Override
334 + public void onResume() {
335 + super.onResume();
336 + startBackgroundThread();
337 +
338 + // When the screen is turned off and turned back on, the SurfaceTexture is already
339 + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
340 + // a camera and start preview from here (otherwise, we wait until the surface is ready in
341 + // the SurfaceTextureListener).
342 + if (textureView.isAvailable()) {
343 + openCamera(textureView.getWidth(), textureView.getHeight());
344 + } else {
345 + textureView.setSurfaceTextureListener(surfaceTextureListener);
346 + }
347 + }
348 +
349 + @Override
350 + public void onPause() {
351 + closeCamera();
352 + stopBackgroundThread();
353 + super.onPause();
354 + }
355 +
356 + public void setCamera(String cameraId) {
357 + this.cameraId = cameraId;
358 + }
359 +
360 + /**
361 + * Sets up member variables related to camera.
362 + */
363 + private void setUpCameraOutputs() {
364 + final Activity activity = getActivity();
365 + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
366 + try {
367 + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
368 +
369 + final StreamConfigurationMap map =
370 + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
371 +
372 + // For still image captures, we use the largest available size.
373 + final Size largest =
374 + Collections.max(
375 + Arrays.asList(map.getOutputSizes(ImageFormat.YUV_420_888)),
376 + new CompareSizesByArea());
377 +
378 + sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION);
379 +
380 + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera
381 + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of
382 + // garbage capture data.
383 + previewSize =
384 + chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class),
385 + inputSize.getWidth(),
386 + inputSize.getHeight());
387 +
388 + // We fit the aspect ratio of TextureView to the size of preview we picked.
389 + final int orientation = getResources().getConfiguration().orientation;
390 + if (orientation == Configuration.ORIENTATION_LANDSCAPE) {
391 + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight());
392 + } else {
393 + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth());
394 + }
395 + } catch (final CameraAccessException e) {
396 + LOGGER.e(e, "Exception!");
397 + } catch (final NullPointerException e) {
398 + // Currently an NPE is thrown when the Camera2API is used but not supported on the
399 + // device this code runs.
400 + // TODO(andrewharp): abstract ErrorDialog/RuntimeException handling out into new method and
401 + // reuse throughout app.
402 + ErrorDialog.newInstance(getString(R.string.camera_error))
403 + .show(getChildFragmentManager(), FRAGMENT_DIALOG);
404 + throw new RuntimeException(getString(R.string.camera_error));
405 + }
406 +
407 + cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation);
408 + }
409 +
410 + /**
411 + * Opens the camera specified by {@link CameraConnectionFragment#cameraId}.
412 + */
413 + private void openCamera(final int width, final int height) {
414 + setUpCameraOutputs();
415 + configureTransform(width, height);
416 + final Activity activity = getActivity();
417 + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
418 + try {
419 + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
420 + throw new RuntimeException("Time out waiting to lock camera opening.");
421 + }
422 + manager.openCamera(cameraId, stateCallback, backgroundHandler);
423 + } catch (final CameraAccessException e) {
424 + LOGGER.e(e, "Exception!");
425 + } catch (final InterruptedException e) {
426 + throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
427 + }
428 + }
429 +
430 + /**
431 + * Closes the current {@link CameraDevice}.
432 + */
433 + private void closeCamera() {
434 + try {
435 + cameraOpenCloseLock.acquire();
436 + if (null != captureSession) {
437 + captureSession.close();
438 + captureSession = null;
439 + }
440 + if (null != cameraDevice) {
441 + cameraDevice.close();
442 + cameraDevice = null;
443 + }
444 + if (null != previewReader) {
445 + previewReader.close();
446 + previewReader = null;
447 + }
448 + } catch (final InterruptedException e) {
449 + throw new RuntimeException("Interrupted while trying to lock camera closing.", e);
450 + } finally {
451 + cameraOpenCloseLock.release();
452 + }
453 + }
454 +
455 + /**
456 + * Starts a background thread and its {@link Handler}.
457 + */
458 + private void startBackgroundThread() {
459 + backgroundThread = new HandlerThread("ImageListener");
460 + backgroundThread.start();
461 + backgroundHandler = new Handler(backgroundThread.getLooper());
462 + }
463 +
464 + /**
465 + * Stops the background thread and its {@link Handler}.
466 + */
467 + private void stopBackgroundThread() {
468 + backgroundThread.quitSafely();
469 + try {
470 + backgroundThread.join();
471 + backgroundThread = null;
472 + backgroundHandler = null;
473 + } catch (final InterruptedException e) {
474 + LOGGER.e(e, "Exception!");
475 + }
476 + }
477 +
478 + private final CameraCaptureSession.CaptureCallback captureCallback =
479 + new CameraCaptureSession.CaptureCallback() {
480 + @Override
481 + public void onCaptureProgressed(
482 + final CameraCaptureSession session,
483 + final CaptureRequest request,
484 + final CaptureResult partialResult) {}
485 +
486 + @Override
487 + public void onCaptureCompleted(
488 + final CameraCaptureSession session,
489 + final CaptureRequest request,
490 + final TotalCaptureResult result) {}
491 + };
492 +
493 + /**
494 + * Creates a new {@link CameraCaptureSession} for camera preview.
495 + */
496 + private void createCameraPreviewSession() {
497 + try {
498 + final SurfaceTexture texture = textureView.getSurfaceTexture();
499 + assert texture != null;
500 +
501 + // We configure the size of default buffer to be the size of camera preview we want.
502 + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());
503 +
504 + // This is the output Surface we need to start preview.
505 + final Surface surface = new Surface(texture);
506 +
507 + // We set up a CaptureRequest.Builder with the output Surface.
508 + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
509 + previewRequestBuilder.addTarget(surface);
510 +
511 + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight());
512 +
513 + // Create the reader for the preview frames.
514 + previewReader =
515 + ImageReader.newInstance(
516 + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
517 +
518 + previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
519 + previewRequestBuilder.addTarget(previewReader.getSurface());
520 +
521 + // Here, we create a CameraCaptureSession for camera preview.
522 + cameraDevice.createCaptureSession(
523 + Arrays.asList(surface, previewReader.getSurface()),
524 + new CameraCaptureSession.StateCallback() {
525 +
526 + @Override
527 + public void onConfigured(final CameraCaptureSession cameraCaptureSession) {
528 + // The camera is already closed
529 + if (null == cameraDevice) {
530 + return;
531 + }
532 +
533 + // When the session is ready, we start displaying the preview.
534 + captureSession = cameraCaptureSession;
535 + try {
536 + // Auto focus should be continuous for camera preview.
537 + previewRequestBuilder.set(
538 + CaptureRequest.CONTROL_AF_MODE,
539 + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
540 + // Flash is automatically enabled when necessary.
541 + previewRequestBuilder.set(
542 + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH);
543 +
544 + // Finally, we start displaying the camera preview.
545 + previewRequest = previewRequestBuilder.build();
546 + captureSession.setRepeatingRequest(
547 + previewRequest, captureCallback, backgroundHandler);
548 + } catch (final CameraAccessException e) {
549 + LOGGER.e(e, "Exception!");
550 + }
551 + }
552 +
553 + @Override
554 + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) {
555 + showToast("Failed");
556 + }
557 + },
558 + null);
559 + } catch (final CameraAccessException e) {
560 + LOGGER.e(e, "Exception!");
561 + }
562 + }
563 +
564 + /**
565 + * Configures the necessary {@link android.graphics.Matrix} transformation to `mTextureView`.
566 + * This method should be called after the camera preview size is determined in
567 + * setUpCameraOutputs and also the size of `mTextureView` is fixed.
568 + *
569 + * @param viewWidth The width of `mTextureView`
570 + * @param viewHeight The height of `mTextureView`
571 + */
572 + private void configureTransform(final int viewWidth, final int viewHeight) {
573 + final Activity activity = getActivity();
574 + if (null == textureView || null == previewSize || null == activity) {
575 + return;
576 + }
577 + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation();
578 + final Matrix matrix = new Matrix();
579 + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight);
580 + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth());
581 + final float centerX = viewRect.centerX();
582 + final float centerY = viewRect.centerY();
583 + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) {
584 + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY());
585 + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL);
586 + final float scale =
587 + Math.max(
588 + (float) viewHeight / previewSize.getHeight(),
589 + (float) viewWidth / previewSize.getWidth());
590 + matrix.postScale(scale, scale, centerX, centerY);
591 + matrix.postRotate(90 * (rotation - 2), centerX, centerY);
592 + } else if (Surface.ROTATION_180 == rotation) {
593 + matrix.postRotate(180, centerX, centerY);
594 + }
595 + textureView.setTransform(matrix);
596 + }
597 +
598 + /**
599 + * Compares two {@code Size}s based on their areas.
600 + */
601 + static class CompareSizesByArea implements Comparator<Size> {
602 + @Override
603 + public int compare(final Size lhs, final Size rhs) {
604 + // We cast here to ensure the multiplications won't overflow
605 + return Long.signum(
606 + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
607 + }
608 + }
609 +
610 + /**
611 + * Shows an error message dialog.
612 + */
613 + public static class ErrorDialog extends DialogFragment {
614 + private static final String ARG_MESSAGE = "message";
615 +
616 + public static ErrorDialog newInstance(final String message) {
617 + final ErrorDialog dialog = new ErrorDialog();
618 + final Bundle args = new Bundle();
619 + args.putString(ARG_MESSAGE, message);
620 + dialog.setArguments(args);
621 + return dialog;
622 + }
623 +
624 + @Override
625 + public Dialog onCreateDialog(final Bundle savedInstanceState) {
626 + final Activity activity = getActivity();
627 + return new AlertDialog.Builder(activity)
628 + .setMessage(getArguments().getString(ARG_MESSAGE))
629 + .setPositiveButton(
630 + android.R.string.ok,
631 + new DialogInterface.OnClickListener() {
632 + @Override
633 + public void onClick(final DialogInterface dialogInterface, final int i) {
634 + activity.finish();
635 + }
636 + })
637 + .create();
638 + }
639 + }
640 +}
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.graphics.Bitmap;
19 +import android.graphics.RectF;
20 +import java.util.List;
21 +
22 +/**
23 + * Generic interface for interacting with different recognition engines.
24 + */
25 +public interface Classifier {
26 + /**
27 + * An immutable result returned by a Classifier describing what was recognized.
28 + */
29 + public class Recognition {
30 + /**
31 + * A unique identifier for what has been recognized. Specific to the class, not the instance of
32 + * the object.
33 + */
34 + private final String id;
35 +
36 + /**
37 + * Display name for the recognition.
38 + */
39 + private final String title;
40 +
41 + /**
42 + * A sortable score for how good the recognition is relative to others. Higher should be better.
43 + */
44 + private final Float confidence;
45 +
46 + /** Optional location within the source image for the location of the recognized object. */
47 + private RectF location;
48 +
49 + public Recognition(
50 + final String id, final String title, final Float confidence, final RectF location) {
51 + this.id = id;
52 + this.title = title;
53 + this.confidence = confidence;
54 + this.location = location;
55 + }
56 +
57 + public String getId() {
58 + return id;
59 + }
60 +
61 + public String getTitle() {
62 + return title;
63 + }
64 +
65 + public Float getConfidence() {
66 + return confidence;
67 + }
68 +
69 + public RectF getLocation() {
70 + return new RectF(location);
71 + }
72 +
73 + public void setLocation(RectF location) {
74 + this.location = location;
75 + }
76 +
77 + @Override
78 + public String toString() {
79 + String resultString = "";
80 + if (id != null) {
81 + resultString += "[" + id + "] ";
82 + }
83 +
84 + if (title != null) {
85 + resultString += title + " ";
86 + }
87 +
88 + if (confidence != null) {
89 + resultString += String.format("(%.1f%%) ", confidence * 100.0f);
90 + }
91 +
92 + if (location != null) {
93 + resultString += location + " ";
94 + }
95 +
96 + return resultString.trim();
97 + }
98 + }
99 +
100 + List<Recognition> recognizeImage(Bitmap bitmap);
101 +
102 + void enableStatLogging(final boolean debug);
103 +
104 + String getStatString();
105 +
106 + void close();
107 +}
1 +/*
2 + * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.graphics.Bitmap;
20 +import android.graphics.Bitmap.Config;
21 +import android.graphics.Canvas;
22 +import android.graphics.Matrix;
23 +import android.graphics.Paint;
24 +import android.graphics.Typeface;
25 +import android.media.ImageReader.OnImageAvailableListener;
26 +import android.os.SystemClock;
27 +import android.util.Size;
28 +import android.util.TypedValue;
29 +import android.view.Display;
30 +import android.view.Surface;
31 +import java.util.List;
32 +import java.util.Vector;
33 +import org.tensorflow.demo.OverlayView.DrawCallback;
34 +import org.tensorflow.demo.env.BorderedText;
35 +import org.tensorflow.demo.env.ImageUtils;
36 +import org.tensorflow.demo.env.Logger;
37 +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
38 +
39 +public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
40 + private static final Logger LOGGER = new Logger();
41 +
42 + protected static final boolean SAVE_PREVIEW_BITMAP = false;
43 +
44 + private ResultsView resultsView;
45 +
46 + private Bitmap rgbFrameBitmap = null;
47 + private Bitmap croppedBitmap = null;
48 + private Bitmap cropCopyBitmap = null;
49 +
50 + private long lastProcessingTimeMs;
51 +
52 + // These are the settings for the original v1 Inception model. If you want to
53 + // use a model that's been produced from the TensorFlow for Poets codelab,
54 + // you'll need to set IMAGE_SIZE = 299, IMAGE_MEAN = 128, IMAGE_STD = 128,
55 + // INPUT_NAME = "Mul", and OUTPUT_NAME = "final_result".
56 + // You'll also need to update the MODEL_FILE and LABEL_FILE paths to point to
57 + // the ones you produced.
58 + //
59 + // To use v3 Inception model, strip the DecodeJpeg Op from your retrained
60 + // model first:
61 + //
62 + // python strip_unused.py \
63 + // --input_graph=<retrained-pb-file> \
64 + // --output_graph=<your-stripped-pb-file> \
65 + // --input_node_names="Mul" \
66 + // --output_node_names="final_result" \
67 + // --input_binary=true
68 + private static final int INPUT_SIZE = 224;
69 + private static final int IMAGE_MEAN = 117;
70 + private static final float IMAGE_STD = 1;
71 + private static final String INPUT_NAME = "input";
72 + private static final String OUTPUT_NAME = "output";
73 +
74 +
75 + private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
76 + private static final String LABEL_FILE =
77 + "file:///android_asset/imagenet_comp_graph_label_strings.txt";
78 +
79 +
80 + private static final boolean MAINTAIN_ASPECT = true;
81 +
82 + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
83 +
84 +
85 + private Integer sensorOrientation;
86 + private Classifier classifier;
87 + private Matrix frameToCropTransform;
88 + private Matrix cropToFrameTransform;
89 +
90 +
91 + private BorderedText borderedText;
92 +
93 +
94 + @Override
95 + protected int getLayoutId() {
96 + return R.layout.camera_connection_fragment;
97 + }
98 +
99 + @Override
100 + protected Size getDesiredPreviewFrameSize() {
101 + return DESIRED_PREVIEW_SIZE;
102 + }
103 +
104 + private static final float TEXT_SIZE_DIP = 10;
105 +
106 + @Override
107 + public void onPreviewSizeChosen(final Size size, final int rotation) {
108 + final float textSizePx = TypedValue.applyDimension(
109 + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
110 + borderedText = new BorderedText(textSizePx);
111 + borderedText.setTypeface(Typeface.MONOSPACE);
112 +
113 + classifier =
114 + TensorFlowImageClassifier.create(
115 + getAssets(),
116 + MODEL_FILE,
117 + LABEL_FILE,
118 + INPUT_SIZE,
119 + IMAGE_MEAN,
120 + IMAGE_STD,
121 + INPUT_NAME,
122 + OUTPUT_NAME);
123 +
124 + previewWidth = size.getWidth();
125 + previewHeight = size.getHeight();
126 +
127 + sensorOrientation = rotation - getScreenOrientation();
128 + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
129 +
130 + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
131 + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
132 + croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
133 +
134 + frameToCropTransform = ImageUtils.getTransformationMatrix(
135 + previewWidth, previewHeight,
136 + INPUT_SIZE, INPUT_SIZE,
137 + sensorOrientation, MAINTAIN_ASPECT);
138 +
139 + cropToFrameTransform = new Matrix();
140 + frameToCropTransform.invert(cropToFrameTransform);
141 +
142 + addCallback(
143 + new DrawCallback() {
144 + @Override
145 + public void drawCallback(final Canvas canvas) {
146 + renderDebug(canvas);
147 + }
148 + });
149 + }
150 +
151 + @Override
152 + protected void processImage() {
153 + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
154 + final Canvas canvas = new Canvas(croppedBitmap);
155 + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
156 +
157 + // For examining the actual TF input.
158 + if (SAVE_PREVIEW_BITMAP) {
159 + ImageUtils.saveBitmap(croppedBitmap);
160 + }
161 + runInBackground(
162 + new Runnable() {
163 + @Override
164 + public void run() {
165 + final long startTime = SystemClock.uptimeMillis();
166 + final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
167 + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
168 + LOGGER.i("Detect: %s", results);
169 + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
170 + if (resultsView == null) {
171 + resultsView = (ResultsView) findViewById(R.id.results);
172 + }
173 + resultsView.setResults(results);
174 + requestRender();
175 + readyForNextImage();
176 + }
177 + });
178 + }
179 +
180 + @Override
181 + public void onSetDebug(boolean debug) {
182 + classifier.enableStatLogging(debug);
183 + }
184 +
185 + private void renderDebug(final Canvas canvas) {
186 + if (!isDebug()) {
187 + return;
188 + }
189 + final Bitmap copy = cropCopyBitmap;
190 + if (copy != null) {
191 + final Matrix matrix = new Matrix();
192 + final float scaleFactor = 2;
193 + matrix.postScale(scaleFactor, scaleFactor);
194 + matrix.postTranslate(
195 + canvas.getWidth() - copy.getWidth() * scaleFactor,
196 + canvas.getHeight() - copy.getHeight() * scaleFactor);
197 + canvas.drawBitmap(copy, matrix, new Paint());
198 +
199 + final Vector<String> lines = new Vector<String>();
200 + if (classifier != null) {
201 + String statString = classifier.getStatString();
202 + String[] statLines = statString.split("\n");
203 + for (String line : statLines) {
204 + lines.add(line);
205 + }
206 + }
207 +
208 + lines.add("Frame: " + previewWidth + "x" + previewHeight);
209 + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
210 + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
211 + lines.add("Rotation: " + sensorOrientation);
212 + lines.add("Inference time: " + lastProcessingTimeMs + "ms");
213 +
214 + borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
215 + }
216 + }
217 +}
1 +/*
2 + * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.graphics.Bitmap;
20 +import android.graphics.Bitmap.Config;
21 +import android.graphics.Canvas;
22 +import android.graphics.Color;
23 +import android.graphics.Matrix;
24 +import android.graphics.Paint;
25 +import android.graphics.Paint.Style;
26 +import android.graphics.RectF;
27 +import android.graphics.Typeface;
28 +import android.media.ImageReader.OnImageAvailableListener;
29 +import android.os.SystemClock;
30 +import android.util.Size;
31 +import android.util.TypedValue;
32 +import android.view.Display;
33 +import android.view.Surface;
34 +import android.widget.Toast;
35 +import java.io.IOException;
36 +import java.util.LinkedList;
37 +import java.util.List;
38 +import java.util.Vector;
39 +import org.tensorflow.demo.OverlayView.DrawCallback;
40 +import org.tensorflow.demo.env.BorderedText;
41 +import org.tensorflow.demo.env.ImageUtils;
42 +import org.tensorflow.demo.env.Logger;
43 +import org.tensorflow.demo.tracking.MultiBoxTracker;
44 +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
45 +
46 +/**
47 + * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
48 + * objects.
49 + */
50 +public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
51 + private static final Logger LOGGER = new Logger();
52 +
53 + // Configuration values for the prepackaged multibox model.
54 + private static final int MB_INPUT_SIZE = 224;
55 + private static final int MB_IMAGE_MEAN = 128;
56 + private static final float MB_IMAGE_STD = 128;
57 + private static final String MB_INPUT_NAME = "ResizeBilinear";
58 + private static final String MB_OUTPUT_LOCATIONS_NAME = "output_locations/Reshape";
59 + private static final String MB_OUTPUT_SCORES_NAME = "output_scores/Reshape";
60 + private static final String MB_MODEL_FILE = "file:///android_asset/multibox_model.pb";
61 + private static final String MB_LOCATION_FILE =
62 + "file:///android_asset/multibox_location_priors.txt";
63 +
64 + private static final int TF_OD_API_INPUT_SIZE = 300;
65 + private static final String TF_OD_API_MODEL_FILE =
66 + "file:///android_asset/ssd_mobilenet_v1_android_export.pb";
67 + private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
68 +
69 + // Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and
70 + // must be manually placed in the assets/ directory by the user.
71 + // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via
72 + // DarkFlow (https://github.com/thtrieu/darkflow). Sample command:
73 + // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise
74 + private static final String YOLO_MODEL_FILE = "file:///android_asset/yolov3.pb";
75 + private static final int YOLO_INPUT_SIZE = 416;
76 + private static final String YOLO_INPUT_NAME = "input";
77 + private static final String YOLO_OUTPUT_NAMES = "output";
78 + private static final int YOLO_BLOCK_SIZE = 32;
79 +
80 + // Which detection model to use: by default uses Tensorflow Object Detection API frozen
81 + // checkpoints. Optionally use legacy Multibox (trained using an older version of the API)
82 + // or YOLO.
83 + private enum DetectorMode {
84 + TF_OD_API, MULTIBOX, YOLO;
85 + }
86 + private static final DetectorMode MODE = DetectorMode.YOLO;
87 +
88 + // Minimum detection confidence to track a detection.
89 + private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f;
90 + private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f;
91 + private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f;
92 +
93 + private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO;
94 +
95 + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
96 +
97 + private static final boolean SAVE_PREVIEW_BITMAP = false;
98 + private static final float TEXT_SIZE_DIP = 10;
99 +
100 + private Integer sensorOrientation;
101 +
102 + private Classifier detector;
103 +
104 + private long lastProcessingTimeMs;
105 + private Bitmap rgbFrameBitmap = null;
106 + private Bitmap croppedBitmap = null;
107 + private Bitmap cropCopyBitmap = null;
108 +
109 + private boolean computingDetection = false;
110 +
111 + private long timestamp = 0;
112 +
113 + private Matrix frameToCropTransform;
114 + private Matrix cropToFrameTransform;
115 +
116 + private MultiBoxTracker tracker;
117 +
118 + private byte[] luminanceCopy;
119 +
120 + private BorderedText borderedText;
121 + @Override
122 + public void onPreviewSizeChosen(final Size size, final int rotation) {
123 + final float textSizePx =
124 + TypedValue.applyDimension(
125 + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
126 + borderedText = new BorderedText(textSizePx);
127 + borderedText.setTypeface(Typeface.MONOSPACE);
128 +
129 + tracker = new MultiBoxTracker(this);
130 +
131 + int cropSize = TF_OD_API_INPUT_SIZE;
132 + if (MODE == DetectorMode.YOLO) {
133 + detector =
134 + TensorFlowYoloDetector.create(
135 + getAssets(),
136 + YOLO_MODEL_FILE,
137 + YOLO_INPUT_SIZE,
138 + YOLO_INPUT_NAME,
139 + YOLO_OUTPUT_NAMES,
140 + YOLO_BLOCK_SIZE);
141 + cropSize = YOLO_INPUT_SIZE;
142 + } else if (MODE == DetectorMode.MULTIBOX) {
143 + detector =
144 + TensorFlowMultiBoxDetector.create(
145 + getAssets(),
146 + MB_MODEL_FILE,
147 + MB_LOCATION_FILE,
148 + MB_IMAGE_MEAN,
149 + MB_IMAGE_STD,
150 + MB_INPUT_NAME,
151 + MB_OUTPUT_LOCATIONS_NAME,
152 + MB_OUTPUT_SCORES_NAME);
153 + cropSize = MB_INPUT_SIZE;
154 + } else {
155 + try {
156 + detector = TensorFlowObjectDetectionAPIModel.create(
157 + getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
158 + cropSize = TF_OD_API_INPUT_SIZE;
159 + } catch (final IOException e) {
160 + LOGGER.e(e, "Exception initializing classifier!");
161 + Toast toast =
162 + Toast.makeText(
163 + getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
164 + toast.show();
165 + finish();
166 + }
167 + }
168 +
169 + previewWidth = size.getWidth();
170 + previewHeight = size.getHeight();
171 +
172 + sensorOrientation = rotation - getScreenOrientation();
173 + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
174 +
175 + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
176 + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
177 + croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
178 +
179 + frameToCropTransform =
180 + ImageUtils.getTransformationMatrix(
181 + previewWidth, previewHeight,
182 + cropSize, cropSize,
183 + sensorOrientation, MAINTAIN_ASPECT);
184 +
185 + cropToFrameTransform = new Matrix();
186 + frameToCropTransform.invert(cropToFrameTransform);
187 +
188 + trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
189 + trackingOverlay.addCallback(
190 + new DrawCallback() {
191 + @Override
192 + public void drawCallback(final Canvas canvas) {
193 + tracker.draw(canvas);
194 + if (isDebug()) {
195 + tracker.drawDebug(canvas);
196 + }
197 + }
198 + });
199 +
200 + addCallback(
201 + new DrawCallback() {
202 + @Override
203 + public void drawCallback(final Canvas canvas) {
204 + if (!isDebug()) {
205 + return;
206 + }
207 + final Bitmap copy = cropCopyBitmap;
208 + if (copy == null) {
209 + return;
210 + }
211 +
212 + final int backgroundColor = Color.argb(100, 0, 0, 0);
213 + canvas.drawColor(backgroundColor);
214 +
215 + final Matrix matrix = new Matrix();
216 + final float scaleFactor = 2;
217 + matrix.postScale(scaleFactor, scaleFactor);
218 + matrix.postTranslate(
219 + canvas.getWidth() - copy.getWidth() * scaleFactor,
220 + canvas.getHeight() - copy.getHeight() * scaleFactor);
221 + canvas.drawBitmap(copy, matrix, new Paint());
222 +
223 + final Vector<String> lines = new Vector<String>();
224 + if (detector != null) {
225 + final String statString = detector.getStatString();
226 + final String[] statLines = statString.split("\n");
227 + for (final String line : statLines) {
228 + lines.add(line);
229 + }
230 + }
231 + lines.add("");
232 +
233 + lines.add("Frame: " + previewWidth + "x" + previewHeight);
234 + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
235 + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
236 + lines.add("Rotation: " + sensorOrientation);
237 + lines.add("Inference time: " + lastProcessingTimeMs + "ms");
238 +
239 + borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
240 + }
241 + });
242 + }
243 +
244 + OverlayView trackingOverlay;
245 +
246 + @Override
247 + protected void processImage() {
248 + ++timestamp;
249 + final long currTimestamp = timestamp;
250 + byte[] originalLuminance = getLuminance();
251 + tracker.onFrame(
252 + previewWidth,
253 + previewHeight,
254 + getLuminanceStride(),
255 + sensorOrientation,
256 + originalLuminance,
257 + timestamp);
258 + trackingOverlay.postInvalidate();
259 +
260 + // No mutex needed as this method is not reentrant.
261 + if (computingDetection) {
262 + readyForNextImage();
263 + return;
264 + }
265 + computingDetection = true;
266 + LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
267 +
268 + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
269 +
270 + if (luminanceCopy == null) {
271 + luminanceCopy = new byte[originalLuminance.length];
272 + }
273 + System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length);
274 + readyForNextImage();
275 +
276 + final Canvas canvas = new Canvas(croppedBitmap);
277 + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
278 + // For examining the actual TF input.
279 + if (SAVE_PREVIEW_BITMAP) {
280 + ImageUtils.saveBitmap(croppedBitmap);
281 + }
282 +
283 + runInBackground(
284 + new Runnable() {
285 + @Override
286 + public void run() {
287 + LOGGER.i("Running detection on image " + currTimestamp);
288 + final long startTime = SystemClock.uptimeMillis();
289 + final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
290 + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
291 +
292 + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
293 + final Canvas canvas = new Canvas(cropCopyBitmap);
294 + final Paint paint = new Paint();
295 + paint.setColor(Color.RED);
296 + paint.setStyle(Style.STROKE);
297 + paint.setStrokeWidth(2.0f);
298 +
299 + float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
300 + switch (MODE) {
301 + case TF_OD_API:
302 + minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
303 + break;
304 + case MULTIBOX:
305 + minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX;
306 + break;
307 + case YOLO:
308 + minimumConfidence = MINIMUM_CONFIDENCE_YOLO;
309 + break;
310 + }
311 +
312 + final List<Classifier.Recognition> mappedRecognitions =
313 + new LinkedList<Classifier.Recognition>();
314 +
315 + for (final Classifier.Recognition result : results) {
316 + final RectF location = result.getLocation();
317 + if (location != null && result.getConfidence() >= minimumConfidence) {
318 + canvas.drawRect(location, paint);
319 +
320 + cropToFrameTransform.mapRect(location);
321 + result.setLocation(location);
322 + mappedRecognitions.add(result);
323 + }
324 + }
325 +
326 + tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp);
327 + trackingOverlay.postInvalidate();
328 +
329 + requestRender();
330 + computingDetection = false;
331 + }
332 + });
333 + }
334 +
335 + @Override
336 + protected int getLayoutId() {
337 + return R.layout.camera_connection_fragment_tracking;
338 + }
339 +
340 + @Override
341 + protected Size getDesiredPreviewFrameSize() {
342 + return DESIRED_PREVIEW_SIZE;
343 + }
344 +
345 + @Override
346 + public void onSetDebug(final boolean debug) {
347 + detector.enableStatLogging(debug);
348 + }
349 +}
1 +package org.tensorflow.demo;
2 +
3 +/*
4 + * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 + *
6 + * Licensed under the Apache License, Version 2.0 (the "License");
7 + * you may not use this file except in compliance with the License.
8 + * You may obtain a copy of the License at
9 + *
10 + * http://www.apache.org/licenses/LICENSE-2.0
11 + *
12 + * Unless required by applicable law or agreed to in writing, software
13 + * distributed under the License is distributed on an "AS IS" BASIS,
14 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 + * See the License for the specific language governing permissions and
16 + * limitations under the License.
17 + */
18 +
19 +import android.app.Fragment;
20 +import android.graphics.SurfaceTexture;
21 +import android.hardware.Camera;
22 +import android.hardware.Camera.CameraInfo;
23 +import android.os.Bundle;
24 +import android.os.Handler;
25 +import android.os.HandlerThread;
26 +import android.util.Size;
27 +import android.util.SparseIntArray;
28 +import android.view.LayoutInflater;
29 +import android.view.Surface;
30 +import android.view.TextureView;
31 +import android.view.View;
32 +import android.view.ViewGroup;
33 +import java.io.IOException;
34 +import java.util.List;
35 +import org.tensorflow.demo.env.ImageUtils;
36 +import org.tensorflow.demo.env.Logger;
37 +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
38 +
39 +public class LegacyCameraConnectionFragment extends Fragment {
40 + private Camera camera;
41 + private static final Logger LOGGER = new Logger();
42 + private Camera.PreviewCallback imageListener;
43 + private Size desiredSize;
44 +
45 + /**
46 + * The layout identifier to inflate for this Fragment.
47 + */
48 + private int layout;
49 +
50 + public LegacyCameraConnectionFragment(
51 + final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) {
52 + this.imageListener = imageListener;
53 + this.layout = layout;
54 + this.desiredSize = desiredSize;
55 + }
56 +
57 + /**
58 + * Conversion from screen rotation to JPEG orientation.
59 + */
60 + private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
61 +
62 + static {
63 + ORIENTATIONS.append(Surface.ROTATION_0, 90);
64 + ORIENTATIONS.append(Surface.ROTATION_90, 0);
65 + ORIENTATIONS.append(Surface.ROTATION_180, 270);
66 + ORIENTATIONS.append(Surface.ROTATION_270, 180);
67 + }
68 +
69 + /**
70 + * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a
71 + * {@link TextureView}.
72 + */
73 + private final TextureView.SurfaceTextureListener surfaceTextureListener =
74 + new TextureView.SurfaceTextureListener() {
75 + @Override
76 + public void onSurfaceTextureAvailable(
77 + final SurfaceTexture texture, final int width, final int height) {
78 +
79 + int index = getCameraId();
80 + camera = Camera.open(index);
81 +
82 + try {
83 + Camera.Parameters parameters = camera.getParameters();
84 + List<String> focusModes = parameters.getSupportedFocusModes();
85 + if (focusModes != null
86 + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) {
87 + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE);
88 + }
89 + List<Camera.Size> cameraSizes = parameters.getSupportedPreviewSizes();
90 + Size[] sizes = new Size[cameraSizes.size()];
91 + int i = 0;
92 + for (Camera.Size size : cameraSizes) {
93 + sizes[i++] = new Size(size.width, size.height);
94 + }
95 + Size previewSize =
96 + CameraConnectionFragment.chooseOptimalSize(
97 + sizes, desiredSize.getWidth(), desiredSize.getHeight());
98 + parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight());
99 + camera.setDisplayOrientation(90);
100 + camera.setParameters(parameters);
101 + camera.setPreviewTexture(texture);
102 + } catch (IOException exception) {
103 + camera.release();
104 + }
105 +
106 + camera.setPreviewCallbackWithBuffer(imageListener);
107 + Camera.Size s = camera.getParameters().getPreviewSize();
108 + camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]);
109 +
110 + textureView.setAspectRatio(s.height, s.width);
111 +
112 + camera.startPreview();
113 + }
114 +
115 + @Override
116 + public void onSurfaceTextureSizeChanged(
117 + final SurfaceTexture texture, final int width, final int height) {}
118 +
119 + @Override
120 + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) {
121 + return true;
122 + }
123 +
124 + @Override
125 + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
126 + };
127 +
128 + /**
129 + * An {@link AutoFitTextureView} for camera preview.
130 + */
131 + private AutoFitTextureView textureView;
132 +
133 + /**
134 + * An additional thread for running tasks that shouldn't block the UI.
135 + */
136 + private HandlerThread backgroundThread;
137 +
138 + @Override
139 + public View onCreateView(
140 + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) {
141 + return inflater.inflate(layout, container, false);
142 + }
143 +
144 + @Override
145 + public void onViewCreated(final View view, final Bundle savedInstanceState) {
146 + textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
147 + }
148 +
149 + @Override
150 + public void onActivityCreated(final Bundle savedInstanceState) {
151 + super.onActivityCreated(savedInstanceState);
152 + }
153 +
154 + @Override
155 + public void onResume() {
156 + super.onResume();
157 + startBackgroundThread();
158 + // When the screen is turned off and turned back on, the SurfaceTexture is already
159 + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
160 + // a camera and start preview from here (otherwise, we wait until the surface is ready in
161 + // the SurfaceTextureListener).
162 +
163 + if (textureView.isAvailable()) {
164 + camera.startPreview();
165 + } else {
166 + textureView.setSurfaceTextureListener(surfaceTextureListener);
167 + }
168 + }
169 +
170 + @Override
171 + public void onPause() {
172 + stopCamera();
173 + stopBackgroundThread();
174 + super.onPause();
175 + }
176 +
177 + /**
178 + * Starts a background thread and its {@link Handler}.
179 + */
180 + private void startBackgroundThread() {
181 + backgroundThread = new HandlerThread("CameraBackground");
182 + backgroundThread.start();
183 + }
184 +
185 + /**
186 + * Stops the background thread and its {@link Handler}.
187 + */
188 + private void stopBackgroundThread() {
189 + backgroundThread.quitSafely();
190 + try {
191 + backgroundThread.join();
192 + backgroundThread = null;
193 + } catch (final InterruptedException e) {
194 + LOGGER.e(e, "Exception!");
195 + }
196 + }
197 +
198 + protected void stopCamera() {
199 + if (camera != null) {
200 + camera.stopPreview();
201 + camera.setPreviewCallback(null);
202 + camera.release();
203 + camera = null;
204 + }
205 + }
206 +
207 + private int getCameraId() {
208 + CameraInfo ci = new CameraInfo();
209 + for (int i = 0; i < Camera.getNumberOfCameras(); i++) {
210 + Camera.getCameraInfo(i, ci);
211 + if (ci.facing == CameraInfo.CAMERA_FACING_BACK)
212 + return i;
213 + }
214 + return -1; // No camera found
215 + }
216 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.content.Context;
19 +import android.graphics.Canvas;
20 +import android.util.AttributeSet;
21 +import android.view.View;
22 +import java.util.LinkedList;
23 +import java.util.List;
24 +
25 +/**
26 + * A simple View providing a render callback to other classes.
27 + */
28 +public class OverlayView extends View {
29 + private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>();
30 +
31 + public OverlayView(final Context context, final AttributeSet attrs) {
32 + super(context, attrs);
33 + }
34 +
35 + /**
36 + * Interface defining the callback for client classes.
37 + */
38 + public interface DrawCallback {
39 + public void drawCallback(final Canvas canvas);
40 + }
41 +
42 + public void addCallback(final DrawCallback callback) {
43 + callbacks.add(callback);
44 + }
45 +
46 + @Override
47 + public synchronized void draw(final Canvas canvas) {
48 + for (final DrawCallback callback : callbacks) {
49 + callback.drawCallback(canvas);
50 + }
51 + }
52 +}
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.content.Context;
19 +import android.graphics.Canvas;
20 +import android.graphics.Paint;
21 +import android.util.AttributeSet;
22 +import android.util.TypedValue;
23 +import android.view.View;
24 +
25 +import org.tensorflow.demo.Classifier.Recognition;
26 +
27 +import java.util.List;
28 +
29 +public class RecognitionScoreView extends View implements ResultsView {
30 + private static final float TEXT_SIZE_DIP = 24;
31 + private List<Recognition> results;
32 + private final float textSizePx;
33 + private final Paint fgPaint;
34 + private final Paint bgPaint;
35 +
36 + public RecognitionScoreView(final Context context, final AttributeSet set) {
37 + super(context, set);
38 +
39 + textSizePx =
40 + TypedValue.applyDimension(
41 + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
42 + fgPaint = new Paint();
43 + fgPaint.setTextSize(textSizePx);
44 +
45 + bgPaint = new Paint();
46 + bgPaint.setColor(0xcc4285f4);
47 + }
48 +
49 + @Override
50 + public void setResults(final List<Recognition> results) {
51 + this.results = results;
52 + postInvalidate();
53 + }
54 +
55 + @Override
56 + public void onDraw(final Canvas canvas) {
57 + final int x = 10;
58 + int y = (int) (fgPaint.getTextSize() * 1.5f);
59 +
60 + canvas.drawPaint(bgPaint);
61 +
62 + if (results != null) {
63 + for (final Recognition recog : results) {
64 + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint);
65 + y += fgPaint.getTextSize() * 1.5f;
66 + }
67 + }
68 + }
69 +}
1 +/*
2 + * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.util.Log;
20 +import android.util.Pair;
21 +import java.util.ArrayDeque;
22 +import java.util.ArrayList;
23 +import java.util.Arrays;
24 +import java.util.Deque;
25 +import java.util.List;
26 +
27 +/** Reads in results from an instantaneous audio recognition model and smoothes them over time. */
28 +public class RecognizeCommands {
29 + // Configuration settings.
30 + private List<String> labels = new ArrayList<String>();
31 + private long averageWindowDurationMs;
32 + private float detectionThreshold;
33 + private int suppressionMs;
34 + private int minimumCount;
35 + private long minimumTimeBetweenSamplesMs;
36 +
37 + // Working variables.
38 + private Deque<Pair<Long, float[]>> previousResults = new ArrayDeque<Pair<Long, float[]>>();
39 + private String previousTopLabel;
40 + private int labelsCount;
41 + private long previousTopLabelTime;
42 + private float previousTopLabelScore;
43 +
44 + private static final String SILENCE_LABEL = "_silence_";
45 + private static final long MINIMUM_TIME_FRACTION = 4;
46 +
47 + public RecognizeCommands(
48 + List<String> inLabels,
49 + long inAverageWindowDurationMs,
50 + float inDetectionThreshold,
51 + int inSuppressionMS,
52 + int inMinimumCount,
53 + long inMinimumTimeBetweenSamplesMS) {
54 + labels = inLabels;
55 + averageWindowDurationMs = inAverageWindowDurationMs;
56 + detectionThreshold = inDetectionThreshold;
57 + suppressionMs = inSuppressionMS;
58 + minimumCount = inMinimumCount;
59 + labelsCount = inLabels.size();
60 + previousTopLabel = SILENCE_LABEL;
61 + previousTopLabelTime = Long.MIN_VALUE;
62 + previousTopLabelScore = 0.0f;
63 + minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS;
64 + }
65 +
66 + /** Holds information about what's been recognized. */
67 + public static class RecognitionResult {
68 + public final String foundCommand;
69 + public final float score;
70 + public final boolean isNewCommand;
71 +
72 + public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) {
73 + foundCommand = inFoundCommand;
74 + score = inScore;
75 + isNewCommand = inIsNewCommand;
76 + }
77 + }
78 +
79 + private static class ScoreForSorting implements Comparable<ScoreForSorting> {
80 + public final float score;
81 + public final int index;
82 +
83 + public ScoreForSorting(float inScore, int inIndex) {
84 + score = inScore;
85 + index = inIndex;
86 + }
87 +
88 + @Override
89 + public int compareTo(ScoreForSorting other) {
90 + if (this.score > other.score) {
91 + return -1;
92 + } else if (this.score < other.score) {
93 + return 1;
94 + } else {
95 + return 0;
96 + }
97 + }
98 + }
99 +
100 + public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) {
101 + if (currentResults.length != labelsCount) {
102 + throw new RuntimeException(
103 + "The results for recognition should contain "
104 + + labelsCount
105 + + " elements, but there are "
106 + + currentResults.length);
107 + }
108 +
109 + if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) {
110 + throw new RuntimeException(
111 + "You must feed results in increasing time order, but received a timestamp of "
112 + + currentTimeMS
113 + + " that was earlier than the previous one of "
114 + + previousResults.getFirst().first);
115 + }
116 +
117 + final int howManyResults = previousResults.size();
118 + // Ignore any results that are coming in too frequently.
119 + if (howManyResults > 1) {
120 + final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first;
121 + if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) {
122 + return new RecognitionResult(previousTopLabel, previousTopLabelScore, false);
123 + }
124 + }
125 +
126 + // Add the latest results to the head of the queue.
127 + previousResults.addLast(new Pair<Long, float[]>(currentTimeMS, currentResults));
128 +
129 + // Prune any earlier results that are too old for the averaging window.
130 + final long timeLimit = currentTimeMS - averageWindowDurationMs;
131 + while (previousResults.getFirst().first < timeLimit) {
132 + previousResults.removeFirst();
133 + }
134 +
135 + // If there are too few results, assume the result will be unreliable and
136 + // bail.
137 + final long earliestTime = previousResults.getFirst().first;
138 + final long samplesDuration = currentTimeMS - earliestTime;
139 + if ((howManyResults < minimumCount)
140 + || (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) {
141 + Log.v("RecognizeResult", "Too few results");
142 + return new RecognitionResult(previousTopLabel, 0.0f, false);
143 + }
144 +
145 + // Calculate the average score across all the results in the window.
146 + float[] averageScores = new float[labelsCount];
147 + for (Pair<Long, float[]> previousResult : previousResults) {
148 + final float[] scoresTensor = previousResult.second;
149 + int i = 0;
150 + while (i < scoresTensor.length) {
151 + averageScores[i] += scoresTensor[i] / howManyResults;
152 + ++i;
153 + }
154 + }
155 +
156 + // Sort the averaged results in descending score order.
157 + ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount];
158 + for (int i = 0; i < labelsCount; ++i) {
159 + sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i);
160 + }
161 + Arrays.sort(sortedAverageScores);
162 +
163 + // See if the latest top score is enough to trigger a detection.
164 + final int currentTopIndex = sortedAverageScores[0].index;
165 + final String currentTopLabel = labels.get(currentTopIndex);
166 + final float currentTopScore = sortedAverageScores[0].score;
167 + // If we've recently had another label trigger, assume one that occurs too
168 + // soon afterwards is a bad result.
169 + long timeSinceLastTop;
170 + if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) {
171 + timeSinceLastTop = Long.MAX_VALUE;
172 + } else {
173 + timeSinceLastTop = currentTimeMS - previousTopLabelTime;
174 + }
175 + boolean isNewCommand;
176 + if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) {
177 + previousTopLabel = currentTopLabel;
178 + previousTopLabelTime = currentTimeMS;
179 + previousTopLabelScore = currentTopScore;
180 + isNewCommand = true;
181 + } else {
182 + isNewCommand = false;
183 + }
184 + return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand);
185 + }
186 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import org.tensorflow.demo.Classifier.Recognition;
19 +
20 +import java.util.List;
21 +
22 +public interface ResultsView {
23 + public void setResults(final List<Recognition> results);
24 +}
1 +/*
2 + * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +/* Demonstrates how to run an audio recognition model in Android.
18 +
19 +This example loads a simple speech recognition model trained by the tutorial at
20 +https://www.tensorflow.org/tutorials/audio_training
21 +
22 +The model files should be downloaded automatically from the TensorFlow website,
23 +but if you have a custom model you can update the LABEL_FILENAME and
24 +MODEL_FILENAME constants to point to your own files.
25 +
26 +The example application displays a list view with all of the known audio labels,
27 +and highlights each one when it thinks it has detected one through the
28 +microphone. The averaging of results to give a more reliable signal happens in
29 +the RecognizeCommands helper class.
30 +*/
31 +
32 +package org.tensorflow.demo;
33 +
34 +import android.animation.AnimatorInflater;
35 +import android.animation.AnimatorSet;
36 +import android.app.Activity;
37 +import android.content.pm.PackageManager;
38 +import android.media.AudioFormat;
39 +import android.media.AudioRecord;
40 +import android.media.MediaRecorder;
41 +import android.os.Build;
42 +import android.os.Bundle;
43 +import android.util.Log;
44 +import android.view.View;
45 +import android.widget.ArrayAdapter;
46 +import android.widget.Button;
47 +import android.widget.ListView;
48 +import java.io.BufferedReader;
49 +import java.io.IOException;
50 +import java.io.InputStreamReader;
51 +import java.util.ArrayList;
52 +import java.util.List;
53 +import java.util.concurrent.locks.ReentrantLock;
54 +import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
55 +import org.tensorflow.demo.R;
56 +
57 +/**
58 + * An activity that listens for audio and then uses a TensorFlow model to detect particular classes,
59 + * by default a small set of action words.
60 + */
61 +public class SpeechActivity extends Activity {
62 +
63 + // Constants that control the behavior of the recognition code and model
64 + // settings. See the audio recognition tutorial for a detailed explanation of
65 + // all these, but you should customize them to match your training settings if
66 + // you are running your own model.
67 + private static final int SAMPLE_RATE = 16000;
68 + private static final int SAMPLE_DURATION_MS = 1000;
69 + private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000);
70 + private static final long AVERAGE_WINDOW_DURATION_MS = 500;
71 + private static final float DETECTION_THRESHOLD = 0.70f;
72 + private static final int SUPPRESSION_MS = 1500;
73 + private static final int MINIMUM_COUNT = 3;
74 + private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30;
75 + private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt";
76 + private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.pb";
77 + private static final String INPUT_DATA_NAME = "decoded_sample_data:0";
78 + private static final String SAMPLE_RATE_NAME = "decoded_sample_data:1";
79 + private static final String OUTPUT_SCORES_NAME = "labels_softmax";
80 +
81 + // UI elements.
82 + private static final int REQUEST_RECORD_AUDIO = 13;
83 + private Button quitButton;
84 + private ListView labelsListView;
85 + private static final String LOG_TAG = SpeechActivity.class.getSimpleName();
86 +
87 + // Working variables.
88 + short[] recordingBuffer = new short[RECORDING_LENGTH];
89 + int recordingOffset = 0;
90 + boolean shouldContinue = true;
91 + private Thread recordingThread;
92 + boolean shouldContinueRecognition = true;
93 + private Thread recognitionThread;
94 + private final ReentrantLock recordingBufferLock = new ReentrantLock();
95 + private TensorFlowInferenceInterface inferenceInterface;
96 + private List<String> labels = new ArrayList<String>();
97 + private List<String> displayedLabels = new ArrayList<>();
98 + private RecognizeCommands recognizeCommands = null;
99 +
100 + @Override
101 + protected void onCreate(Bundle savedInstanceState) {
102 + // Set up the UI.
103 + super.onCreate(savedInstanceState);
104 + setContentView(R.layout.activity_speech);
105 + quitButton = (Button) findViewById(R.id.quit);
106 + quitButton.setOnClickListener(
107 + new View.OnClickListener() {
108 + @Override
109 + public void onClick(View view) {
110 + moveTaskToBack(true);
111 + android.os.Process.killProcess(android.os.Process.myPid());
112 + System.exit(1);
113 + }
114 + });
115 + labelsListView = (ListView) findViewById(R.id.list_view);
116 +
117 + // Load the labels for the model, but only display those that don't start
118 + // with an underscore.
119 + String actualFilename = LABEL_FILENAME.split("file:///android_asset/")[1];
120 + Log.i(LOG_TAG, "Reading labels from: " + actualFilename);
121 + BufferedReader br = null;
122 + try {
123 + br = new BufferedReader(new InputStreamReader(getAssets().open(actualFilename)));
124 + String line;
125 + while ((line = br.readLine()) != null) {
126 + labels.add(line);
127 + if (line.charAt(0) != '_') {
128 + displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1));
129 + }
130 + }
131 + br.close();
132 + } catch (IOException e) {
133 + throw new RuntimeException("Problem reading label file!", e);
134 + }
135 +
136 + // Build a list view based on these labels.
137 + ArrayAdapter<String> arrayAdapter =
138 + new ArrayAdapter<String>(this, R.layout.list_text_item, displayedLabels);
139 + labelsListView.setAdapter(arrayAdapter);
140 +
141 + // Set up an object to smooth recognition results to increase accuracy.
142 + recognizeCommands =
143 + new RecognizeCommands(
144 + labels,
145 + AVERAGE_WINDOW_DURATION_MS,
146 + DETECTION_THRESHOLD,
147 + SUPPRESSION_MS,
148 + MINIMUM_COUNT,
149 + MINIMUM_TIME_BETWEEN_SAMPLES_MS);
150 +
151 + // Load the TensorFlow model.
152 + inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME);
153 +
154 + // Start the recording and recognition threads.
155 + requestMicrophonePermission();
156 + startRecording();
157 + startRecognition();
158 + }
159 +
160 + private void requestMicrophonePermission() {
161 + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
162 + requestPermissions(
163 + new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
164 + }
165 + }
166 +
167 + @Override
168 + public void onRequestPermissionsResult(
169 + int requestCode, String[] permissions, int[] grantResults) {
170 + if (requestCode == REQUEST_RECORD_AUDIO
171 + && grantResults.length > 0
172 + && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
173 + startRecording();
174 + startRecognition();
175 + }
176 + }
177 +
178 + public synchronized void startRecording() {
179 + if (recordingThread != null) {
180 + return;
181 + }
182 + shouldContinue = true;
183 + recordingThread =
184 + new Thread(
185 + new Runnable() {
186 + @Override
187 + public void run() {
188 + record();
189 + }
190 + });
191 + recordingThread.start();
192 + }
193 +
194 + public synchronized void stopRecording() {
195 + if (recordingThread == null) {
196 + return;
197 + }
198 + shouldContinue = false;
199 + recordingThread = null;
200 + }
201 +
202 + private void record() {
203 + android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);
204 +
205 + // Estimate the buffer size we'll need for this device.
206 + int bufferSize =
207 + AudioRecord.getMinBufferSize(
208 + SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT);
209 + if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) {
210 + bufferSize = SAMPLE_RATE * 2;
211 + }
212 + short[] audioBuffer = new short[bufferSize / 2];
213 +
214 + AudioRecord record =
215 + new AudioRecord(
216 + MediaRecorder.AudioSource.DEFAULT,
217 + SAMPLE_RATE,
218 + AudioFormat.CHANNEL_IN_MONO,
219 + AudioFormat.ENCODING_PCM_16BIT,
220 + bufferSize);
221 +
222 + if (record.getState() != AudioRecord.STATE_INITIALIZED) {
223 + Log.e(LOG_TAG, "Audio Record can't initialize!");
224 + return;
225 + }
226 +
227 + record.startRecording();
228 +
229 + Log.v(LOG_TAG, "Start recording");
230 +
231 + // Loop, gathering audio data and copying it to a round-robin buffer.
232 + while (shouldContinue) {
233 + int numberRead = record.read(audioBuffer, 0, audioBuffer.length);
234 + int maxLength = recordingBuffer.length;
235 + int newRecordingOffset = recordingOffset + numberRead;
236 + int secondCopyLength = Math.max(0, newRecordingOffset - maxLength);
237 + int firstCopyLength = numberRead - secondCopyLength;
238 + // We store off all the data for the recognition thread to access. The ML
239 + // thread will copy out of this buffer into its own, while holding the
240 + // lock, so this should be thread safe.
241 + recordingBufferLock.lock();
242 + try {
243 + System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength);
244 + System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength);
245 + recordingOffset = newRecordingOffset % maxLength;
246 + } finally {
247 + recordingBufferLock.unlock();
248 + }
249 + }
250 +
251 + record.stop();
252 + record.release();
253 + }
254 +
255 + public synchronized void startRecognition() {
256 + if (recognitionThread != null) {
257 + return;
258 + }
259 + shouldContinueRecognition = true;
260 + recognitionThread =
261 + new Thread(
262 + new Runnable() {
263 + @Override
264 + public void run() {
265 + recognize();
266 + }
267 + });
268 + recognitionThread.start();
269 + }
270 +
271 + public synchronized void stopRecognition() {
272 + if (recognitionThread == null) {
273 + return;
274 + }
275 + shouldContinueRecognition = false;
276 + recognitionThread = null;
277 + }
278 +
279 + private void recognize() {
280 + Log.v(LOG_TAG, "Start recognition");
281 +
282 + short[] inputBuffer = new short[RECORDING_LENGTH];
283 + float[] floatInputBuffer = new float[RECORDING_LENGTH];
284 + float[] outputScores = new float[labels.size()];
285 + String[] outputScoresNames = new String[] {OUTPUT_SCORES_NAME};
286 + int[] sampleRateList = new int[] {SAMPLE_RATE};
287 +
288 + // Loop, grabbing recorded data and running the recognition model on it.
289 + while (shouldContinueRecognition) {
290 + // The recording thread places data in this round-robin buffer, so lock to
291 + // make sure there's no writing happening and then copy it to our own
292 + // local version.
293 + recordingBufferLock.lock();
294 + try {
295 + int maxLength = recordingBuffer.length;
296 + int firstCopyLength = maxLength - recordingOffset;
297 + int secondCopyLength = recordingOffset;
298 + System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength);
299 + System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength);
300 + } finally {
301 + recordingBufferLock.unlock();
302 + }
303 +
304 + // We need to feed in float values between -1.0f and 1.0f, so divide the
305 + // signed 16-bit inputs.
306 + for (int i = 0; i < RECORDING_LENGTH; ++i) {
307 + floatInputBuffer[i] = inputBuffer[i] / 32767.0f;
308 + }
309 +
310 + // Run the model.
311 + inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList);
312 + inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1);
313 + inferenceInterface.run(outputScoresNames);
314 + inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores);
315 +
316 + // Use the smoother to figure out if we've had a real recognition event.
317 + long currentTime = System.currentTimeMillis();
318 + final RecognizeCommands.RecognitionResult result =
319 + recognizeCommands.processLatestResults(outputScores, currentTime);
320 +
321 + runOnUiThread(
322 + new Runnable() {
323 + @Override
324 + public void run() {
325 + // If we do have a new command, highlight the right list entry.
326 + if (!result.foundCommand.startsWith("_") && result.isNewCommand) {
327 + int labelIndex = -1;
328 + for (int i = 0; i < labels.size(); ++i) {
329 + if (labels.get(i).equals(result.foundCommand)) {
330 + labelIndex = i;
331 + }
332 + }
333 + final View labelView = labelsListView.getChildAt(labelIndex - 2);
334 +
335 + AnimatorSet colorAnimation =
336 + (AnimatorSet)
337 + AnimatorInflater.loadAnimator(
338 + SpeechActivity.this, R.animator.color_animation);
339 + colorAnimation.setTarget(labelView);
340 + colorAnimation.start();
341 + }
342 + }
343 + });
344 + try {
345 + // We don't need to run too frequently, so snooze for a bit.
346 + Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS);
347 + } catch (InterruptedException e) {
348 + // Ignore
349 + }
350 + }
351 +
352 + Log.v(LOG_TAG, "End recognition");
353 + }
354 +}
1 +/*
2 + * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 + *
4 + * Licensed under the Apache License, Version 2.0 (the "License");
5 + * you may not use this file except in compliance with the License.
6 + * You may obtain a copy of the License at
7 + *
8 + * http://www.apache.org/licenses/LICENSE-2.0
9 + *
10 + * Unless required by applicable law or agreed to in writing, software
11 + * distributed under the License is distributed on an "AS IS" BASIS,
12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 + * See the License for the specific language governing permissions and
14 + * limitations under the License.
15 + */
16 +
17 +package org.tensorflow.demo;
18 +
19 +import android.app.UiModeManager;
20 +import android.content.Context;
21 +import android.content.res.AssetManager;
22 +import android.content.res.Configuration;
23 +import android.graphics.Bitmap;
24 +import android.graphics.Bitmap.Config;
25 +import android.graphics.BitmapFactory;
26 +import android.graphics.Canvas;
27 +import android.graphics.Color;
28 +import android.graphics.Matrix;
29 +import android.graphics.Paint;
30 +import android.graphics.Paint.Style;
31 +import android.graphics.Rect;
32 +import android.graphics.Typeface;
33 +import android.media.ImageReader.OnImageAvailableListener;
34 +import android.os.Bundle;
35 +import android.os.SystemClock;
36 +import android.util.DisplayMetrics;
37 +import android.util.Size;
38 +import android.util.TypedValue;
39 +import android.view.Display;
40 +import android.view.KeyEvent;
41 +import android.view.MotionEvent;
42 +import android.view.View;
43 +import android.view.View.OnClickListener;
44 +import android.view.View.OnTouchListener;
45 +import android.view.ViewGroup;
46 +import android.widget.BaseAdapter;
47 +import android.widget.Button;
48 +import android.widget.GridView;
49 +import android.widget.ImageView;
50 +import android.widget.RelativeLayout;
51 +import android.widget.Toast;
52 +import java.io.IOException;
53 +import java.io.InputStream;
54 +import java.util.ArrayList;
55 +import java.util.Collections;
56 +import java.util.Vector;
57 +import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
58 +import org.tensorflow.demo.OverlayView.DrawCallback;
59 +import org.tensorflow.demo.env.BorderedText;
60 +import org.tensorflow.demo.env.ImageUtils;
61 +import org.tensorflow.demo.env.Logger;
62 +import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
63 +
64 +/**
65 + * Sample activity that stylizes the camera preview according to "A Learned Representation For
66 + * Artistic Style" (https://arxiv.org/abs/1610.07629)
67 + */
68 +public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
69 + private static final Logger LOGGER = new Logger();
70 +
71 + private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
72 + private static final String INPUT_NODE = "input";
73 + private static final String STYLE_NODE = "style_num";
74 + private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";
75 + private static final int NUM_STYLES = 26;
76 +
77 + private static final boolean SAVE_PREVIEW_BITMAP = false;
78 +
79 + // Whether to actively manipulate non-selected sliders so that sum of activations always appears
80 + // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless.
81 + private static final boolean NORMALIZE_SLIDERS = true;
82 +
83 + private static final float TEXT_SIZE_DIP = 12;
84 +
85 + private static final boolean DEBUG_MODEL = false;
86 +
87 + private static final int[] SIZES = {128, 192, 256, 384, 512, 720};
88 +
89 + private static final Size DESIRED_PREVIEW_SIZE = new Size(1280, 720);
90 +
91 + // Start at a medium size, but let the user step up through smaller sizes so they don't get
92 + // immediately stuck processing a large image.
93 + private int desiredSizeIndex = -1;
94 + private int desiredSize = 256;
95 + private int initializedSize = 0;
96 +
97 + private Integer sensorOrientation;
98 +
99 + private long lastProcessingTimeMs;
100 + private Bitmap rgbFrameBitmap = null;
101 + private Bitmap croppedBitmap = null;
102 + private Bitmap cropCopyBitmap = null;
103 +
104 + private final float[] styleVals = new float[NUM_STYLES];
105 + private int[] intValues;
106 + private float[] floatValues;
107 +
108 + private int frameNum = 0;
109 +
110 + private Bitmap textureCopyBitmap;
111 +
112 + private Matrix frameToCropTransform;
113 + private Matrix cropToFrameTransform;
114 +
115 + private BorderedText borderedText;
116 +
117 + private TensorFlowInferenceInterface inferenceInterface;
118 +
119 + private int lastOtherStyle = 1;
120 +
121 + private boolean allZero = false;
122 +
123 + private ImageGridAdapter adapter;
124 + private GridView grid;
125 +
126 + private final OnTouchListener gridTouchAdapter =
127 + new OnTouchListener() {
128 + ImageSlider slider = null;
129 +
130 + @Override
131 + public boolean onTouch(final View v, final MotionEvent event) {
132 + switch (event.getActionMasked()) {
133 + case MotionEvent.ACTION_DOWN:
134 + for (int i = 0; i < NUM_STYLES; ++i) {
135 + final ImageSlider child = adapter.items[i];
136 + final Rect rect = new Rect();
137 + child.getHitRect(rect);
138 + if (rect.contains((int) event.getX(), (int) event.getY())) {
139 + slider = child;
140 + slider.setHilighted(true);
141 + }
142 + }
143 + break;
144 +
145 + case MotionEvent.ACTION_MOVE:
146 + if (slider != null) {
147 + final Rect rect = new Rect();
148 + slider.getHitRect(rect);
149 +
150 + final float newSliderVal =
151 + (float)
152 + Math.min(
153 + 1.0,
154 + Math.max(
155 + 0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight()));
156 +
157 + setStyle(slider, newSliderVal);
158 + }
159 + break;
160 +
161 + case MotionEvent.ACTION_UP:
162 + if (slider != null) {
163 + slider.setHilighted(false);
164 + slider = null;
165 + }
166 + break;
167 +
168 + default: // fall out
169 +
170 + }
171 + return true;
172 + }
173 + };
174 +
175 + @Override
176 + public void onCreate(final Bundle savedInstanceState) {
177 + super.onCreate(savedInstanceState);
178 + }
179 +
180 + @Override
181 + protected int getLayoutId() {
182 + return R.layout.camera_connection_fragment_stylize;
183 + }
184 +
185 + @Override
186 + protected Size getDesiredPreviewFrameSize() {
187 + return DESIRED_PREVIEW_SIZE;
188 + }
189 +
190 + public static Bitmap getBitmapFromAsset(final Context context, final String filePath) {
191 + final AssetManager assetManager = context.getAssets();
192 +
193 + Bitmap bitmap = null;
194 + try {
195 + final InputStream inputStream = assetManager.open(filePath);
196 + bitmap = BitmapFactory.decodeStream(inputStream);
197 + } catch (final IOException e) {
198 + LOGGER.e("Error opening bitmap!", e);
199 + }
200 +
201 + return bitmap;
202 + }
203 +
204 + private class ImageSlider extends ImageView {
205 + private float value = 0.0f;
206 + private boolean hilighted = false;
207 +
208 + private final Paint boxPaint;
209 + private final Paint linePaint;
210 +
211 + public ImageSlider(final Context context) {
212 + super(context);
213 + value = 0.0f;
214 +
215 + boxPaint = new Paint();
216 + boxPaint.setColor(Color.BLACK);
217 + boxPaint.setAlpha(128);
218 +
219 + linePaint = new Paint();
220 + linePaint.setColor(Color.WHITE);
221 + linePaint.setStrokeWidth(10.0f);
222 + linePaint.setStyle(Style.STROKE);
223 + }
224 +
225 + @Override
226 + public void onDraw(final Canvas canvas) {
227 + super.onDraw(canvas);
228 + final float y = (1.0f - value) * canvas.getHeight();
229 +
230 + // If all sliders are zero, don't bother shading anything.
231 + if (!allZero) {
232 + canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint);
233 + }
234 +
235 + if (value > 0.0f) {
236 + canvas.drawLine(0, y, canvas.getWidth(), y, linePaint);
237 + }
238 +
239 + if (hilighted) {
240 + canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint);
241 + }
242 + }
243 +
244 + @Override
245 + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
246 + super.onMeasure(widthMeasureSpec, heightMeasureSpec);
247 + setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
248 + }
249 +
250 + public void setValue(final float value) {
251 + this.value = value;
252 + postInvalidate();
253 + }
254 +
255 + public void setHilighted(final boolean highlighted) {
256 + this.hilighted = highlighted;
257 + this.postInvalidate();
258 + }
259 + }
260 +
261 + private class ImageGridAdapter extends BaseAdapter {
262 + final ImageSlider[] items = new ImageSlider[NUM_STYLES];
263 + final ArrayList<Button> buttons = new ArrayList<>();
264 +
265 + {
266 + final Button sizeButton =
267 + new Button(StylizeActivity.this) {
268 + @Override
269 + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
270 + super.onMeasure(widthMeasureSpec, heightMeasureSpec);
271 + setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
272 + }
273 + };
274 + sizeButton.setText("" + desiredSize);
275 + sizeButton.setOnClickListener(
276 + new OnClickListener() {
277 + @Override
278 + public void onClick(final View v) {
279 + desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length;
280 + desiredSize = SIZES[desiredSizeIndex];
281 + sizeButton.setText("" + desiredSize);
282 + sizeButton.postInvalidate();
283 + }
284 + });
285 +
286 + final Button saveButton =
287 + new Button(StylizeActivity.this) {
288 + @Override
289 + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
290 + super.onMeasure(widthMeasureSpec, heightMeasureSpec);
291 + setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
292 + }
293 + };
294 + saveButton.setText("save");
295 + saveButton.setTextSize(12);
296 +
297 + saveButton.setOnClickListener(
298 + new OnClickListener() {
299 + @Override
300 + public void onClick(final View v) {
301 + if (textureCopyBitmap != null) {
302 + // TODO(andrewharp): Save as jpeg with guaranteed unique filename.
303 + ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png");
304 + Toast.makeText(
305 + StylizeActivity.this,
306 + "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png",
307 + Toast.LENGTH_LONG)
308 + .show();
309 + }
310 + }
311 + });
312 +
313 + buttons.add(sizeButton);
314 + buttons.add(saveButton);
315 +
316 + for (int i = 0; i < NUM_STYLES; ++i) {
317 + LOGGER.v("Creating item %d", i);
318 +
319 + if (items[i] == null) {
320 + final ImageSlider slider = new ImageSlider(StylizeActivity.this);
321 + final Bitmap bm =
322 + getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg");
323 + slider.setImageBitmap(bm);
324 +
325 + items[i] = slider;
326 + }
327 + }
328 + }
329 +
330 + @Override
331 + public int getCount() {
332 + return buttons.size() + NUM_STYLES;
333 + }
334 +
335 + @Override
336 + public Object getItem(final int position) {
337 + if (position < buttons.size()) {
338 + return buttons.get(position);
339 + } else {
340 + return items[position - buttons.size()];
341 + }
342 + }
343 +
344 + @Override
345 + public long getItemId(final int position) {
346 + return getItem(position).hashCode();
347 + }
348 +
349 + @Override
350 + public View getView(final int position, final View convertView, final ViewGroup parent) {
351 + if (convertView != null) {
352 + return convertView;
353 + }
354 + return (View) getItem(position);
355 + }
356 + }
357 +
358 + @Override
359 + public void onPreviewSizeChosen(final Size size, final int rotation) {
360 + final float textSizePx = TypedValue.applyDimension(
361 + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
362 + borderedText = new BorderedText(textSizePx);
363 + borderedText.setTypeface(Typeface.MONOSPACE);
364 +
365 + inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
366 +
367 + previewWidth = size.getWidth();
368 + previewHeight = size.getHeight();
369 +
370 + final Display display = getWindowManager().getDefaultDisplay();
371 + final int screenOrientation = display.getRotation();
372 +
373 + LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
374 +
375 + sensorOrientation = rotation + screenOrientation;
376 +
377 + addCallback(
378 + new DrawCallback() {
379 + @Override
380 + public void drawCallback(final Canvas canvas) {
381 + renderDebug(canvas);
382 + }
383 + });
384 +
385 + adapter = new ImageGridAdapter();
386 + grid = (GridView) findViewById(R.id.grid_layout);
387 + grid.setAdapter(adapter);
388 + grid.setOnTouchListener(gridTouchAdapter);
389 +
390 + // Change UI on Android TV
391 + UiModeManager uiModeManager = (UiModeManager) getSystemService(UI_MODE_SERVICE);
392 + if (uiModeManager.getCurrentModeType() == Configuration.UI_MODE_TYPE_TELEVISION) {
393 + DisplayMetrics displayMetrics = new DisplayMetrics();
394 + getWindowManager().getDefaultDisplay().getMetrics(displayMetrics);
395 + int styleSelectorHeight = displayMetrics.heightPixels;
396 + int styleSelectorWidth = displayMetrics.widthPixels - styleSelectorHeight;
397 + RelativeLayout.LayoutParams layoutParams = new RelativeLayout.LayoutParams(styleSelectorWidth, ViewGroup.LayoutParams.MATCH_PARENT);
398 +
399 + // Calculate number of style in a row, so all the style can show up without scrolling
400 + int numOfStylePerRow = 3;
401 + while (styleSelectorWidth / numOfStylePerRow * Math.ceil((float) (adapter.getCount() - 2) / numOfStylePerRow) > styleSelectorHeight) {
402 + numOfStylePerRow++;
403 + }
404 + grid.setNumColumns(numOfStylePerRow);
405 + layoutParams.addRule(RelativeLayout.ALIGN_PARENT_RIGHT);
406 + grid.setLayoutParams(layoutParams);
407 + adapter.buttons.clear();
408 + }
409 +
410 + setStyle(adapter.items[0], 1.0f);
411 + }
412 +
413 + private void setStyle(final ImageSlider slider, final float value) {
414 + slider.setValue(value);
415 +
416 + if (NORMALIZE_SLIDERS) {
417 + // Slider vals correspond directly to the input tensor vals, and normalization is visually
418 + // maintained by remanipulating non-selected sliders.
419 + float otherSum = 0.0f;
420 +
421 + for (int i = 0; i < NUM_STYLES; ++i) {
422 + if (adapter.items[i] != slider) {
423 + otherSum += adapter.items[i].value;
424 + }
425 + }
426 +
427 + if (otherSum > 0.0) {
428 + float highestOtherVal = 0;
429 + final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f;
430 + for (int i = 0; i < NUM_STYLES; ++i) {
431 + final ImageSlider child = adapter.items[i];
432 + if (child == slider) {
433 + continue;
434 + }
435 + final float newVal = child.value * factor;
436 + child.setValue(newVal > 0.01f ? newVal : 0.0f);
437 +
438 + if (child.value > highestOtherVal) {
439 + lastOtherStyle = i;
440 + highestOtherVal = child.value;
441 + }
442 + }
443 + } else {
444 + // Everything else is 0, so just pick a suitable slider to push up when the
445 + // selected one goes down.
446 + if (adapter.items[lastOtherStyle] == slider) {
447 + lastOtherStyle = (lastOtherStyle + 1) % NUM_STYLES;
448 + }
449 + adapter.items[lastOtherStyle].setValue(1.0f - value);
450 + }
451 + }
452 +
453 + final boolean lastAllZero = allZero;
454 + float sum = 0.0f;
455 + for (int i = 0; i < NUM_STYLES; ++i) {
456 + sum += adapter.items[i].value;
457 + }
458 + allZero = sum == 0.0f;
459 +
460 + // Now update the values used for the input tensor. If nothing is set, mix in everything
461 + // equally. Otherwise everything is normalized to sum to 1.0.
462 + for (int i = 0; i < NUM_STYLES; ++i) {
463 + styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum;
464 +
465 + if (lastAllZero != allZero) {
466 + adapter.items[i].postInvalidate();
467 + }
468 + }
469 + }
470 +
471 + private void resetPreviewBuffers() {
472 + croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
473 +
474 + frameToCropTransform = ImageUtils.getTransformationMatrix(
475 + previewWidth, previewHeight,
476 + desiredSize, desiredSize,
477 + sensorOrientation, true);
478 +
479 + cropToFrameTransform = new Matrix();
480 + frameToCropTransform.invert(cropToFrameTransform);
481 + intValues = new int[desiredSize * desiredSize];
482 + floatValues = new float[desiredSize * desiredSize * 3];
483 + initializedSize = desiredSize;
484 + }
485 +
486 + @Override
487 + protected void processImage() {
488 + if (desiredSize != initializedSize) {
489 + LOGGER.i(
490 + "Initializing at size preview size %dx%d, stylize size %d",
491 + previewWidth, previewHeight, desiredSize);
492 +
493 + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
494 + croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
495 + frameToCropTransform = ImageUtils.getTransformationMatrix(
496 + previewWidth, previewHeight,
497 + desiredSize, desiredSize,
498 + sensorOrientation, true);
499 +
500 + cropToFrameTransform = new Matrix();
501 + frameToCropTransform.invert(cropToFrameTransform);
502 + intValues = new int[desiredSize * desiredSize];
503 + floatValues = new float[desiredSize * desiredSize * 3];
504 + initializedSize = desiredSize;
505 + }
506 + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
507 + final Canvas canvas = new Canvas(croppedBitmap);
508 + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
509 +
510 + // For examining the actual TF input.
511 + if (SAVE_PREVIEW_BITMAP) {
512 + ImageUtils.saveBitmap(croppedBitmap);
513 + }
514 +
515 + runInBackground(
516 + new Runnable() {
517 + @Override
518 + public void run() {
519 + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
520 + final long startTime = SystemClock.uptimeMillis();
521 + stylizeImage(croppedBitmap);
522 + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
523 + textureCopyBitmap = Bitmap.createBitmap(croppedBitmap);
524 + requestRender();
525 + readyForNextImage();
526 + }
527 + });
528 + if (desiredSize != initializedSize) {
529 + resetPreviewBuffers();
530 + }
531 + }
532 +
533 + private void stylizeImage(final Bitmap bitmap) {
534 + ++frameNum;
535 + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
536 +
537 + if (DEBUG_MODEL) {
538 + // Create a white square that steps through a black background 1 pixel per frame.
539 + final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth();
540 + final int centerY = bitmap.getHeight() / 2;
541 + final int squareSize = 10;
542 + for (int i = 0; i < intValues.length; ++i) {
543 + final int x = i % bitmap.getWidth();
544 + final int y = i / bitmap.getHeight();
545 + final float val =
546 + Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f;
547 + floatValues[i * 3] = val;
548 + floatValues[i * 3 + 1] = val;
549 + floatValues[i * 3 + 2] = val;
550 + }
551 + } else {
552 + for (int i = 0; i < intValues.length; ++i) {
553 + final int val = intValues[i];
554 + floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
555 + floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
556 + floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
557 + }
558 + }
559 +
560 + // Copy the input data into TensorFlow.
561 + LOGGER.i("Width: %s , Height: %s", bitmap.getWidth(), bitmap.getHeight());
562 + inferenceInterface.feed(
563 + INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3);
564 + inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);
565 +
566 + inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug());
567 + inferenceInterface.fetch(OUTPUT_NODE, floatValues);
568 +
569 + for (int i = 0; i < intValues.length; ++i) {
570 + intValues[i] =
571 + 0xFF000000
572 + | (((int) (floatValues[i * 3] * 255)) << 16)
573 + | (((int) (floatValues[i * 3 + 1] * 255)) << 8)
574 + | ((int) (floatValues[i * 3 + 2] * 255));
575 + }
576 +
577 + bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
578 + }
579 +
580 + private void renderDebug(final Canvas canvas) {
581 + // TODO(andrewharp): move result display to its own View instead of using debug overlay.
582 + final Bitmap texture = textureCopyBitmap;
583 + if (texture != null) {
584 + final Matrix matrix = new Matrix();
585 + final float scaleFactor =
586 + DEBUG_MODEL
587 + ? 4.0f
588 + : Math.min(
589 + (float) canvas.getWidth() / texture.getWidth(),
590 + (float) canvas.getHeight() / texture.getHeight());
591 + matrix.postScale(scaleFactor, scaleFactor);
592 + canvas.drawBitmap(texture, matrix, new Paint());
593 + }
594 +
595 + if (!isDebug()) {
596 + return;
597 + }
598 +
599 + final Bitmap copy = cropCopyBitmap;
600 + if (copy == null) {
601 + return;
602 + }
603 +
604 + canvas.drawColor(0x55000000);
605 +
606 + final Matrix matrix = new Matrix();
607 + final float scaleFactor = 2;
608 + matrix.postScale(scaleFactor, scaleFactor);
609 + matrix.postTranslate(
610 + canvas.getWidth() - copy.getWidth() * scaleFactor,
611 + canvas.getHeight() - copy.getHeight() * scaleFactor);
612 + canvas.drawBitmap(copy, matrix, new Paint());
613 +
614 + final Vector<String> lines = new Vector<>();
615 +
616 + final String[] statLines = inferenceInterface.getStatString().split("\n");
617 + Collections.addAll(lines, statLines);
618 +
619 + lines.add("");
620 +
621 + lines.add("Frame: " + previewWidth + "x" + previewHeight);
622 + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
623 + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
624 + lines.add("Rotation: " + sensorOrientation);
625 + lines.add("Inference time: " + lastProcessingTimeMs + "ms");
626 + lines.add("Desired size: " + desiredSize);
627 + lines.add("Initialized size: " + initializedSize);
628 +
629 + borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
630 + }
631 +
632 + @Override
633 + public boolean onKeyDown(int keyCode, KeyEvent event) {
634 + int moveOffset = 0;
635 + switch (keyCode) {
636 + case KeyEvent.KEYCODE_DPAD_LEFT:
637 + moveOffset = -1;
638 + break;
639 + case KeyEvent.KEYCODE_DPAD_RIGHT:
640 + moveOffset = 1;
641 + break;
642 + case KeyEvent.KEYCODE_DPAD_UP:
643 + moveOffset = -1 * grid.getNumColumns();
644 + break;
645 + case KeyEvent.KEYCODE_DPAD_DOWN:
646 + moveOffset = grid.getNumColumns();
647 + break;
648 + default:
649 + return super.onKeyDown(keyCode, event);
650 + }
651 +
652 + // get the highest selected style
653 + int currentSelect = 0;
654 + float highestValue = 0;
655 + for (int i = 0; i < adapter.getCount(); i++) {
656 + if (adapter.items[i].value > highestValue) {
657 + currentSelect = i;
658 + highestValue = adapter.items[i].value;
659 + }
660 + }
661 + setStyle(adapter.items[(currentSelect + moveOffset + adapter.getCount()) % adapter.getCount()], 1);
662 +
663 + return true;
664 + }
665 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.content.res.AssetManager;
19 +import android.graphics.Bitmap;
20 +import android.os.Trace;
21 +import android.util.Log;
22 +import java.io.BufferedReader;
23 +import java.io.IOException;
24 +import java.io.InputStreamReader;
25 +import java.util.ArrayList;
26 +import java.util.Comparator;
27 +import java.util.List;
28 +import java.util.PriorityQueue;
29 +import java.util.Vector;
30 +import org.tensorflow.Operation;
31 +import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
32 +
33 +/** A classifier specialized to label images using TensorFlow. */
34 +public class TensorFlowImageClassifier implements Classifier {
35 + private static final String TAG = "TensorFlowImageClassifier";
36 +
37 + // Only return this many results with at least this confidence.
38 + private static final int MAX_RESULTS = 3;
39 + private static final float THRESHOLD = 0.1f;
40 +
41 + // Config values.
42 + private String inputName;
43 + private String outputName;
44 + private int inputSize;
45 + private int imageMean;
46 + private float imageStd;
47 +
48 + // Pre-allocated buffers.
49 + private Vector<String> labels = new Vector<String>();
50 + private int[] intValues;
51 + private float[] floatValues;
52 + private float[] outputs;
53 + private String[] outputNames;
54 +
55 + private boolean logStats = false;
56 +
57 + private TensorFlowInferenceInterface inferenceInterface;
58 +
59 + private TensorFlowImageClassifier() {}
60 +
61 + /**
62 + * Initializes a native TensorFlow session for classifying images.
63 + *
64 + * @param assetManager The asset manager to be used to load assets.
65 + * @param modelFilename The filepath of the model GraphDef protocol buffer.
66 + * @param labelFilename The filepath of label file for classes.
67 + * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
68 + * @param imageMean The assumed mean of the image values.
69 + * @param imageStd The assumed std of the image values.
70 + * @param inputName The label of the image input node.
71 + * @param outputName The label of the output node.
72 + * @throws IOException
73 + */
74 + public static Classifier create(
75 + AssetManager assetManager,
76 + String modelFilename,
77 + String labelFilename,
78 + int inputSize,
79 + int imageMean,
80 + float imageStd,
81 + String inputName,
82 + String outputName) {
83 + TensorFlowImageClassifier c = new TensorFlowImageClassifier();
84 + c.inputName = inputName;
85 + c.outputName = outputName;
86 +
87 + // Read the label names into memory.
88 + // TODO(andrewharp): make this handle non-assets.
89 + String actualFilename = labelFilename.split("file:///android_asset/")[1];
90 + Log.i(TAG, "Reading labels from: " + actualFilename);
91 + BufferedReader br = null;
92 + try {
93 + br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
94 + String line;
95 + while ((line = br.readLine()) != null) {
96 + c.labels.add(line);
97 + }
98 + br.close();
99 + } catch (IOException e) {
100 + throw new RuntimeException("Problem reading label file!" , e);
101 + }
102 +
103 + c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
104 +
105 + // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
106 + final Operation operation = c.inferenceInterface.graphOperation(outputName);
107 + final int numClasses = (int) operation.output(0).shape().size(1);
108 + Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
109 +
110 + // Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
111 + // the placeholder node for input in the graphdef typically used does not specify a shape, so it
112 + // must be passed in as a parameter.
113 + c.inputSize = inputSize;
114 + c.imageMean = imageMean;
115 + c.imageStd = imageStd;
116 +
117 + // Pre-allocate buffers.
118 + c.outputNames = new String[] {outputName};
119 + c.intValues = new int[inputSize * inputSize];
120 + c.floatValues = new float[inputSize * inputSize * 3];
121 + c.outputs = new float[numClasses];
122 +
123 + return c;
124 + }
125 +
126 + @Override
127 + public List<Recognition> recognizeImage(final Bitmap bitmap) {
128 + // Log this method so that it can be analyzed with systrace.
129 + Trace.beginSection("recognizeImage");
130 +
131 + Trace.beginSection("preprocessBitmap");
132 + // Preprocess the image data from 0-255 int to normalized float based
133 + // on the provided parameters.
134 + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
135 + for (int i = 0; i < intValues.length; ++i) {
136 + final int val = intValues[i];
137 + floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
138 + floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
139 + floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
140 + }
141 + Trace.endSection();
142 +
143 + // Copy the input data into TensorFlow.
144 + Trace.beginSection("feed");
145 + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
146 + Trace.endSection();
147 +
148 + // Run the inference call.
149 + Trace.beginSection("run");
150 + inferenceInterface.run(outputNames, logStats);
151 + Trace.endSection();
152 +
153 + // Copy the output Tensor back into the output array.
154 + Trace.beginSection("fetch");
155 + inferenceInterface.fetch(outputName, outputs);
156 + Trace.endSection();
157 +
158 + // Find the best classifications.
159 + PriorityQueue<Recognition> pq =
160 + new PriorityQueue<Recognition>(
161 + 3,
162 + new Comparator<Recognition>() {
163 + @Override
164 + public int compare(Recognition lhs, Recognition rhs) {
165 + // Intentionally reversed to put high confidence at the head of the queue.
166 + return Float.compare(rhs.getConfidence(), lhs.getConfidence());
167 + }
168 + });
169 + for (int i = 0; i < outputs.length; ++i) {
170 + if (outputs[i] > THRESHOLD) {
171 + pq.add(
172 + new Recognition(
173 + "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
174 + }
175 + }
176 + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
177 + int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
178 + for (int i = 0; i < recognitionsSize; ++i) {
179 + recognitions.add(pq.poll());
180 + }
181 + Trace.endSection(); // "recognizeImage"
182 + return recognitions;
183 + }
184 +
185 + @Override
186 + public void enableStatLogging(boolean logStats) {
187 + this.logStats = logStats;
188 + }
189 +
190 + @Override
191 + public String getStatString() {
192 + return inferenceInterface.getStatString();
193 + }
194 +
195 + @Override
196 + public void close() {
197 + inferenceInterface.close();
198 + }
199 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.content.res.AssetManager;
19 +import android.graphics.Bitmap;
20 +import android.graphics.RectF;
21 +import android.os.Trace;
22 +import java.io.BufferedReader;
23 +import java.io.FileInputStream;
24 +import java.io.IOException;
25 +import java.io.InputStream;
26 +import java.io.InputStreamReader;
27 +import java.util.ArrayList;
28 +import java.util.Comparator;
29 +import java.util.List;
30 +import java.util.PriorityQueue;
31 +import java.util.StringTokenizer;
32 +import org.tensorflow.Graph;
33 +import org.tensorflow.Operation;
34 +import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
35 +import org.tensorflow.demo.env.Logger;
36 +
37 +/**
38 + * A detector for general purpose object detection as described in Scalable Object Detection using
39 + * Deep Neural Networks (https://arxiv.org/abs/1312.2249).
40 + */
41 +public class TensorFlowMultiBoxDetector implements Classifier {
42 + private static final Logger LOGGER = new Logger();
43 +
44 + // Only return this many results.
45 + private static final int MAX_RESULTS = Integer.MAX_VALUE;
46 +
47 + // Config values.
48 + private String inputName;
49 + private int inputSize;
50 + private int imageMean;
51 + private float imageStd;
52 +
53 + // Pre-allocated buffers.
54 + private int[] intValues;
55 + private float[] floatValues;
56 + private float[] outputLocations;
57 + private float[] outputScores;
58 + private String[] outputNames;
59 + private int numLocations;
60 +
61 + private boolean logStats = false;
62 +
63 + private TensorFlowInferenceInterface inferenceInterface;
64 +
65 + private float[] boxPriors;
66 +
67 + /**
68 + * Initializes a native TensorFlow session for classifying images.
69 + *
70 + * @param assetManager The asset manager to be used to load assets.
71 + * @param modelFilename The filepath of the model GraphDef protocol buffer.
72 + * @param locationFilename The filepath of label file for classes.
73 + * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
74 + * @param imageMean The assumed mean of the image values.
75 + * @param imageStd The assumed std of the image values.
76 + * @param inputName The label of the image input node.
77 + * @param outputName The label of the output node.
78 + */
79 + public static Classifier create(
80 + final AssetManager assetManager,
81 + final String modelFilename,
82 + final String locationFilename,
83 + final int imageMean,
84 + final float imageStd,
85 + final String inputName,
86 + final String outputLocationsName,
87 + final String outputScoresName) {
88 + final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
89 +
90 + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
91 +
92 + final Graph g = d.inferenceInterface.graph();
93 +
94 + d.inputName = inputName;
95 + // The inputName node has a shape of [N, H, W, C], where
96 + // N is the batch size
97 + // H = W are the height and width
98 + // C is the number of channels (3 for our purposes - RGB)
99 + final Operation inputOp = g.operation(inputName);
100 + if (inputOp == null) {
101 + throw new RuntimeException("Failed to find input Node '" + inputName + "'");
102 + }
103 + d.inputSize = (int) inputOp.output(0).shape().size(1);
104 + d.imageMean = imageMean;
105 + d.imageStd = imageStd;
106 + // The outputScoresName node has a shape of [N, NumLocations], where N
107 + // is the batch size.
108 + final Operation outputOp = g.operation(outputScoresName);
109 + if (outputOp == null) {
110 + throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'");
111 + }
112 + d.numLocations = (int) outputOp.output(0).shape().size(1);
113 +
114 + d.boxPriors = new float[d.numLocations * 8];
115 +
116 + try {
117 + d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
118 + } catch (final IOException e) {
119 + throw new RuntimeException("Error initializing box priors from " + locationFilename);
120 + }
121 +
122 + // Pre-allocate buffers.
123 + d.outputNames = new String[] {outputLocationsName, outputScoresName};
124 + d.intValues = new int[d.inputSize * d.inputSize];
125 + d.floatValues = new float[d.inputSize * d.inputSize * 3];
126 + d.outputScores = new float[d.numLocations];
127 + d.outputLocations = new float[d.numLocations * 4];
128 +
129 + return d;
130 + }
131 +
132 + private TensorFlowMultiBoxDetector() {}
133 +
134 + private void loadCoderOptions(
135 + final AssetManager assetManager, final String locationFilename, final float[] boxPriors)
136 + throws IOException {
137 + // Try to be intelligent about opening from assets or sdcard depending on prefix.
138 + final String assetPrefix = "file:///android_asset/";
139 + InputStream is;
140 + if (locationFilename.startsWith(assetPrefix)) {
141 + is = assetManager.open(locationFilename.split(assetPrefix)[1]);
142 + } else {
143 + is = new FileInputStream(locationFilename);
144 + }
145 +
146 + // Read values. Number of values per line doesn't matter, as long as they are separated
147 + // by commas and/or whitespace, and there are exactly numLocations * 8 values total.
148 + // Values are in the order mean, std for each consecutive corner of each box, for a total of 8
149 + // per location.
150 + final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
151 + int priorIndex = 0;
152 + String line;
153 + while ((line = reader.readLine()) != null) {
154 + final StringTokenizer st = new StringTokenizer(line, ", ");
155 + while (st.hasMoreTokens()) {
156 + final String token = st.nextToken();
157 + try {
158 + final float number = Float.parseFloat(token);
159 + boxPriors[priorIndex++] = number;
160 + } catch (final NumberFormatException e) {
161 + // Silently ignore.
162 + }
163 + }
164 + }
165 + if (priorIndex != boxPriors.length) {
166 + throw new RuntimeException(
167 + "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length);
168 + }
169 + }
170 +
171 + private float[] decodeLocationsEncoding(final float[] locationEncoding) {
172 + final float[] locations = new float[locationEncoding.length];
173 + boolean nonZero = false;
174 + for (int i = 0; i < numLocations; ++i) {
175 + for (int j = 0; j < 4; ++j) {
176 + final float currEncoding = locationEncoding[4 * i + j];
177 + nonZero = nonZero || currEncoding != 0.0f;
178 +
179 + final float mean = boxPriors[i * 8 + j * 2];
180 + final float stdDev = boxPriors[i * 8 + j * 2 + 1];
181 + float currentLocation = currEncoding * stdDev + mean;
182 + currentLocation = Math.max(currentLocation, 0.0f);
183 + currentLocation = Math.min(currentLocation, 1.0f);
184 + locations[4 * i + j] = currentLocation;
185 + }
186 + }
187 +
188 + if (!nonZero) {
189 + LOGGER.w("No non-zero encodings; check log for inference errors.");
190 + }
191 + return locations;
192 + }
193 +
194 + private float[] decodeScoresEncoding(final float[] scoresEncoding) {
195 + final float[] scores = new float[scoresEncoding.length];
196 + for (int i = 0; i < scoresEncoding.length; ++i) {
197 + scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i])));
198 + }
199 + return scores;
200 + }
201 +
202 + @Override
203 + public List<Recognition> recognizeImage(final Bitmap bitmap) {
204 + // Log this method so that it can be analyzed with systrace.
205 + Trace.beginSection("recognizeImage");
206 +
207 + Trace.beginSection("preprocessBitmap");
208 + // Preprocess the image data from 0-255 int to normalized float based
209 + // on the provided parameters.
210 + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
211 +
212 + for (int i = 0; i < intValues.length; ++i) {
213 + floatValues[i * 3 + 0] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd;
214 + floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd;
215 + floatValues[i * 3 + 2] = ((intValues[i] & 0xFF) - imageMean) / imageStd;
216 + }
217 + Trace.endSection(); // preprocessBitmap
218 +
219 + // Copy the input data into TensorFlow.
220 + Trace.beginSection("feed");
221 + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
222 + Trace.endSection();
223 +
224 + // Run the inference call.
225 + Trace.beginSection("run");
226 + inferenceInterface.run(outputNames, logStats);
227 + Trace.endSection();
228 +
229 + // Copy the output Tensor back into the output array.
230 + Trace.beginSection("fetch");
231 + final float[] outputScoresEncoding = new float[numLocations];
232 + final float[] outputLocationsEncoding = new float[numLocations * 4];
233 + inferenceInterface.fetch(outputNames[0], outputLocationsEncoding);
234 + inferenceInterface.fetch(outputNames[1], outputScoresEncoding);
235 + Trace.endSection();
236 +
237 + outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
238 + outputScores = decodeScoresEncoding(outputScoresEncoding);
239 +
240 + // Find the best detections.
241 + final PriorityQueue<Recognition> pq =
242 + new PriorityQueue<Recognition>(
243 + 1,
244 + new Comparator<Recognition>() {
245 + @Override
246 + public int compare(final Recognition lhs, final Recognition rhs) {
247 + // Intentionally reversed to put high confidence at the head of the queue.
248 + return Float.compare(rhs.getConfidence(), lhs.getConfidence());
249 + }
250 + });
251 +
252 + // Scale them back to the input size.
253 + for (int i = 0; i < outputScores.length; ++i) {
254 + final RectF detection =
255 + new RectF(
256 + outputLocations[4 * i] * inputSize,
257 + outputLocations[4 * i + 1] * inputSize,
258 + outputLocations[4 * i + 2] * inputSize,
259 + outputLocations[4 * i + 3] * inputSize);
260 + pq.add(new Recognition("" + i, null, outputScores[i], detection));
261 + }
262 +
263 + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
264 + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
265 + recognitions.add(pq.poll());
266 + }
267 + Trace.endSection(); // "recognizeImage"
268 + return recognitions;
269 + }
270 +
271 + @Override
272 + public void enableStatLogging(final boolean logStats) {
273 + this.logStats = logStats;
274 + }
275 +
276 + @Override
277 + public String getStatString() {
278 + return inferenceInterface.getStatString();
279 + }
280 +
281 + @Override
282 + public void close() {
283 + inferenceInterface.close();
284 + }
285 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.content.res.AssetManager;
19 +import android.graphics.Bitmap;
20 +import android.graphics.RectF;
21 +import android.os.Trace;
22 +import java.io.BufferedReader;
23 +import java.io.IOException;
24 +import java.io.InputStream;
25 +import java.io.InputStreamReader;
26 +import java.util.ArrayList;
27 +import java.util.Comparator;
28 +import java.util.List;
29 +import java.util.PriorityQueue;
30 +import java.util.Vector;
31 +import org.tensorflow.Graph;
32 +import org.tensorflow.Operation;
33 +import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
34 +import org.tensorflow.demo.env.Logger;
35 +
36 +/**
37 + * Wrapper for frozen detection models trained using the Tensorflow Object Detection API:
38 + * github.com/tensorflow/models/tree/master/research/object_detection
39 + */
40 +public class TensorFlowObjectDetectionAPIModel implements Classifier {
41 + private static final Logger LOGGER = new Logger();
42 +
43 + // Only return this many results.
44 + private static final int MAX_RESULTS = 100;
45 +
46 + // Config values.
47 + private String inputName;
48 + private int inputSize;
49 +
50 + // Pre-allocated buffers.
51 + private Vector<String> labels = new Vector<String>();
52 + private int[] intValues;
53 + private byte[] byteValues;
54 + private float[] outputLocations;
55 + private float[] outputScores;
56 + private float[] outputClasses;
57 + private float[] outputNumDetections;
58 + private String[] outputNames;
59 +
60 + private boolean logStats = false;
61 +
62 + private TensorFlowInferenceInterface inferenceInterface;
63 +
64 + /**
65 + * Initializes a native TensorFlow session for classifying images.
66 + *
67 + * @param assetManager The asset manager to be used to load assets.
68 + * @param modelFilename The filepath of the model GraphDef protocol buffer.
69 + * @param labelFilename The filepath of label file for classes.
70 + */
71 + public static Classifier create(
72 + final AssetManager assetManager,
73 + final String modelFilename,
74 + final String labelFilename,
75 + final int inputSize) throws IOException {
76 + final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();
77 +
78 + InputStream labelsInput = null;
79 + String actualFilename = labelFilename.split("file:///android_asset/")[1];
80 + labelsInput = assetManager.open(actualFilename);
81 + BufferedReader br = null;
82 + br = new BufferedReader(new InputStreamReader(labelsInput));
83 + String line;
84 + while ((line = br.readLine()) != null) {
85 + LOGGER.w(line);
86 + d.labels.add(line);
87 + }
88 + br.close();
89 +
90 +
91 + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
92 +
93 + final Graph g = d.inferenceInterface.graph();
94 +
95 + d.inputName = "image_tensor";
96 + // The inputName node has a shape of [N, H, W, C], where
97 + // N is the batch size
98 + // H = W are the height and width
99 + // C is the number of channels (3 for our purposes - RGB)
100 + final Operation inputOp = g.operation(d.inputName);
101 + if (inputOp == null) {
102 + throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
103 + }
104 + d.inputSize = inputSize;
105 + // The outputScoresName node has a shape of [N, NumLocations], where N
106 + // is the batch size.
107 + final Operation outputOp1 = g.operation("detection_scores");
108 + if (outputOp1 == null) {
109 + throw new RuntimeException("Failed to find output Node 'detection_scores'");
110 + }
111 + final Operation outputOp2 = g.operation("detection_boxes");
112 + if (outputOp2 == null) {
113 + throw new RuntimeException("Failed to find output Node 'detection_boxes'");
114 + }
115 + final Operation outputOp3 = g.operation("detection_classes");
116 + if (outputOp3 == null) {
117 + throw new RuntimeException("Failed to find output Node 'detection_classes'");
118 + }
119 +
120 + // Pre-allocate buffers.
121 + d.outputNames = new String[] {"detection_boxes", "detection_scores",
122 + "detection_classes", "num_detections"};
123 + d.intValues = new int[d.inputSize * d.inputSize];
124 + d.byteValues = new byte[d.inputSize * d.inputSize * 3];
125 + d.outputScores = new float[MAX_RESULTS];
126 + d.outputLocations = new float[MAX_RESULTS * 4];
127 + d.outputClasses = new float[MAX_RESULTS];
128 + d.outputNumDetections = new float[1];
129 + return d;
130 + }
131 +
132 + private TensorFlowObjectDetectionAPIModel() {}
133 +
134 + @Override
135 + public List<Recognition> recognizeImage(final Bitmap bitmap) {
136 + // Log this method so that it can be analyzed with systrace.
137 + Trace.beginSection("recognizeImage");
138 +
139 + Trace.beginSection("preprocessBitmap");
140 + // Preprocess the image data to extract R, G and B bytes from int of form 0x00RRGGBB
141 + // on the provided parameters.
142 + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
143 +
144 + for (int i = 0; i < intValues.length; ++i) {
145 + byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
146 + byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
147 + byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
148 + }
149 + Trace.endSection(); // preprocessBitmap
150 +
151 + // Copy the input data into TensorFlow.
152 + Trace.beginSection("feed");
153 + inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
154 + Trace.endSection();
155 +
156 + // Run the inference call.
157 + Trace.beginSection("run");
158 + inferenceInterface.run(outputNames, logStats);
159 + Trace.endSection();
160 +
161 + // Copy the output Tensor back into the output array.
162 + Trace.beginSection("fetch");
163 + outputLocations = new float[MAX_RESULTS * 4];
164 + outputScores = new float[MAX_RESULTS];
165 + outputClasses = new float[MAX_RESULTS];
166 + outputNumDetections = new float[1];
167 + inferenceInterface.fetch(outputNames[0], outputLocations);
168 + inferenceInterface.fetch(outputNames[1], outputScores);
169 + inferenceInterface.fetch(outputNames[2], outputClasses);
170 + inferenceInterface.fetch(outputNames[3], outputNumDetections);
171 + Trace.endSection();
172 +
173 + // Find the best detections.
174 + final PriorityQueue<Recognition> pq =
175 + new PriorityQueue<Recognition>(
176 + 1,
177 + new Comparator<Recognition>() {
178 + @Override
179 + public int compare(final Recognition lhs, final Recognition rhs) {
180 + // Intentionally reversed to put high confidence at the head of the queue.
181 + return Float.compare(rhs.getConfidence(), lhs.getConfidence());
182 + }
183 + });
184 +
185 + // Scale them back to the input size.
186 + for (int i = 0; i < outputScores.length; ++i) {
187 + final RectF detection =
188 + new RectF(
189 + outputLocations[4 * i + 1] * inputSize,
190 + outputLocations[4 * i] * inputSize,
191 + outputLocations[4 * i + 3] * inputSize,
192 + outputLocations[4 * i + 2] * inputSize);
193 + pq.add(
194 + new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection));
195 + }
196 +
197 + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
198 + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
199 + recognitions.add(pq.poll());
200 + }
201 + Trace.endSection(); // "recognizeImage"
202 + return recognitions;
203 + }
204 +
205 + @Override
206 + public void enableStatLogging(final boolean logStats) {
207 + this.logStats = logStats;
208 + }
209 +
210 + @Override
211 + public String getStatString() {
212 + return inferenceInterface.getStatString();
213 + }
214 +
215 + @Override
216 + public void close() {
217 + inferenceInterface.close();
218 + }
219 +}
1 +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo;
17 +
18 +import android.content.res.AssetManager;
19 +import android.graphics.Bitmap;
20 +import android.graphics.RectF;
21 +import android.os.Trace;
22 +import java.util.ArrayList;
23 +import java.util.Comparator;
24 +import java.util.List;
25 +import java.util.PriorityQueue;
26 +import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
27 +import org.tensorflow.demo.env.Logger;
28 +import org.tensorflow.demo.env.SplitTimer;
29 +
30 +/** An object detector that uses TF and a YOLO model to detect objects. */
31 +public class TensorFlowYoloDetector implements Classifier {
32 + private static final Logger LOGGER = new Logger();
33 +
34 + // Only return this many results with at least this confidence.
35 + private static final int MAX_RESULTS = 5;
36 +
37 + private static final int NUM_CLASSES = 1;
38 +
39 + private static final int NUM_BOXES_PER_BLOCK = 5;
40 +
41 + // TODO(andrewharp): allow loading anchors and classes
42 + // from files.
43 + private static final double[] ANCHORS = {
44 + 1.08, 1.19,
45 + 3.42, 4.41,
46 + 6.63, 11.38,
47 + 9.42, 5.11,
48 + 16.62, 10.52
49 + };
50 +
51 + private static final String[] LABELS = {
52 + "dog"
53 + };
54 +
55 + // Config values.
56 + private String inputName;
57 + private int inputSize;
58 +
59 + // Pre-allocated buffers.
60 + private int[] intValues;
61 + private float[] floatValues;
62 + private String[] outputNames;
63 +
64 + private int blockSize;
65 +
66 + private boolean logStats = false;
67 +
68 + private TensorFlowInferenceInterface inferenceInterface;
69 +
70 + /** Initializes a native TensorFlow session for classifying images. */
71 + public static Classifier create(
72 + final AssetManager assetManager,
73 + final String modelFilename,
74 + final int inputSize,
75 + final String inputName,
76 + final String outputName,
77 + final int blockSize) {
78 + TensorFlowYoloDetector d = new TensorFlowYoloDetector();
79 + d.inputName = inputName;
80 + d.inputSize = inputSize;
81 +
82 + // Pre-allocate buffers.
83 + d.outputNames = outputName.split(",");
84 + d.intValues = new int[inputSize * inputSize];
85 + d.floatValues = new float[inputSize * inputSize * 3];
86 + d.blockSize = blockSize;
87 +
88 + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
89 +
90 + return d;
91 + }
92 +
93 + private TensorFlowYoloDetector() {}
94 +
95 + private float expit(final float x) {
96 + return (float) (1. / (1. + Math.exp(-x)));
97 + }
98 +
99 + private void softmax(final float[] vals) {
100 + float max = Float.NEGATIVE_INFINITY;
101 + for (final float val : vals) {
102 + max = Math.max(max, val);
103 + }
104 + float sum = 0.0f;
105 + for (int i = 0; i < vals.length; ++i) {
106 + vals[i] = (float) Math.exp(vals[i] - max);
107 + sum += vals[i];
108 + }
109 + for (int i = 0; i < vals.length; ++i) {
110 + vals[i] = vals[i] / sum;
111 + }
112 + }
113 +
114 + @Override
115 + public List<Recognition> recognizeImage(final Bitmap bitmap) {
116 + final SplitTimer timer = new SplitTimer("recognizeImage");
117 +
118 + // Log this method so that it can be analyzed with systrace.
119 + Trace.beginSection("recognizeImage");
120 +
121 + Trace.beginSection("preprocessBitmap");
122 + // Preprocess the image data from 0-255 int to normalized float based
123 + // on the provided parameters.
124 + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
125 +
126 + for (int i = 0; i < intValues.length; ++i) {
127 + floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;
128 + floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
129 + floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
130 + }
131 + Trace.endSection(); // preprocessBitmap
132 +
133 + // Copy the input data into TensorFlow.
134 + Trace.beginSection("feed");
135 + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
136 + Trace.endSection();
137 +
138 + timer.endSplit("ready for inference");
139 +
140 + // Run the inference call.
141 + Trace.beginSection("run");
142 + inferenceInterface.run(outputNames, logStats);
143 + Trace.endSection();
144 +
145 + timer.endSplit("ran inference");
146 +
147 + // Copy the output Tensor back into the output array.
148 + Trace.beginSection("fetch");
149 + final int gridWidth = bitmap.getWidth() / blockSize;
150 + final int gridHeight = bitmap.getHeight() / blockSize;
151 + final float[] output =
152 + new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
153 + inferenceInterface.fetch(outputNames[0], output);
154 + Trace.endSection();
155 +
156 + // Find the best detections.
157 + final PriorityQueue<Recognition> pq =
158 + new PriorityQueue<Recognition>(
159 + 1,
160 + new Comparator<Recognition>() {
161 + @Override
162 + public int compare(final Recognition lhs, final Recognition rhs) {
163 + // Intentionally reversed to put high confidence at the head of the queue.
164 + return Float.compare(rhs.getConfidence(), lhs.getConfidence());
165 + }
166 + });
167 +
168 + for (int y = 0; y < gridHeight; ++y) {
169 + for (int x = 0; x < gridWidth; ++x) {
170 + for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
171 + final int offset =
172 + (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
173 + + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
174 + + (NUM_CLASSES + 5) * b;
175 +
176 + final float xPos = (x + expit(output[offset + 0])) * blockSize;
177 + final float yPos = (y + expit(output[offset + 1])) * blockSize;
178 +
179 + final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize;
180 + final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize;
181 +
182 + final RectF rect =
183 + new RectF(
184 + Math.max(0, xPos - w / 2),
185 + Math.max(0, yPos - h / 2),
186 + Math.min(bitmap.getWidth() - 1, xPos + w / 2),
187 + Math.min(bitmap.getHeight() - 1, yPos + h / 2));
188 + final float confidence = expit(output[offset + 4]);
189 +
190 + int detectedClass = -1;
191 + float maxClass = 0;
192 +
193 + final float[] classes = new float[NUM_CLASSES];
194 + for (int c = 0; c < NUM_CLASSES; ++c) {
195 + classes[c] = output[offset + 5 + c];
196 + }
197 + softmax(classes);
198 +
199 + for (int c = 0; c < NUM_CLASSES; ++c) {
200 + if (classes[c] > maxClass) {
201 + detectedClass = c;
202 + maxClass = classes[c];
203 + }
204 + }
205 +
206 + final float confidenceInClass = maxClass * confidence;
207 + if (confidenceInClass > 0.01) {
208 + LOGGER.i(
209 + "%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect);
210 + pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect));
211 + }
212 + }
213 + }
214 + }
215 + timer.endSplit("decoded results");
216 +
217 + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
218 + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
219 + recognitions.add(pq.poll());
220 + }
221 + Trace.endSection(); // "recognizeImage"
222 +
223 + timer.endSplit("processed results");
224 +
225 + return recognitions;
226 + }
227 +
228 + @Override
229 + public void enableStatLogging(final boolean logStats) {
230 + this.logStats = logStats;
231 + }
232 +
233 + @Override
234 + public String getStatString() {
235 + return inferenceInterface.getStatString();
236 + }
237 +
238 + @Override
239 + public void close() {
240 + inferenceInterface.close();
241 + }
242 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.env;
17 +
18 +import android.graphics.Canvas;
19 +import android.graphics.Color;
20 +import android.graphics.Paint;
21 +import android.graphics.Paint.Align;
22 +import android.graphics.Paint.Style;
23 +import android.graphics.Rect;
24 +import android.graphics.Typeface;
25 +import java.util.Vector;
26 +
27 +/**
28 + * A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
29 + */
30 +public class BorderedText {
31 + private final Paint interiorPaint;
32 + private final Paint exteriorPaint;
33 +
34 + private final float textSize;
35 +
36 + /**
37 + * Creates a left-aligned bordered text object with a white interior, and a black exterior with
38 + * the specified text size.
39 + *
40 + * @param textSize text size in pixels
41 + */
42 + public BorderedText(final float textSize) {
43 + this(Color.WHITE, Color.BLACK, textSize);
44 + }
45 +
46 + /**
47 + * Create a bordered text object with the specified interior and exterior colors, text size and
48 + * alignment.
49 + *
50 + * @param interiorColor the interior text color
51 + * @param exteriorColor the exterior text color
52 + * @param textSize text size in pixels
53 + */
54 + public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
55 + interiorPaint = new Paint();
56 + interiorPaint.setTextSize(textSize);
57 + interiorPaint.setColor(interiorColor);
58 + interiorPaint.setStyle(Style.FILL);
59 + interiorPaint.setAntiAlias(false);
60 + interiorPaint.setAlpha(255);
61 +
62 + exteriorPaint = new Paint();
63 + exteriorPaint.setTextSize(textSize);
64 + exteriorPaint.setColor(exteriorColor);
65 + exteriorPaint.setStyle(Style.FILL_AND_STROKE);
66 + exteriorPaint.setStrokeWidth(textSize / 8);
67 + exteriorPaint.setAntiAlias(false);
68 + exteriorPaint.setAlpha(255);
69 +
70 + this.textSize = textSize;
71 + }
72 +
73 + public void setTypeface(Typeface typeface) {
74 + interiorPaint.setTypeface(typeface);
75 + exteriorPaint.setTypeface(typeface);
76 + }
77 +
78 + public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
79 + canvas.drawText(text, posX, posY, exteriorPaint);
80 + canvas.drawText(text, posX, posY, interiorPaint);
81 + }
82 +
83 + public void drawLines(Canvas canvas, final float posX, final float posY, Vector<String> lines) {
84 + int lineNum = 0;
85 + for (final String line : lines) {
86 + drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line);
87 + ++lineNum;
88 + }
89 + }
90 +
91 + public void setInteriorColor(final int color) {
92 + interiorPaint.setColor(color);
93 + }
94 +
95 + public void setExteriorColor(final int color) {
96 + exteriorPaint.setColor(color);
97 + }
98 +
99 + public float getTextSize() {
100 + return textSize;
101 + }
102 +
103 + public void setAlpha(final int alpha) {
104 + interiorPaint.setAlpha(alpha);
105 + exteriorPaint.setAlpha(alpha);
106 + }
107 +
108 + public void getTextBounds(
109 + final String line, final int index, final int count, final Rect lineBounds) {
110 + interiorPaint.getTextBounds(line, index, count, lineBounds);
111 + }
112 +
113 + public void setTextAlign(final Align align) {
114 + interiorPaint.setTextAlign(align);
115 + exteriorPaint.setTextAlign(align);
116 + }
117 +}
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.env;
17 +
18 +import android.graphics.Bitmap;
19 +import android.graphics.Matrix;
20 +import android.os.Environment;
21 +import java.io.File;
22 +import java.io.FileOutputStream;
23 +
24 +/**
25 + * Utility class for manipulating images.
26 + **/
27 +public class ImageUtils {
28 + @SuppressWarnings("unused")
29 + private static final Logger LOGGER = new Logger();
30 +
31 + static {
32 + try {
33 + System.loadLibrary("tensorflow_demo");
34 + } catch (UnsatisfiedLinkError e) {
35 + LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable.");
36 + }
37 + }
38 +
39 + /**
40 + * Utility method to compute the allocated size in bytes of a YUV420SP image
41 + * of the given dimensions.
42 + */
43 + public static int getYUVByteSize(final int width, final int height) {
44 + // The luminance plane requires 1 byte per pixel.
45 + final int ySize = width * height;
46 +
47 + // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up.
48 + // Each 2x2 block takes 2 bytes to encode, one each for U and V.
49 + final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2;
50 +
51 + return ySize + uvSize;
52 + }
53 +
54 + /**
55 + * Saves a Bitmap object to disk for analysis.
56 + *
57 + * @param bitmap The bitmap to save.
58 + */
59 + public static void saveBitmap(final Bitmap bitmap) {
60 + saveBitmap(bitmap, "preview.png");
61 + }
62 +
63 + /**
64 + * Saves a Bitmap object to disk for analysis.
65 + *
66 + * @param bitmap The bitmap to save.
67 + * @param filename The location to save the bitmap to.
68 + */
69 + public static void saveBitmap(final Bitmap bitmap, final String filename) {
70 + final String root =
71 + Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow";
72 + LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root);
73 + final File myDir = new File(root);
74 +
75 + if (!myDir.mkdirs()) {
76 + LOGGER.i("Make dir failed");
77 + }
78 +
79 + final String fname = filename;
80 + final File file = new File(myDir, fname);
81 + if (file.exists()) {
82 + file.delete();
83 + }
84 + try {
85 + final FileOutputStream out = new FileOutputStream(file);
86 + bitmap.compress(Bitmap.CompressFormat.PNG, 99, out);
87 + out.flush();
88 + out.close();
89 + } catch (final Exception e) {
90 + LOGGER.e(e, "Exception!");
91 + }
92 + }
93 +
94 + // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
95 + // are normalized to eight bits.
96 + static final int kMaxChannelValue = 262143;
97 +
98 + // Always prefer the native implementation if available.
99 + private static boolean useNativeConversion = true;
100 +
101 + public static void convertYUV420SPToARGB8888(
102 + byte[] input,
103 + int width,
104 + int height,
105 + int[] output) {
106 + if (useNativeConversion) {
107 + try {
108 + ImageUtils.convertYUV420SPToARGB8888(input, output, width, height, false);
109 + return;
110 + } catch (UnsatisfiedLinkError e) {
111 + LOGGER.w(
112 + "Native YUV420SP -> RGB implementation not found, falling back to Java implementation");
113 + useNativeConversion = false;
114 + }
115 + }
116 +
117 + // Java implementation of YUV420SP to ARGB8888 converting
118 + final int frameSize = width * height;
119 + for (int j = 0, yp = 0; j < height; j++) {
120 + int uvp = frameSize + (j >> 1) * width;
121 + int u = 0;
122 + int v = 0;
123 +
124 + for (int i = 0; i < width; i++, yp++) {
125 + int y = 0xff & input[yp];
126 + if ((i & 1) == 0) {
127 + v = 0xff & input[uvp++];
128 + u = 0xff & input[uvp++];
129 + }
130 +
131 + output[yp] = YUV2RGB(y, u, v);
132 + }
133 + }
134 + }
135 +
136 + private static int YUV2RGB(int y, int u, int v) {
137 + // Adjust and check YUV values
138 + y = (y - 16) < 0 ? 0 : (y - 16);
139 + u -= 128;
140 + v -= 128;
141 +
142 + // This is the floating point equivalent. We do the conversion in integer
143 + // because some Android devices do not have floating point in hardware.
144 + // nR = (int)(1.164 * nY + 2.018 * nU);
145 + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
146 + // nB = (int)(1.164 * nY + 1.596 * nV);
147 + int y1192 = 1192 * y;
148 + int r = (y1192 + 1634 * v);
149 + int g = (y1192 - 833 * v - 400 * u);
150 + int b = (y1192 + 2066 * u);
151 +
152 + // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ]
153 + r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r);
154 + g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g);
155 + b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b);
156 +
157 + return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff);
158 + }
159 +
160 +
161 + public static void convertYUV420ToARGB8888(
162 + byte[] yData,
163 + byte[] uData,
164 + byte[] vData,
165 + int width,
166 + int height,
167 + int yRowStride,
168 + int uvRowStride,
169 + int uvPixelStride,
170 + int[] out) {
171 + if (useNativeConversion) {
172 + try {
173 + convertYUV420ToARGB8888(
174 + yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false);
175 + return;
176 + } catch (UnsatisfiedLinkError e) {
177 + LOGGER.w(
178 + "Native YUV420 -> RGB implementation not found, falling back to Java implementation");
179 + useNativeConversion = false;
180 + }
181 + }
182 +
183 + int yp = 0;
184 + for (int j = 0; j < height; j++) {
185 + int pY = yRowStride * j;
186 + int pUV = uvRowStride * (j >> 1);
187 +
188 + for (int i = 0; i < width; i++) {
189 + int uv_offset = pUV + (i >> 1) * uvPixelStride;
190 +
191 + out[yp++] = YUV2RGB(
192 + 0xff & yData[pY + i],
193 + 0xff & uData[uv_offset],
194 + 0xff & vData[uv_offset]);
195 + }
196 + }
197 + }
198 +
199 +
200 + /**
201 + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The
202 + * input and output must already be allocated and non-null. For efficiency, no error checking is
203 + * performed.
204 + *
205 + * @param input The array of YUV 4:2:0 input data.
206 + * @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
207 + * @param width The width of the input image.
208 + * @param height The height of the input image.
209 + * @param halfSize If true, downsample to 50% in each dimension, otherwise not.
210 + */
211 + private static native void convertYUV420SPToARGB8888(
212 + byte[] input, int[] output, int width, int height, boolean halfSize);
213 +
214 + /**
215 + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
216 + * and height. The input and output must already be allocated and non-null.
217 + * For efficiency, no error checking is performed.
218 + *
219 + * @param y
220 + * @param u
221 + * @param v
222 + * @param uvPixelStride
223 + * @param width The width of the input image.
224 + * @param height The height of the input image.
225 + * @param halfSize If true, downsample to 50% in each dimension, otherwise not.
226 + * @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
227 + */
228 + private static native void convertYUV420ToARGB8888(
229 + byte[] y,
230 + byte[] u,
231 + byte[] v,
232 + int[] output,
233 + int width,
234 + int height,
235 + int yRowStride,
236 + int uvRowStride,
237 + int uvPixelStride,
238 + boolean halfSize);
239 +
240 + /**
241 + * Converts YUV420 semi-planar data to RGB 565 data using the supplied width
242 + * and height. The input and output must already be allocated and non-null.
243 + * For efficiency, no error checking is performed.
244 + *
245 + * @param input The array of YUV 4:2:0 input data.
246 + * @param output A pre-allocated array for the RGB 5:6:5 output data.
247 + * @param width The width of the input image.
248 + * @param height The height of the input image.
249 + */
250 + private static native void convertYUV420SPToRGB565(
251 + byte[] input, byte[] output, int width, int height);
252 +
253 + /**
254 + * Converts 32-bit ARGB8888 image data to YUV420SP data. This is useful, for
255 + * instance, in creating data to feed the classes that rely on raw camera
256 + * preview frames.
257 + *
258 + * @param input An array of input pixels in ARGB8888 format.
259 + * @param output A pre-allocated array for the YUV420SP output data.
260 + * @param width The width of the input image.
261 + * @param height The height of the input image.
262 + */
263 + private static native void convertARGB8888ToYUV420SP(
264 + int[] input, byte[] output, int width, int height);
265 +
266 + /**
267 + * Converts 16-bit RGB565 image data to YUV420SP data. This is useful, for
268 + * instance, in creating data to feed the classes that rely on raw camera
269 + * preview frames.
270 + *
271 + * @param input An array of input pixels in RGB565 format.
272 + * @param output A pre-allocated array for the YUV420SP output data.
273 + * @param width The width of the input image.
274 + * @param height The height of the input image.
275 + */
276 + private static native void convertRGB565ToYUV420SP(
277 + byte[] input, byte[] output, int width, int height);
278 +
279 + /**
280 + * Returns a transformation matrix from one reference frame into another.
281 + * Handles cropping (if maintaining aspect ratio is desired) and rotation.
282 + *
283 + * @param srcWidth Width of source frame.
284 + * @param srcHeight Height of source frame.
285 + * @param dstWidth Width of destination frame.
286 + * @param dstHeight Height of destination frame.
287 + * @param applyRotation Amount of rotation to apply from one frame to another.
288 + * Must be a multiple of 90.
289 + * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
290 + * cropping the image if necessary.
291 + * @return The transformation fulfilling the desired requirements.
292 + */
293 + public static Matrix getTransformationMatrix(
294 + final int srcWidth,
295 + final int srcHeight,
296 + final int dstWidth,
297 + final int dstHeight,
298 + final int applyRotation,
299 + final boolean maintainAspectRatio) {
300 + final Matrix matrix = new Matrix();
301 +
302 + if (applyRotation != 0) {
303 + if (applyRotation % 90 != 0) {
304 + LOGGER.w("Rotation of %d % 90 != 0", applyRotation);
305 + }
306 +
307 + // Translate so center of image is at origin.
308 + matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);
309 +
310 + // Rotate around origin.
311 + matrix.postRotate(applyRotation);
312 + }
313 +
314 + // Account for the already applied rotation, if any, and then determine how
315 + // much scaling is needed for each axis.
316 + final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;
317 +
318 + final int inWidth = transpose ? srcHeight : srcWidth;
319 + final int inHeight = transpose ? srcWidth : srcHeight;
320 +
321 + // Apply scaling if necessary.
322 + if (inWidth != dstWidth || inHeight != dstHeight) {
323 + final float scaleFactorX = dstWidth / (float) inWidth;
324 + final float scaleFactorY = dstHeight / (float) inHeight;
325 +
326 + if (maintainAspectRatio) {
327 + // Scale by minimum factor so that dst is filled completely while
328 + // maintaining the aspect ratio. Some image may fall off the edge.
329 + final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
330 + matrix.postScale(scaleFactor, scaleFactor);
331 + } else {
332 + // Scale exactly to fill dst from src.
333 + matrix.postScale(scaleFactorX, scaleFactorY);
334 + }
335 + }
336 +
337 + if (applyRotation != 0) {
338 + // Translate back from origin centered reference to destination frame.
339 + matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
340 + }
341 +
342 + return matrix;
343 + }
344 +}
1 +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.env;
17 +
18 +import android.util.Log;
19 +
20 +import java.util.HashSet;
21 +import java.util.Set;
22 +
23 +/**
24 + * Wrapper for the platform log function, allows convenient message prefixing and log disabling.
25 + */
26 +public final class Logger {
27 + private static final String DEFAULT_TAG = "tensorflow";
28 + private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG;
29 +
30 + // Classes to be ignored when examining the stack trace
31 + private static final Set<String> IGNORED_CLASS_NAMES;
32 +
33 + static {
34 + IGNORED_CLASS_NAMES = new HashSet<String>(3);
35 + IGNORED_CLASS_NAMES.add("dalvik.system.VMStack");
36 + IGNORED_CLASS_NAMES.add("java.lang.Thread");
37 + IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName());
38 + }
39 +
40 + private final String tag;
41 + private final String messagePrefix;
42 + private int minLogLevel = DEFAULT_MIN_LOG_LEVEL;
43 +
44 + /**
45 + * Creates a Logger using the class name as the message prefix.
46 + *
47 + * @param clazz the simple name of this class is used as the message prefix.
48 + */
49 + public Logger(final Class<?> clazz) {
50 + this(clazz.getSimpleName());
51 + }
52 +
53 + /**
54 + * Creates a Logger using the specified message prefix.
55 + *
56 + * @param messagePrefix is prepended to the text of every message.
57 + */
58 + public Logger(final String messagePrefix) {
59 + this(DEFAULT_TAG, messagePrefix);
60 + }
61 +
62 + /**
63 + * Creates a Logger with a custom tag and a custom message prefix. If the message prefix
64 + * is set to <pre>null</pre>, the caller's class name is used as the prefix.
65 + *
66 + * @param tag identifies the source of a log message.
67 + * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is
68 + * being used
69 + */
70 + public Logger(final String tag, final String messagePrefix) {
71 + this.tag = tag;
72 + final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix;
73 + this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix;
74 + }
75 +
76 + /**
77 + * Creates a Logger using the caller's class name as the message prefix.
78 + */
79 + public Logger() {
80 + this(DEFAULT_TAG, null);
81 + }
82 +
83 + /**
84 + * Creates a Logger using the caller's class name as the message prefix.
85 + */
86 + public Logger(final int minLogLevel) {
87 + this(DEFAULT_TAG, null);
88 + this.minLogLevel = minLogLevel;
89 + }
90 +
91 + public void setMinLogLevel(final int minLogLevel) {
92 + this.minLogLevel = minLogLevel;
93 + }
94 +
95 + public boolean isLoggable(final int logLevel) {
96 + return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel);
97 + }
98 +
99 + /**
100 + * Return caller's simple name.
101 + *
102 + * Android getStackTrace() returns an array that looks like this:
103 + * stackTrace[0]: dalvik.system.VMStack
104 + * stackTrace[1]: java.lang.Thread
105 + * stackTrace[2]: com.google.android.apps.unveil.env.UnveilLogger
106 + * stackTrace[3]: com.google.android.apps.unveil.BaseApplication
107 + *
108 + * This function returns the simple version of the first non-filtered name.
109 + *
110 + * @return caller's simple name
111 + */
112 + private static String getCallerSimpleName() {
113 + // Get the current callstack so we can pull the class of the caller off of it.
114 + final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
115 +
116 + for (final StackTraceElement elem : stackTrace) {
117 + final String className = elem.getClassName();
118 + if (!IGNORED_CLASS_NAMES.contains(className)) {
119 + // We're only interested in the simple name of the class, not the complete package.
120 + final String[] classParts = className.split("\\.");
121 + return classParts[classParts.length - 1];
122 + }
123 + }
124 +
125 + return Logger.class.getSimpleName();
126 + }
127 +
128 + private String toMessage(final String format, final Object... args) {
129 + return messagePrefix + (args.length > 0 ? String.format(format, args) : format);
130 + }
131 +
132 + public void v(final String format, final Object... args) {
133 + if (isLoggable(Log.VERBOSE)) {
134 + Log.v(tag, toMessage(format, args));
135 + }
136 + }
137 +
138 + public void v(final Throwable t, final String format, final Object... args) {
139 + if (isLoggable(Log.VERBOSE)) {
140 + Log.v(tag, toMessage(format, args), t);
141 + }
142 + }
143 +
144 + public void d(final String format, final Object... args) {
145 + if (isLoggable(Log.DEBUG)) {
146 + Log.d(tag, toMessage(format, args));
147 + }
148 + }
149 +
150 + public void d(final Throwable t, final String format, final Object... args) {
151 + if (isLoggable(Log.DEBUG)) {
152 + Log.d(tag, toMessage(format, args), t);
153 + }
154 + }
155 +
156 + public void i(final String format, final Object... args) {
157 + if (isLoggable(Log.INFO)) {
158 + Log.i(tag, toMessage(format, args));
159 + }
160 + }
161 +
162 + public void i(final Throwable t, final String format, final Object... args) {
163 + if (isLoggable(Log.INFO)) {
164 + Log.i(tag, toMessage(format, args), t);
165 + }
166 + }
167 +
168 + public void w(final String format, final Object... args) {
169 + if (isLoggable(Log.WARN)) {
170 + Log.w(tag, toMessage(format, args));
171 + }
172 + }
173 +
174 + public void w(final Throwable t, final String format, final Object... args) {
175 + if (isLoggable(Log.WARN)) {
176 + Log.w(tag, toMessage(format, args), t);
177 + }
178 + }
179 +
180 + public void e(final String format, final Object... args) {
181 + if (isLoggable(Log.ERROR)) {
182 + Log.e(tag, toMessage(format, args));
183 + }
184 + }
185 +
186 + public void e(final Throwable t, final String format, final Object... args) {
187 + if (isLoggable(Log.ERROR)) {
188 + Log.e(tag, toMessage(format, args), t);
189 + }
190 + }
191 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.env;
17 +
18 +import android.graphics.Bitmap;
19 +import android.text.TextUtils;
20 +import java.io.Serializable;
21 +import java.util.ArrayList;
22 +import java.util.List;
23 +
24 +/**
25 + * Size class independent of a Camera object.
26 + */
27 +public class Size implements Comparable<Size>, Serializable {
28 +
29 + // 1.4 went out with this UID so we'll need to maintain it to preserve pending queries when
30 + // upgrading.
31 + public static final long serialVersionUID = 7689808733290872361L;
32 +
33 + public final int width;
34 + public final int height;
35 +
36 + public Size(final int width, final int height) {
37 + this.width = width;
38 + this.height = height;
39 + }
40 +
41 + public Size(final Bitmap bmp) {
42 + this.width = bmp.getWidth();
43 + this.height = bmp.getHeight();
44 + }
45 +
46 + /**
47 + * Rotate a size by the given number of degrees.
48 + * @param size Size to rotate.
49 + * @param rotation Degrees {0, 90, 180, 270} to rotate the size.
50 + * @return Rotated size.
51 + */
52 + public static Size getRotatedSize(final Size size, final int rotation) {
53 + if (rotation % 180 != 0) {
54 + // The phone is portrait, therefore the camera is sideways and frame should be rotated.
55 + return new Size(size.height, size.width);
56 + }
57 + return size;
58 + }
59 +
60 + public static Size parseFromString(String sizeString) {
61 + if (TextUtils.isEmpty(sizeString)) {
62 + return null;
63 + }
64 +
65 + sizeString = sizeString.trim();
66 +
67 + // The expected format is "<width>x<height>".
68 + final String[] components = sizeString.split("x");
69 + if (components.length == 2) {
70 + try {
71 + final int width = Integer.parseInt(components[0]);
72 + final int height = Integer.parseInt(components[1]);
73 + return new Size(width, height);
74 + } catch (final NumberFormatException e) {
75 + return null;
76 + }
77 + } else {
78 + return null;
79 + }
80 + }
81 +
82 + public static List<Size> sizeStringToList(final String sizes) {
83 + final List<Size> sizeList = new ArrayList<Size>();
84 + if (sizes != null) {
85 + final String[] pairs = sizes.split(",");
86 + for (final String pair : pairs) {
87 + final Size size = Size.parseFromString(pair);
88 + if (size != null) {
89 + sizeList.add(size);
90 + }
91 + }
92 + }
93 + return sizeList;
94 + }
95 +
96 + public static String sizeListToString(final List<Size> sizes) {
97 + String sizesString = "";
98 + if (sizes != null && sizes.size() > 0) {
99 + sizesString = sizes.get(0).toString();
100 + for (int i = 1; i < sizes.size(); i++) {
101 + sizesString += "," + sizes.get(i).toString();
102 + }
103 + }
104 + return sizesString;
105 + }
106 +
107 + public final float aspectRatio() {
108 + return (float) width / (float) height;
109 + }
110 +
111 + @Override
112 + public int compareTo(final Size other) {
113 + return width * height - other.width * other.height;
114 + }
115 +
116 + @Override
117 + public boolean equals(final Object other) {
118 + if (other == null) {
119 + return false;
120 + }
121 +
122 + if (!(other instanceof Size)) {
123 + return false;
124 + }
125 +
126 + final Size otherSize = (Size) other;
127 + return (width == otherSize.width && height == otherSize.height);
128 + }
129 +
130 + @Override
131 + public int hashCode() {
132 + return width * 32713 + height;
133 + }
134 +
135 + @Override
136 + public String toString() {
137 + return dimensionsAsString(width, height);
138 + }
139 +
140 + public static final String dimensionsAsString(final int width, final int height) {
141 + return width + "x" + height;
142 + }
143 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.env;
17 +
18 +import android.os.SystemClock;
19 +
20 +/**
21 + * A simple utility timer for measuring CPU time and wall-clock splits.
22 + */
23 +public class SplitTimer {
24 + private final Logger logger;
25 +
26 + private long lastWallTime;
27 + private long lastCpuTime;
28 +
29 + public SplitTimer(final String name) {
30 + logger = new Logger(name);
31 + newSplit();
32 + }
33 +
34 + public void newSplit() {
35 + lastWallTime = SystemClock.uptimeMillis();
36 + lastCpuTime = SystemClock.currentThreadTimeMillis();
37 + }
38 +
39 + public void endSplit(final String splitName) {
40 + final long currWallTime = SystemClock.uptimeMillis();
41 + final long currCpuTime = SystemClock.currentThreadTimeMillis();
42 +
43 + logger.i(
44 + "%s: cpu=%dms wall=%dms",
45 + splitName, currCpuTime - lastCpuTime, currWallTime - lastWallTime);
46 +
47 + lastWallTime = currWallTime;
48 + lastCpuTime = currCpuTime;
49 + }
50 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.tracking;
17 +
18 +import android.content.Context;
19 +import android.graphics.Canvas;
20 +import android.graphics.Color;
21 +import android.graphics.Matrix;
22 +import android.graphics.Paint;
23 +import android.graphics.Paint.Cap;
24 +import android.graphics.Paint.Join;
25 +import android.graphics.Paint.Style;
26 +import android.graphics.RectF;
27 +import android.text.TextUtils;
28 +import android.util.Pair;
29 +import android.util.TypedValue;
30 +import android.widget.Toast;
31 +import java.util.LinkedList;
32 +import java.util.List;
33 +import java.util.Queue;
34 +import org.tensorflow.demo.Classifier.Recognition;
35 +import org.tensorflow.demo.env.BorderedText;
36 +import org.tensorflow.demo.env.ImageUtils;
37 +import org.tensorflow.demo.env.Logger;
38 +
39 +/**
40 + * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
41 + * objects to new detections.
42 + */
43 +public class MultiBoxTracker {
44 + private final Logger logger = new Logger();
45 +
46 + private static final float TEXT_SIZE_DIP = 18;
47 +
48 + // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
49 + // the lower scored box (new or old) will be removed.
50 + private static final float MAX_OVERLAP = 0.2f;
51 +
52 + private static final float MIN_SIZE = 16.0f;
53 +
54 + // Allow replacement of the tracked box with new results if
55 + // correlation has dropped below this level.
56 + private static final float MARGINAL_CORRELATION = 0.75f;
57 +
58 + // Consider object to be lost if correlation falls below this threshold.
59 + private static final float MIN_CORRELATION = 0.3f;
60 +
61 + private static final int[] COLORS = {
62 + Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE,
63 + Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"),
64 + Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"),
65 + Color.parseColor("#AA33AA"), Color.parseColor("#0D0068")
66 + };
67 +
68 + private final Queue<Integer> availableColors = new LinkedList<Integer>();
69 +
70 + public ObjectTracker objectTracker;
71 +
72 + final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
73 +
74 + private static class TrackedRecognition {
75 + ObjectTracker.TrackedObject trackedObject;
76 + RectF location;
77 + float detectionConfidence;
78 + int color;
79 + String title;
80 + }
81 +
82 + private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
83 +
84 + private final Paint boxPaint = new Paint();
85 +
86 + private final float textSizePx;
87 + private final BorderedText borderedText;
88 +
89 + private Matrix frameToCanvasMatrix;
90 +
91 + private int frameWidth;
92 + private int frameHeight;
93 +
94 + private int sensorOrientation;
95 + private Context context;
96 +
97 + public MultiBoxTracker(final Context context) {
98 + this.context = context;
99 + for (final int color : COLORS) {
100 + availableColors.add(color);
101 + }
102 +
103 + boxPaint.setColor(Color.RED);
104 + boxPaint.setStyle(Style.STROKE);
105 + boxPaint.setStrokeWidth(12.0f);
106 + boxPaint.setStrokeCap(Cap.ROUND);
107 + boxPaint.setStrokeJoin(Join.ROUND);
108 + boxPaint.setStrokeMiter(100);
109 +
110 + textSizePx =
111 + TypedValue.applyDimension(
112 + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics());
113 + borderedText = new BorderedText(textSizePx);
114 + }
115 +
116 + private Matrix getFrameToCanvasMatrix() {
117 + return frameToCanvasMatrix;
118 + }
119 +
120 + public synchronized void drawDebug(final Canvas canvas) {
121 + final Paint textPaint = new Paint();
122 + textPaint.setColor(Color.WHITE);
123 + textPaint.setTextSize(60.0f);
124 +
125 + final Paint boxPaint = new Paint();
126 + boxPaint.setColor(Color.RED);
127 + boxPaint.setAlpha(200);
128 + boxPaint.setStyle(Style.STROKE);
129 +
130 + for (final Pair<Float, RectF> detection : screenRects) {
131 + final RectF rect = detection.second;
132 + canvas.drawRect(rect, boxPaint);
133 + canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
134 + borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
135 + }
136 +
137 + if (objectTracker == null) {
138 + return;
139 + }
140 +
141 + // Draw correlations.
142 + for (final TrackedRecognition recognition : trackedObjects) {
143 + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
144 +
145 + final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
146 +
147 + if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
148 + final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
149 + borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
150 + }
151 + }
152 +
153 + final Matrix matrix = getFrameToCanvasMatrix();
154 + objectTracker.drawDebug(canvas, matrix);
155 + }
156 +
157 + public synchronized void trackResults(
158 + final List<Recognition> results, final byte[] frame, final long timestamp) {
159 + logger.i("Processing %d results from %d", results.size(), timestamp);
160 + processResults(timestamp, results, frame);
161 + }
162 +
163 + public synchronized void draw(final Canvas canvas) {
164 + final boolean rotated = sensorOrientation % 180 == 90;
165 + final float multiplier =
166 + Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight),
167 + canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth));
168 + frameToCanvasMatrix =
169 + ImageUtils.getTransformationMatrix(
170 + frameWidth,
171 + frameHeight,
172 + (int) (multiplier * (rotated ? frameHeight : frameWidth)),
173 + (int) (multiplier * (rotated ? frameWidth : frameHeight)),
174 + sensorOrientation,
175 + false);
176 + for (final TrackedRecognition recognition : trackedObjects) {
177 + final RectF trackedPos =
178 + (objectTracker != null)
179 + ? recognition.trackedObject.getTrackedPositionInPreviewFrame()
180 + : new RectF(recognition.location);
181 +
182 + getFrameToCanvasMatrix().mapRect(trackedPos);
183 + boxPaint.setColor(recognition.color);
184 +
185 + final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
186 + canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
187 +
188 + final String labelString =
189 + !TextUtils.isEmpty(recognition.title)
190 + ? String.format("%s %.2f", recognition.title, recognition.detectionConfidence)
191 + : String.format("%.2f", recognition.detectionConfidence);
192 + borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
193 + }
194 + }
195 +
196 + private boolean initialized = false;
197 +
198 + public synchronized void onFrame(
199 + final int w,
200 + final int h,
201 + final int rowStride,
202 + final int sensorOrientation,
203 + final byte[] frame,
204 + final long timestamp) {
205 + if (objectTracker == null && !initialized) {
206 + ObjectTracker.clearInstance();
207 +
208 + logger.i("Initializing ObjectTracker: %dx%d", w, h);
209 + objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
210 + frameWidth = w;
211 + frameHeight = h;
212 + this.sensorOrientation = sensorOrientation;
213 + initialized = true;
214 +
215 + if (objectTracker == null) {
216 + String message =
217 + "Object tracking support not found. "
218 + + "See tensorflow/examples/android/README.md for details.";
219 + Toast.makeText(context, message, Toast.LENGTH_LONG).show();
220 + logger.e(message);
221 + }
222 + }
223 +
224 + if (objectTracker == null) {
225 + return;
226 + }
227 +
228 + objectTracker.nextFrame(frame, null, timestamp, null, true);
229 +
230 + // Clean up any objects not worth tracking any more.
231 + final LinkedList<TrackedRecognition> copyList =
232 + new LinkedList<TrackedRecognition>(trackedObjects);
233 + for (final TrackedRecognition recognition : copyList) {
234 + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
235 + final float correlation = trackedObject.getCurrentCorrelation();
236 + if (correlation < MIN_CORRELATION) {
237 + logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
238 + trackedObject.stopTracking();
239 + trackedObjects.remove(recognition);
240 +
241 + availableColors.add(recognition.color);
242 + }
243 + }
244 + }
245 +
246 + private void processResults(
247 + final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
248 + final List<Pair<Float, Recognition>> rectsToTrack = new LinkedList<Pair<Float, Recognition>>();
249 +
250 + screenRects.clear();
251 + final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
252 +
253 + for (final Recognition result : results) {
254 + if (result.getLocation() == null) {
255 + continue;
256 + }
257 + final RectF detectionFrameRect = new RectF(result.getLocation());
258 +
259 + final RectF detectionScreenRect = new RectF();
260 + rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
261 +
262 + logger.v(
263 + "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
264 +
265 + screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
266 +
267 + if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
268 + logger.w("Degenerate rectangle! " + detectionFrameRect);
269 + continue;
270 + }
271 +
272 + rectsToTrack.add(new Pair<Float, Recognition>(result.getConfidence(), result));
273 + }
274 +
275 + if (rectsToTrack.isEmpty()) {
276 + logger.v("Nothing to track, aborting.");
277 + return;
278 + }
279 +
280 + if (objectTracker == null) {
281 + trackedObjects.clear();
282 + for (final Pair<Float, Recognition> potential : rectsToTrack) {
283 + final TrackedRecognition trackedRecognition = new TrackedRecognition();
284 + trackedRecognition.detectionConfidence = potential.first;
285 + trackedRecognition.location = new RectF(potential.second.getLocation());
286 + trackedRecognition.trackedObject = null;
287 + trackedRecognition.title = potential.second.getTitle();
288 + trackedRecognition.color = COLORS[trackedObjects.size()];
289 + trackedObjects.add(trackedRecognition);
290 +
291 + if (trackedObjects.size() >= COLORS.length) {
292 + break;
293 + }
294 + }
295 + return;
296 + }
297 +
298 + logger.i("%d rects to track", rectsToTrack.size());
299 + for (final Pair<Float, Recognition> potential : rectsToTrack) {
300 + handleDetection(originalFrame, timestamp, potential);
301 + }
302 + }
303 +
304 + private void handleDetection(
305 + final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential) {
306 + final ObjectTracker.TrackedObject potentialObject =
307 + objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy);
308 +
309 + final float potentialCorrelation = potentialObject.getCurrentCorrelation();
310 + logger.v(
311 + "Tracked object went from %s to %s with correlation %.2f",
312 + potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
313 +
314 + if (potentialCorrelation < MARGINAL_CORRELATION) {
315 + logger.v("Correlation too low to begin tracking %s.", potentialObject);
316 + potentialObject.stopTracking();
317 + return;
318 + }
319 +
320 + final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
321 +
322 + float maxIntersect = 0.0f;
323 +
324 + // This is the current tracked object whose color we will take. If left null we'll take the
325 + // first one from the color queue.
326 + TrackedRecognition recogToReplace = null;
327 +
328 + // Look for intersections that will be overridden by this object or an intersection that would
329 + // prevent this one from being placed.
330 + for (final TrackedRecognition trackedRecognition : trackedObjects) {
331 + final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
332 + final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
333 + final RectF intersection = new RectF();
334 + final boolean intersects = intersection.setIntersect(a, b);
335 +
336 + final float intersectArea = intersection.width() * intersection.height();
337 + final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea;
338 + final float intersectOverUnion = intersectArea / totalArea;
339 +
340 + // If there is an intersection with this currently tracked box above the maximum overlap
341 + // percentage allowed, either the new recognition needs to be dismissed or the old
342 + // recognition needs to be removed and possibly replaced with the new one.
343 + if (intersects && intersectOverUnion > MAX_OVERLAP) {
344 + if (potential.first < trackedRecognition.detectionConfidence
345 + && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
346 + // If track for the existing object is still going strong and the detection score was
347 + // good, reject this new object.
348 + potentialObject.stopTracking();
349 + return;
350 + } else {
351 + removeList.add(trackedRecognition);
352 +
353 + // Let the previously tracked object with max intersection amount donate its color to
354 + // the new object.
355 + if (intersectOverUnion > maxIntersect) {
356 + maxIntersect = intersectOverUnion;
357 + recogToReplace = trackedRecognition;
358 + }
359 + }
360 + }
361 + }
362 +
363 + // If we're already tracking the max object and no intersections were found to bump off,
364 + // pick the worst current tracked object to remove, if it's also worse than this candidate
365 + // object.
366 + if (availableColors.isEmpty() && removeList.isEmpty()) {
367 + for (final TrackedRecognition candidate : trackedObjects) {
368 + if (candidate.detectionConfidence < potential.first) {
369 + if (recogToReplace == null
370 + || candidate.detectionConfidence < recogToReplace.detectionConfidence) {
371 + // Save it so that we use this color for the new object.
372 + recogToReplace = candidate;
373 + }
374 + }
375 + }
376 + if (recogToReplace != null) {
377 + logger.v("Found non-intersecting object to remove.");
378 + removeList.add(recogToReplace);
379 + } else {
380 + logger.v("No non-intersecting object found to remove");
381 + }
382 + }
383 +
384 + // Remove everything that got intersected.
385 + for (final TrackedRecognition trackedRecognition : removeList) {
386 + logger.v(
387 + "Removing tracked object %s with detection confidence %.2f, correlation %.2f",
388 + trackedRecognition.trackedObject,
389 + trackedRecognition.detectionConfidence,
390 + trackedRecognition.trackedObject.getCurrentCorrelation());
391 + trackedRecognition.trackedObject.stopTracking();
392 + trackedObjects.remove(trackedRecognition);
393 + if (trackedRecognition != recogToReplace) {
394 + availableColors.add(trackedRecognition.color);
395 + }
396 + }
397 +
398 + if (recogToReplace == null && availableColors.isEmpty()) {
399 + logger.e("No room to track this object, aborting.");
400 + potentialObject.stopTracking();
401 + return;
402 + }
403 +
404 + // Finally safe to say we can track this object.
405 + logger.v(
406 + "Tracking object %s (%s) with detection confidence %.2f at position %s",
407 + potentialObject,
408 + potential.second.getTitle(),
409 + potential.first,
410 + potential.second.getLocation());
411 + final TrackedRecognition trackedRecognition = new TrackedRecognition();
412 + trackedRecognition.detectionConfidence = potential.first;
413 + trackedRecognition.trackedObject = potentialObject;
414 + trackedRecognition.title = potential.second.getTitle();
415 +
416 + // Use the color from a replaced object before taking one from the color queue.
417 + trackedRecognition.color =
418 + recogToReplace != null ? recogToReplace.color : availableColors.poll();
419 + trackedObjects.add(trackedRecognition);
420 + }
421 +}
1 +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 +
3 +Licensed under the Apache License, Version 2.0 (the "License");
4 +you may not use this file except in compliance with the License.
5 +You may obtain a copy of the License at
6 +
7 + http://www.apache.org/licenses/LICENSE-2.0
8 +
9 +Unless required by applicable law or agreed to in writing, software
10 +distributed under the License is distributed on an "AS IS" BASIS,
11 +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +See the License for the specific language governing permissions and
13 +limitations under the License.
14 +==============================================================================*/
15 +
16 +package org.tensorflow.demo.tracking;
17 +
18 +import android.graphics.Canvas;
19 +import android.graphics.Color;
20 +import android.graphics.Matrix;
21 +import android.graphics.Paint;
22 +import android.graphics.PointF;
23 +import android.graphics.RectF;
24 +import android.graphics.Typeface;
25 +import java.util.ArrayList;
26 +import java.util.HashMap;
27 +import java.util.LinkedList;
28 +import java.util.List;
29 +import java.util.Map;
30 +import java.util.Vector;
31 +import javax.microedition.khronos.opengles.GL10;
32 +import org.tensorflow.demo.env.Logger;
33 +import org.tensorflow.demo.env.Size;
34 +
35 +/**
36 + * True object detector/tracker class that tracks objects across consecutive preview frames.
37 + * It provides a simplified Java interface to the analogous native object defined by
38 + * jni/client_vision/tracking/object_tracker.*.
39 + *
40 + * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must
41 + * be allocated by ObjectTracker.getInstance(). In addition, release() should be called
42 + * as soon as the ObjectTracker is no longer needed, and before a new one is created.
43 + *
44 + * nextFrame() should be called as new frames become available, preferably as often as possible.
45 + *
46 + * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects
47 + * are associated with the ObjectTracker that created them, and are only valid while that
48 + * ObjectTracker still exists.
49 + */
50 +public class ObjectTracker {
51 + private static final Logger LOGGER = new Logger();
52 +
53 + private static boolean libraryFound = false;
54 +
55 + static {
56 + try {
57 + System.loadLibrary("tensorflow_demo");
58 + libraryFound = true;
59 + } catch (UnsatisfiedLinkError e) {
60 + LOGGER.e("libtensorflow_demo.so not found, tracking unavailable");
61 + }
62 + }
63 +
64 + private static final boolean DRAW_TEXT = false;
65 +
66 + /**
67 + * How many history points to keep track of and draw in the red history line.
68 + */
69 + private static final int MAX_DEBUG_HISTORY_SIZE = 30;
70 +
71 + /**
72 + * How many frames of optical flow deltas to record.
73 + * TODO(andrewharp): Push this down to the native level so it can be polled
74 + * efficiently into a an array for upload, instead of keeping a duplicate
75 + * copy in Java.
76 + */
77 + private static final int MAX_FRAME_HISTORY_SIZE = 200;
78 +
79 + private static final int DOWNSAMPLE_FACTOR = 2;
80 +
81 + private final byte[] downsampledFrame;
82 +
83 + protected static ObjectTracker instance;
84 +
85 + private final Map<String, TrackedObject> trackedObjects;
86 +
87 + private long lastTimestamp;
88 +
89 + private FrameChange lastKeypoints;
90 +
91 + private final Vector<PointF> debugHistory;
92 +
93 + private final LinkedList<TimestampedDeltas> timestampedDeltas;
94 +
95 + protected final int frameWidth;
96 + protected final int frameHeight;
97 + private final int rowStride;
98 + protected final boolean alwaysTrack;
99 +
100 + private static class TimestampedDeltas {
101 + final long timestamp;
102 + final byte[] deltas;
103 +
104 + public TimestampedDeltas(final long timestamp, final byte[] deltas) {
105 + this.timestamp = timestamp;
106 + this.deltas = deltas;
107 + }
108 + }
109 +
110 + /**
111 + * A simple class that records keypoint information, which includes
112 + * local location, score and type. This will be used in calculating
113 + * FrameChange.
114 + */
115 + public static class Keypoint {
116 + public final float x;
117 + public final float y;
118 + public final float score;
119 + public final int type;
120 +
121 + public Keypoint(final float x, final float y) {
122 + this.x = x;
123 + this.y = y;
124 + this.score = 0;
125 + this.type = -1;
126 + }
127 +
128 + public Keypoint(final float x, final float y, final float score, final int type) {
129 + this.x = x;
130 + this.y = y;
131 + this.score = score;
132 + this.type = type;
133 + }
134 +
135 + Keypoint delta(final Keypoint other) {
136 + return new Keypoint(this.x - other.x, this.y - other.y);
137 + }
138 + }
139 +
140 + /**
141 + * A simple class that could calculate Keypoint delta.
142 + * This class will be used in calculating frame translation delta
143 + * for optical flow.
144 + */
145 + public static class PointChange {
146 + public final Keypoint keypointA;
147 + public final Keypoint keypointB;
148 + Keypoint pointDelta;
149 + private final boolean wasFound;
150 +
151 + public PointChange(final float x1, final float y1,
152 + final float x2, final float y2,
153 + final float score, final int type,
154 + final boolean wasFound) {
155 + this.wasFound = wasFound;
156 +
157 + keypointA = new Keypoint(x1, y1, score, type);
158 + keypointB = new Keypoint(x2, y2);
159 + }
160 +
161 + public Keypoint getDelta() {
162 + if (pointDelta == null) {
163 + pointDelta = keypointB.delta(keypointA);
164 + }
165 + return pointDelta;
166 + }
167 + }
168 +
169 + /** A class that records a timestamped frame translation delta for optical flow. */
170 + public static class FrameChange {
171 + public static final int KEYPOINT_STEP = 7;
172 +
173 + public final Vector<PointChange> pointDeltas;
174 +
175 + private final float minScore;
176 + private final float maxScore;
177 +
178 + public FrameChange(final float[] framePoints) {
179 + float minScore = 100.0f;
180 + float maxScore = -100.0f;
181 +
182 + pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP);
183 +
184 + for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) {
185 + final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR;
186 + final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR;
187 +
188 + final boolean wasFound = framePoints[i + 2] > 0.0f;
189 +
190 + final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR;
191 + final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR;
192 + final float score = framePoints[i + 5];
193 + final int type = (int) framePoints[i + 6];
194 +
195 + minScore = Math.min(minScore, score);
196 + maxScore = Math.max(maxScore, score);
197 +
198 + pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound));
199 + }
200 +
201 + this.minScore = minScore;
202 + this.maxScore = maxScore;
203 + }
204 + }
205 +
206 + public static synchronized ObjectTracker getInstance(
207 + final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
208 + if (!libraryFound) {
209 + LOGGER.e(
210 + "Native object tracking support not found. "
211 + + "See tensorflow/examples/android/README.md for details.");
212 + return null;
213 + }
214 +
215 + if (instance == null) {
216 + instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
217 + instance.init();
218 + } else {
219 + throw new RuntimeException(
220 + "Tried to create a new objectracker before releasing the old one!");
221 + }
222 + return instance;
223 + }
224 +
225 + public static synchronized void clearInstance() {
226 + if (instance != null) {
227 + instance.release();
228 + }
229 + }
230 +
231 + protected ObjectTracker(
232 + final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
233 + this.frameWidth = frameWidth;
234 + this.frameHeight = frameHeight;
235 + this.rowStride = rowStride;
236 + this.alwaysTrack = alwaysTrack;
237 + this.timestampedDeltas = new LinkedList<TimestampedDeltas>();
238 +
239 + trackedObjects = new HashMap<String, TrackedObject>();
240 +
241 + debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE);
242 +
243 + downsampledFrame =
244 + new byte
245 + [(frameWidth + DOWNSAMPLE_FACTOR - 1)
246 + / DOWNSAMPLE_FACTOR
247 + * (frameHeight + DOWNSAMPLE_FACTOR - 1)
248 + / DOWNSAMPLE_FACTOR];
249 + }
250 +
251 + protected void init() {
252 + // The native tracker never sees the full frame, so pre-scale dimensions
253 + // by the downsample factor.
254 + initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack);
255 + }
256 +
257 + private final float[] matrixValues = new float[9];
258 +
259 + private long downsampledTimestamp;
260 +
261 + @SuppressWarnings("unused")
262 + public synchronized void drawOverlay(final GL10 gl,
263 + final Size cameraViewSize, final Matrix matrix) {
264 + final Matrix tempMatrix = new Matrix(matrix);
265 + tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR);
266 + tempMatrix.getValues(matrixValues);
267 + drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues);
268 + }
269 +
270 + public synchronized void nextFrame(
271 + final byte[] frameData, final byte[] uvData,
272 + final long timestamp, final float[] transformationMatrix,
273 + final boolean updateDebugInfo) {
274 + if (downsampledTimestamp != timestamp) {
275 + ObjectTracker.downsampleImageNative(
276 + frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
277 + downsampledTimestamp = timestamp;
278 + }
279 +
280 + // Do Lucas Kanade using the fullframe initializer.
281 + nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix);
282 +
283 + timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR)));
284 + while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) {
285 + timestampedDeltas.removeFirst();
286 + }
287 +
288 + for (final TrackedObject trackedObject : trackedObjects.values()) {
289 + trackedObject.updateTrackedPosition();
290 + }
291 +
292 + if (updateDebugInfo) {
293 + updateDebugHistory();
294 + }
295 +
296 + lastTimestamp = timestamp;
297 + }
298 +
299 + public synchronized void release() {
300 + releaseMemoryNative();
301 + synchronized (ObjectTracker.class) {
302 + instance = null;
303 + }
304 + }
305 +
306 + private void drawHistoryDebug(final Canvas canvas) {
307 + drawHistoryPoint(
308 + canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2);
309 + }
310 +
311 + private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) {
312 + final Paint p = new Paint();
313 + p.setAntiAlias(false);
314 + p.setTypeface(Typeface.SERIF);
315 +
316 + p.setColor(Color.RED);
317 + p.setStrokeWidth(2.0f);
318 +
319 + // Draw the center circle.
320 + p.setColor(Color.GREEN);
321 + canvas.drawCircle(startX, startY, 3.0f, p);
322 +
323 + p.setColor(Color.RED);
324 +
325 + // Iterate through in backwards order.
326 + synchronized (debugHistory) {
327 + final int numPoints = debugHistory.size();
328 + float lastX = startX;
329 + float lastY = startY;
330 + for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) {
331 + final PointF delta = debugHistory.get(numPoints - keypointNum - 1);
332 + final float newX = lastX + delta.x;
333 + final float newY = lastY + delta.y;
334 + canvas.drawLine(lastX, lastY, newX, newY, p);
335 + lastX = newX;
336 + lastY = newY;
337 + }
338 + }
339 + }
340 +
341 + private static int floatToChar(final float value) {
342 + return Math.max(0, Math.min((int) (value * 255.999f), 255));
343 + }
344 +
345 + private void drawKeypointsDebug(final Canvas canvas) {
346 + final Paint p = new Paint();
347 + if (lastKeypoints == null) {
348 + return;
349 + }
350 + final int keypointSize = 3;
351 +
352 + final float minScore = lastKeypoints.minScore;
353 + final float maxScore = lastKeypoints.maxScore;
354 +
355 + for (final PointChange keypoint : lastKeypoints.pointDeltas) {
356 + if (keypoint.wasFound) {
357 + final int r =
358 + floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore));
359 + final int b =
360 + floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore));
361 +
362 + final int color = 0xFF000000 | (r << 16) | b;
363 + p.setColor(color);
364 +
365 + final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y,
366 + keypoint.keypointB.x, keypoint.keypointB.y};
367 + canvas.drawRect(screenPoints[2] - keypointSize,
368 + screenPoints[3] - keypointSize,
369 + screenPoints[2] + keypointSize,
370 + screenPoints[3] + keypointSize, p);
371 + p.setColor(Color.CYAN);
372 + canvas.drawLine(screenPoints[2], screenPoints[3],
373 + screenPoints[0], screenPoints[1], p);
374 +
375 + if (DRAW_TEXT) {
376 + p.setColor(Color.WHITE);
377 + canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score,
378 + keypoint.keypointA.x, keypoint.keypointA.y, p);
379 + }
380 + } else {
381 + p.setColor(Color.YELLOW);
382 + final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y};
383 + canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p);
384 + }
385 + }
386 + }
387 +
388 + private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX,
389 + final float positionY, final float radius) {
390 + final RectF currPosition = getCurrentPosition(timestamp,
391 + new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius));
392 + return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY);
393 + }
394 +
395 + private synchronized RectF getCurrentPosition(final long timestamp, final RectF
396 + oldPosition) {
397 + final RectF downscaledFrameRect = downscaleRect(oldPosition);
398 +
399 + final float[] delta = new float[4];
400 + getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top,
401 + downscaledFrameRect.right, downscaledFrameRect.bottom, delta);
402 +
403 + final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
404 +
405 + return upscaleRect(newPosition);
406 + }
407 +
408 + private void updateDebugHistory() {
409 + lastKeypoints = new FrameChange(getKeypointsNative(false));
410 +
411 + if (lastTimestamp == 0) {
412 + return;
413 + }
414 +
415 + final PointF delta =
416 + getAccumulatedDelta(
417 + lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100);
418 +
419 + synchronized (debugHistory) {
420 + debugHistory.add(delta);
421 +
422 + while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) {
423 + debugHistory.remove(0);
424 + }
425 + }
426 + }
427 +
428 + public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) {
429 + canvas.save();
430 + canvas.setMatrix(frameToCanvas);
431 +
432 + drawHistoryDebug(canvas);
433 + drawKeypointsDebug(canvas);
434 +
435 + canvas.restore();
436 + }
437 +
438 + public Vector<String> getDebugText() {
439 + final Vector<String> lines = new Vector<String>();
440 +
441 + if (lastKeypoints != null) {
442 + lines.add("Num keypoints " + lastKeypoints.pointDeltas.size());
443 + lines.add("Min score: " + lastKeypoints.minScore);
444 + lines.add("Max score: " + lastKeypoints.maxScore);
445 + }
446 +
447 + return lines;
448 + }
449 +
450 + public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) {
451 + final List<byte[]> frameDeltas = new ArrayList<byte[]>();
452 + while (timestampedDeltas.size() > 0) {
453 + final TimestampedDeltas currentDeltas = timestampedDeltas.peek();
454 + if (currentDeltas.timestamp <= endFrameTime) {
455 + frameDeltas.add(currentDeltas.deltas);
456 + timestampedDeltas.removeFirst();
457 + } else {
458 + break;
459 + }
460 + }
461 +
462 + return frameDeltas;
463 + }
464 +
465 + private RectF downscaleRect(final RectF fullFrameRect) {
466 + return new RectF(
467 + fullFrameRect.left / DOWNSAMPLE_FACTOR,
468 + fullFrameRect.top / DOWNSAMPLE_FACTOR,
469 + fullFrameRect.right / DOWNSAMPLE_FACTOR,
470 + fullFrameRect.bottom / DOWNSAMPLE_FACTOR);
471 + }
472 +
473 + private RectF upscaleRect(final RectF downsampledFrameRect) {
474 + return new RectF(
475 + downsampledFrameRect.left * DOWNSAMPLE_FACTOR,
476 + downsampledFrameRect.top * DOWNSAMPLE_FACTOR,
477 + downsampledFrameRect.right * DOWNSAMPLE_FACTOR,
478 + downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR);
479 + }
480 +
481 + /**
482 + * A TrackedObject represents a native TrackedObject, and provides access to the
483 + * relevant native tracking information available after every frame update. They may
484 + * be safely passed around and accessed externally, but will become invalid after
485 + * stopTracking() is called or the related creating ObjectTracker is deactivated.
486 + *
487 + * @author andrewharp@google.com (Andrew Harp)
488 + */
489 + public class TrackedObject {
490 + private final String id;
491 +
492 + private long lastExternalPositionTime;
493 +
494 + private RectF lastTrackedPosition;
495 + private boolean visibleInLastFrame;
496 +
497 + private boolean isDead;
498 +
499 + TrackedObject(final RectF position, final long timestamp, final byte[] data) {
500 + isDead = false;
501 +
502 + id = Integer.toString(this.hashCode());
503 +
504 + lastExternalPositionTime = timestamp;
505 +
506 + synchronized (ObjectTracker.this) {
507 + registerInitialAppearance(position, data);
508 + setPreviousPosition(position, timestamp);
509 + trackedObjects.put(id, this);
510 + }
511 + }
512 +
513 + public void stopTracking() {
514 + checkValidObject();
515 +
516 + synchronized (ObjectTracker.this) {
517 + isDead = true;
518 + forgetNative(id);
519 + trackedObjects.remove(id);
520 + }
521 + }
522 +
523 + public float getCurrentCorrelation() {
524 + checkValidObject();
525 + return ObjectTracker.this.getCurrentCorrelation(id);
526 + }
527 +
528 + void registerInitialAppearance(final RectF position, final byte[] data) {
529 + final RectF externalPosition = downscaleRect(position);
530 + registerNewObjectWithAppearanceNative(id,
531 + externalPosition.left, externalPosition.top,
532 + externalPosition.right, externalPosition.bottom,
533 + data);
534 + }
535 +
536 + synchronized void setPreviousPosition(final RectF position, final long timestamp) {
537 + checkValidObject();
538 + synchronized (ObjectTracker.this) {
539 + if (lastExternalPositionTime > timestamp) {
540 + LOGGER.w("Tried to use older position time!");
541 + return;
542 + }
543 + final RectF externalPosition = downscaleRect(position);
544 + lastExternalPositionTime = timestamp;
545 +
546 + setPreviousPositionNative(id,
547 + externalPosition.left, externalPosition.top,
548 + externalPosition.right, externalPosition.bottom,
549 + lastExternalPositionTime);
550 +
551 + updateTrackedPosition();
552 + }
553 + }
554 +
555 + void setCurrentPosition(final RectF position) {
556 + checkValidObject();
557 + final RectF downsampledPosition = downscaleRect(position);
558 + synchronized (ObjectTracker.this) {
559 + setCurrentPositionNative(id,
560 + downsampledPosition.left, downsampledPosition.top,
561 + downsampledPosition.right, downsampledPosition.bottom);
562 + }
563 + }
564 +
565 + private synchronized void updateTrackedPosition() {
566 + checkValidObject();
567 +
568 + final float[] delta = new float[4];
569 + getTrackedPositionNative(id, delta);
570 + lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
571 +
572 + visibleInLastFrame = isObjectVisible(id);
573 + }
574 +
575 + public synchronized RectF getTrackedPositionInPreviewFrame() {
576 + checkValidObject();
577 +
578 + if (lastTrackedPosition == null) {
579 + return null;
580 + }
581 + return upscaleRect(lastTrackedPosition);
582 + }
583 +
584 + synchronized long getLastExternalPositionTime() {
585 + return lastExternalPositionTime;
586 + }
587 +
588 + public synchronized boolean visibleInLastPreviewFrame() {
589 + return visibleInLastFrame;
590 + }
591 +
592 + private void checkValidObject() {
593 + if (isDead) {
594 + throw new RuntimeException("TrackedObject already removed from tracking!");
595 + } else if (ObjectTracker.this != instance) {
596 + throw new RuntimeException("TrackedObject created with another ObjectTracker!");
597 + }
598 + }
599 + }
600 +
601 + public synchronized TrackedObject trackObject(
602 + final RectF position, final long timestamp, final byte[] frameData) {
603 + if (downsampledTimestamp != timestamp) {
604 + ObjectTracker.downsampleImageNative(
605 + frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
606 + downsampledTimestamp = timestamp;
607 + }
608 + return new TrackedObject(position, timestamp, downsampledFrame);
609 + }
610 +
611 + public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) {
612 + return new TrackedObject(position, lastTimestamp, frameData);
613 + }
614 +
615 + /** ********************* NATIVE CODE ************************************ */
616 +
617 + /** This will contain an opaque pointer to the native ObjectTracker */
618 + private long nativeObjectTracker;
619 +
620 + private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);
621 +
622 + protected native void registerNewObjectWithAppearanceNative(
623 + String objectId, float x1, float y1, float x2, float y2, byte[] data);
624 +
625 + protected native void setPreviousPositionNative(
626 + String objectId, float x1, float y1, float x2, float y2, long timestamp);
627 +
628 + protected native void setCurrentPositionNative(
629 + String objectId, float x1, float y1, float x2, float y2);
630 +
631 + protected native void forgetNative(String key);
632 +
633 + protected native String getModelIdNative(String key);
634 +
635 + protected native boolean haveObject(String key);
636 + protected native boolean isObjectVisible(String key);
637 + protected native float getCurrentCorrelation(String key);
638 +
639 + protected native float getMatchScore(String key);
640 +
641 + protected native void getTrackedPositionNative(String key, float[] points);
642 +
643 + protected native void nextFrameNative(
644 + byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix);
645 +
646 + protected native void releaseMemoryNative();
647 +
648 + protected native void getCurrentPositionNative(long timestamp,
649 + final float positionX1, final float positionY1,
650 + final float positionX2, final float positionY2,
651 + final float[] delta);
652 +
653 + protected native byte[] getKeypointsPacked(float scaleFactor);
654 +
655 + protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints);
656 +
657 + protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas);
658 +
659 + protected static native void downsampleImageNative(
660 + int width, int height, int rowStride, byte[] input, int factor, byte[] output);
661 +}
1 +# Copyright 2015 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +# ==============================================================================
15 +"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
16 +This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
17 +variable values stored in a checkpoint file, and output a GraphDef with all of
18 +the variable ops converted into const ops containing the values of the
19 +variables.
20 +It's useful to do this when we need to load a single file in C++, especially in
21 +environments like mobile or embedded where we may not have access to the
22 +RestoreTensor ops and file loading calls that they rely on.
23 +An example of command-line usage is:
24 +bazel build tensorflow/python/tools:freeze_graph && \
25 +bazel-bin/tensorflow/python/tools/freeze_graph \
26 +--input_graph=some_graph_def.pb \
27 +--input_checkpoint=model.ckpt-8361242 \
28 +--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
29 +You can also look at freeze_graph_test.py for an example of how to use it.
30 +"""
31 +from __future__ import absolute_import
32 +from __future__ import division
33 +from __future__ import print_function
34 +
35 +import tensorflow as tf
36 +
37 +from google.protobuf import text_format
38 +from tensorflow.python.framework import graph_util
39 +
40 +
41 +FLAGS = tf.app.flags.FLAGS
42 +
43 +tf.app.flags.DEFINE_string("input_graph", "",
44 + """TensorFlow 'GraphDef' file to load.""")
45 +tf.app.flags.DEFINE_string("input_saver", "",
46 + """TensorFlow saver file to load.""")
47 +tf.app.flags.DEFINE_string("input_checkpoint", "",
48 + """TensorFlow variables file to load.""")
49 +tf.app.flags.DEFINE_string("output_graph", "",
50 + """Output 'GraphDef' file name.""")
51 +tf.app.flags.DEFINE_boolean("input_binary", False,
52 + """Whether the input files are in binary format.""")
53 +tf.app.flags.DEFINE_string("output_node_names", "",
54 + """The name of the output nodes, comma separated.""")
55 +tf.app.flags.DEFINE_string("restore_op_name", "save/restore_all",
56 + """The name of the master restore operator.""")
57 +tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
58 + """The name of the tensor holding the save path.""")
59 +tf.app.flags.DEFINE_boolean("clear_devices", True,
60 + """Whether to remove device specifications.""")
61 +tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
62 + "initializer nodes to run before freezing.")
63 +
64 +
65 +def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
66 + output_node_names, restore_op_name, filename_tensor_name,
67 + output_graph, clear_devices, initializer_nodes):
68 + """Converts all variables in a graph and checkpoint into constants."""
69 +
70 + if not tf.gfile.Exists(input_graph):
71 + print("Input graph file '" + input_graph + "' does not exist!")
72 + return -1
73 +
74 + if input_saver and not tf.gfile.Exists(input_saver):
75 + print("Input saver file '" + input_saver + "' does not exist!")
76 + return -1
77 +
78 + if not tf.gfile.Glob(input_checkpoint):
79 + print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
80 + return -1
81 +
82 + if not output_node_names:
83 + print("You need to supply the name of a node to --output_node_names.")
84 + return -1
85 +
86 + input_graph_def = tf.GraphDef()
87 + mode = "rb" if input_binary else "r"
88 + with tf.gfile.FastGFile(input_graph, mode) as f:
89 + if input_binary:
90 + input_graph_def.ParseFromString(f.read())
91 + else:
92 + text_format.Merge(f.read(), input_graph_def)
93 + # Remove all the explicit device specifications for this node. This helps to
94 + # make the graph more portable.
95 + if clear_devices:
96 + for node in input_graph_def.node:
97 + node.device = ""
98 + _ = tf.import_graph_def(input_graph_def, name="")
99 +
100 + with tf.Session() as sess:
101 + if input_saver:
102 + with tf.gfile.FastGFile(input_saver, mode) as f:
103 + saver_def = tf.train.SaverDef()
104 + if input_binary:
105 + saver_def.ParseFromString(f.read())
106 + else:
107 + text_format.Merge(f.read(), saver_def)
108 + saver = tf.train.Saver(saver_def=saver_def)
109 + saver.restore(sess, input_checkpoint)
110 + else:
111 + sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
112 + if initializer_nodes:
113 + sess.run(initializer_nodes)
114 + output_graph_def = graph_util.convert_variables_to_constants(
115 + sess, input_graph_def, output_node_names.split(","))
116 +
117 + with tf.gfile.GFile(output_graph, "wb") as f:
118 + f.write(output_graph_def.SerializeToString())
119 + print("%d ops in the final graph." % len(output_graph_def.node))
120 +
121 +
122 +def main(unused_args):
123 + freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
124 + FLAGS.input_checkpoint, FLAGS.output_node_names,
125 + FLAGS.restore_op_name, FLAGS.filename_tensor_name,
126 + FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
127 +
128 +if __name__ == "__main__":
129 + tf.app.run()
...\ No newline at end of file ...\ No newline at end of file