장수창

modified test app

Showing 114 changed files with 0 additions and 16626 deletions
1 -# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore
2 -*.iml
3 -.idea/compiler.xml
4 -.idea/copyright
5 -.idea/dictionaries
6 -.idea/gradle.xml
7 -.idea/libraries
8 -.idea/inspectionProfiles
9 -.idea/misc.xml
10 -.idea/modules.xml
11 -.idea/runConfigurations.xml
12 -.idea/tasks.xml
13 -.idea/workspace.xml
14 -.gradle
15 -local.properties
16 -.DS_Store
17 -build/
18 -gradleBuild/
19 -*.apk
20 -*.ap_
21 -*.dex
22 -*.class
23 -bin/
24 -gen/
25 -out/
26 -*.log
27 -.navigation/
28 -/captures
29 -.externalNativeBuild
1 -<component name="ProjectCodeStyleConfiguration">
2 - <code_scheme name="Project" version="173">
3 - <codeStyleSettings language="XML">
4 - <indentOptions>
5 - <option name="CONTINUATION_INDENT_SIZE" value="4" />
6 - </indentOptions>
7 - <arrangement>
8 - <rules>
9 - <section>
10 - <rule>
11 - <match>
12 - <AND>
13 - <NAME>xmlns:android</NAME>
14 - <XML_ATTRIBUTE />
15 - <XML_NAMESPACE>^$</XML_NAMESPACE>
16 - </AND>
17 - </match>
18 - </rule>
19 - </section>
20 - <section>
21 - <rule>
22 - <match>
23 - <AND>
24 - <NAME>xmlns:.*</NAME>
25 - <XML_ATTRIBUTE />
26 - <XML_NAMESPACE>^$</XML_NAMESPACE>
27 - </AND>
28 - </match>
29 - <order>BY_NAME</order>
30 - </rule>
31 - </section>
32 - <section>
33 - <rule>
34 - <match>
35 - <AND>
36 - <NAME>.*:id</NAME>
37 - <XML_ATTRIBUTE />
38 - <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
39 - </AND>
40 - </match>
41 - </rule>
42 - </section>
43 - <section>
44 - <rule>
45 - <match>
46 - <AND>
47 - <NAME>.*:name</NAME>
48 - <XML_ATTRIBUTE />
49 - <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
50 - </AND>
51 - </match>
52 - </rule>
53 - </section>
54 - <section>
55 - <rule>
56 - <match>
57 - <AND>
58 - <NAME>name</NAME>
59 - <XML_ATTRIBUTE />
60 - <XML_NAMESPACE>^$</XML_NAMESPACE>
61 - </AND>
62 - </match>
63 - </rule>
64 - </section>
65 - <section>
66 - <rule>
67 - <match>
68 - <AND>
69 - <NAME>style</NAME>
70 - <XML_ATTRIBUTE />
71 - <XML_NAMESPACE>^$</XML_NAMESPACE>
72 - </AND>
73 - </match>
74 - </rule>
75 - </section>
76 - <section>
77 - <rule>
78 - <match>
79 - <AND>
80 - <NAME>.*</NAME>
81 - <XML_ATTRIBUTE />
82 - <XML_NAMESPACE>^$</XML_NAMESPACE>
83 - </AND>
84 - </match>
85 - <order>BY_NAME</order>
86 - </rule>
87 - </section>
88 - <section>
89 - <rule>
90 - <match>
91 - <AND>
92 - <NAME>.*</NAME>
93 - <XML_ATTRIBUTE />
94 - <XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
95 - </AND>
96 - </match>
97 - <order>ANDROID_ATTRIBUTE_ORDER</order>
98 - </rule>
99 - </section>
100 - <section>
101 - <rule>
102 - <match>
103 - <AND>
104 - <NAME>.*</NAME>
105 - <XML_ATTRIBUTE />
106 - <XML_NAMESPACE>.*</XML_NAMESPACE>
107 - </AND>
108 - </match>
109 - <order>BY_NAME</order>
110 - </rule>
111 - </section>
112 - </rules>
113 - </arrangement>
114 - </codeStyleSettings>
115 - </code_scheme>
116 -</component>
...\ No newline at end of file ...\ No newline at end of file
1 -<?xml version="1.0" encoding="UTF-8"?>
2 -<project version="4">
3 - <component name="VcsDirectoryMappings">
4 - <mapping directory="$PROJECT_DIR$/../../.." vcs="Git" />
5 - </component>
6 -</project>
...\ No newline at end of file ...\ No newline at end of file
1 -<?xml version="1.0" encoding="UTF-8"?>
2 -<!--
3 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 -
5 - Licensed under the Apache License, Version 2.0 (the "License");
6 - you may not use this file except in compliance with the License.
7 - You may obtain a copy of the License at
8 -
9 - http://www.apache.org/licenses/LICENSE-2.0
10 -
11 - Unless required by applicable law or agreed to in writing, software
12 - distributed under the License is distributed on an "AS IS" BASIS,
13 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 - See the License for the specific language governing permissions and
15 - limitations under the License.
16 --->
17 -
18 -<manifest xmlns:android="http://schemas.android.com/apk/res/android"
19 - package="org.tensorflow.demo">
20 -
21 - <uses-permission android:name="android.permission.CAMERA" />
22 - <uses-feature android:name="android.hardware.camera" />
23 - <uses-feature android:name="android.hardware.camera.autofocus" />
24 - <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
25 - <uses-permission android:name="android.permission.RECORD_AUDIO" />
26 -
27 - <application android:allowBackup="true"
28 - android:debuggable="true"
29 - android:label="@string/app_name"
30 - android:icon="@drawable/ic_launcher"
31 - android:theme="@style/MaterialTheme">
32 -
33 -<!-- <activity android:name="org.tensorflow.demo.ClassifierActivity"-->
34 -<!-- android:screenOrientation="portrait"-->
35 -<!-- android:label="@string/activity_name_classification">-->
36 -<!-- <intent-filter>-->
37 -<!-- <action android:name="android.intent.action.MAIN" />-->
38 -<!-- <category android:name="android.intent.category.LAUNCHER" />-->
39 -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
40 -<!-- </intent-filter>-->
41 -<!-- </activity>-->
42 -
43 - <activity android:name="org.tensorflow.demo.DetectorActivity"
44 - android:screenOrientation="portrait"
45 - android:label="@string/activity_name_detection">
46 - <intent-filter>
47 - <action android:name="android.intent.action.MAIN" />
48 - <category android:name="android.intent.category.LAUNCHER" />
49 - <category android:name="android.intent.category.LEANBACK_LAUNCHER" />
50 - </intent-filter>
51 - </activity>
52 -
53 -<!-- <activity android:name="org.tensorflow.demo.StylizeActivity"-->
54 -<!-- android:screenOrientation="portrait"-->
55 -<!-- android:label="@string/activity_name_stylize">-->
56 -<!-- <intent-filter>-->
57 -<!-- <action android:name="android.intent.action.MAIN" />-->
58 -<!-- <category android:name="android.intent.category.LAUNCHER" />-->
59 -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
60 -<!-- </intent-filter>-->
61 -<!-- </activity>-->
62 -
63 -<!-- <activity android:name="org.tensorflow.demo.SpeechActivity"-->
64 -<!-- android:screenOrientation="portrait"-->
65 -<!-- android:label="@string/activity_name_speech">-->
66 -<!-- <intent-filter>-->
67 -<!-- <action android:name="android.intent.action.MAIN" />-->
68 -<!-- <category android:name="android.intent.category.LAUNCHER" />-->
69 -<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
70 -<!-- </intent-filter>-->
71 -<!-- </activity>-->
72 - </application>
73 -
74 -</manifest>
1 -# Description:
2 -# TensorFlow camera demo app for Android.
3 -
4 -load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
5 -load(
6 - "//tensorflow:tensorflow.bzl",
7 - "tf_copts",
8 -)
9 -
10 -package(
11 - default_visibility = ["//visibility:public"],
12 - licenses = ["notice"], # Apache 2.0
13 -)
14 -
15 -exports_files(["LICENSE"])
16 -
17 -LINKER_SCRIPT = "jni/version_script.lds"
18 -
19 -# libtensorflow_demo.so contains the native code for image colorspace conversion
20 -# and object tracking used by the demo. It does not require TF as a dependency
21 -# to build if STANDALONE_DEMO_LIB is defined.
22 -# TF support for the demo is provided separately by libtensorflow_inference.so.
23 -cc_binary(
24 - name = "libtensorflow_demo.so",
25 - srcs = glob([
26 - "jni/**/*.cc",
27 - "jni/**/*.h",
28 - ]),
29 - copts = tf_copts(),
30 - defines = ["STANDALONE_DEMO_LIB"],
31 - linkopts = [
32 - "-landroid",
33 - "-ldl",
34 - "-ljnigraphics",
35 - "-llog",
36 - "-lm",
37 - "-z defs",
38 - "-s",
39 - "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
40 - ],
41 - linkshared = 1,
42 - linkstatic = 1,
43 - tags = [
44 - "manual",
45 - "notap",
46 - ],
47 - deps = [
48 - LINKER_SCRIPT,
49 - ],
50 -)
51 -
52 -cc_library(
53 - name = "tensorflow_native_libs",
54 - srcs = [
55 - ":libtensorflow_demo.so",
56 - "//tensorflow/tools/android/inference_interface:libtensorflow_inference.so",
57 - ],
58 - tags = [
59 - "manual",
60 - "notap",
61 - ],
62 -)
63 -
64 -android_binary(
65 - name = "tensorflow_demo",
66 - srcs = glob([
67 - "src/**/*.java",
68 - ]),
69 - # Package assets from assets dir as well as all model targets. Remove undesired models
70 - # (and corresponding Activities in source) to reduce APK size.
71 - assets = [
72 - "//tensorflow/examples/android/assets:asset_files",
73 - ":external_assets",
74 - ],
75 - assets_dir = "",
76 - custom_package = "org.tensorflow.demo",
77 - manifest = "AndroidManifest.xml",
78 - resource_files = glob(["res/**"]),
79 - tags = [
80 - "manual",
81 - "notap",
82 - ],
83 - deps = [
84 - ":tensorflow_native_libs",
85 - "//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java",
86 - ],
87 -)
88 -
89 -# LINT.IfChange
90 -filegroup(
91 - name = "external_assets",
92 - srcs = [
93 - "@inception_v1//:model_files",
94 - "@mobile_ssd//:model_files",
95 - "@speech_commands//:model_files",
96 - "@stylize//:model_files",
97 - ],
98 -)
99 -# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle)
100 -
101 -filegroup(
102 - name = "java_files",
103 - srcs = glob(["src/**/*.java"]),
104 -)
105 -
106 -filegroup(
107 - name = "jni_files",
108 - srcs = glob([
109 - "jni/**/*.cc",
110 - "jni/**/*.h",
111 - ]),
112 -)
113 -
114 -filegroup(
115 - name = "resource_files",
116 - srcs = glob(["res/**"]),
117 -)
118 -
119 -exports_files([
120 - "AndroidManifest.xml",
121 -])
1 -# TensorFlow Android Camera Demo
2 -
3 -This folder contains an example application utilizing TensorFlow for Android
4 -devices.
5 -
6 -## Description
7 -
8 -The demos in this folder are designed to give straightforward samples of using
9 -TensorFlow in mobile applications.
10 -
11 -Inference is done using the [TensorFlow Android Inference
12 -Interface](../../tools/android/inference_interface), which may be built
13 -separately if you want a standalone library to drop into your existing
14 -application. Object tracking and efficient YUV -> RGB conversion are handled by
15 -`libtensorflow_demo.so`.
16 -
17 -A device running Android 5.0 (API 21) or higher is required to run the demo due
18 -to the use of the camera2 API, although the native libraries themselves can run
19 -on API >= 14 devices.
20 -
21 -## Current samples:
22 -
23 -1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java):
24 - Uses the [Google Inception](https://arxiv.org/abs/1409.4842)
25 - model to classify camera frames in real-time, displaying the top results
26 - in an overlay on the camera image.
27 -2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
28 - Demonstrates an SSD-Mobilenet model trained using the
29 - [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection/)
30 - introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to
31 - localize and track objects (from 80 categories) in the camera preview
32 - in real-time.
33 -3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java):
34 - Uses a model based on [A Learned Representation For Artistic
35 - Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
36 - image to that of a number of different artists.
37 -4. [TF
38 - Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java):
39 - Runs a simple speech recognition model built by the [audio training
40 - tutorial](https://www.tensorflow.org/versions/master/tutorials/audio_recognition). Listens
41 - for a small set of words, and highlights them in the UI when they are
42 - recognized.
43 -
44 -<img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%">
45 -
46 -## Prebuilt Components:
47 -
48 -The fastest path to trying the demo is to download the [prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
49 -
50 -Also available are precompiled native libraries, and a jcenter package that you
51 -may simply drop into your own applications. See
52 -[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
53 -for more details.
54 -
55 -## Running the Demo
56 -
57 -Once the app is installed it can be started via the "TF Classify", "TF Detect",
58 -"TF Stylize", and "TF Speech" icons, which have the orange TensorFlow logo as
59 -their icon.
60 -
61 -While running the activities, pressing the volume keys on your device will
62 -toggle debug visualizations on/off, rendering additional info to the screen that
63 -may be useful for development purposes.
64 -
65 -## Building in Android Studio using the TensorFlow AAR from JCenter
66 -
67 -The simplest way to compile the demo app yourself, and try out changes to the
68 -project code is to use AndroidStudio. Simply set this `android` directory as the
69 -project root.
70 -
71 -Then edit the `build.gradle` file and change the value of `nativeBuildSystem` to
72 -`'none'` so that the project is built in the simplest way possible:
73 -
74 -```None
75 -def nativeBuildSystem = 'none'
76 -```
77 -
78 -While this project includes full build integration for TensorFlow, this setting
79 -disables it, and uses the TensorFlow Inference Interface package from JCenter.
80 -
81 -Note: Currently, in this build mode, YUV -> RGB is done using a less efficient
82 -Java implementation, and object tracking is not available in the "TF Detect"
83 -activity. Setting the build system to `'cmake'` currently only builds
84 -`libtensorflow_demo.so`, which provides fast YUV -> RGB conversion and object
85 -tracking, while still acquiring TensorFlow support via the downloaded AAR, so it
86 -may be a lightweight way to enable these features.
87 -
88 -For any project that does not include custom low level TensorFlow code, this is
89 -likely sufficient.
90 -
91 -For details on how to include this JCenter package in your own project see
92 -[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
93 -
94 -## Building the Demo with TensorFlow from Source
95 -
96 -Pick your preferred approach below. At the moment, we have full support for
97 -Bazel, and partial support for gradle, cmake, make, and Android Studio.
98 -
99 -As a first step for all build types, clone the TensorFlow repo with:
100 -
101 -```
102 -git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git
103 -```
104 -
105 -Note that `--recurse-submodules` is necessary to prevent some issues with
106 -protobuf compilation.
107 -
108 -### Bazel
109 -
110 -NOTE: Bazel does not currently support building for Android on Windows. Full
111 -support for gradle/cmake builds is coming soon, but in the meantime we suggest
112 -that Windows users download the
113 -[prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk)
114 -instead.
115 -
116 -##### Install Bazel and Android Prerequisites
117 -
118 -Bazel is the primary build system for TensorFlow. To build with Bazel, it and
119 -the Android NDK and SDK must be installed on your system.
120 -
121 -1. Install the latest version of Bazel as per the instructions [on the Bazel
122 - website](https://bazel.build/versions/master/docs/install.html).
123 -2. The Android NDK is required to build the native (C/C++) TensorFlow code. The
124 - current recommended version is 14b, which may be found
125 - [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
126 -3. The Android SDK and build tools may be obtained
127 - [here](https://developer.android.com/tools/revisions/build-tools.html), or
128 - alternatively as part of [Android
129 - Studio](https://developer.android.com/studio/index.html). Build tools API >=
130 - 23 is required to build the TF Android demo (though it will run on API >= 21
131 - devices).
132 -
133 -##### Edit WORKSPACE
134 -
135 -NOTE: As long as you have the SDK and NDK installed, the `./configure` script
136 -will create these rules for you. Answer "Yes" when the script asks to
137 -automatically configure the `./WORKSPACE`.
138 -
139 -The Android entries in
140 -[`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L19-L36) must be uncommented
141 -with the paths filled in appropriately depending on where you installed the NDK
142 -and SDK. Otherwise an error such as: "The external label
143 -'//external:android/sdk' is not bound to anything" will be reported.
144 -
145 -Also edit the API levels for the SDK in WORKSPACE to the highest level you have
146 -installed in your SDK. This must be >= 23 (this is completely independent of the
147 -API level of the demo, which is defined in AndroidManifest.xml). The NDK API
148 -level may remain at 14.
149 -
150 -##### Install Model Files (optional)
151 -
152 -The TensorFlow `GraphDef`s that contain the model definitions and weights are
153 -not packaged in the repo because of their size. They are downloaded
154 -automatically and packaged with the APK by Bazel via a new_http_archive defined
155 -in `WORKSPACE` during the build process, and by Gradle via
156 -download-models.gradle.
157 -
158 -**Optional**: If you wish to place the models in your assets manually, remove
159 -all of the `model_files` entries from the `assets` list in `tensorflow_demo`
160 -found in the [`BUILD`](BUILD#L92) file. Then download and extract the archives
161 -yourself to the `assets` directory in the source tree:
162 -
163 -```bash
164 -BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models
165 -for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip
166 -do
167 - curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP}
168 - unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/
169 -done
170 -```
171 -
172 -This will extract the models and their associated metadata files to the local
173 -assets/ directory.
174 -
175 -If you are using Gradle, make sure to remove download-models.gradle reference
176 -from build.gradle after your manually download models; otherwise gradle might
177 -download models again and overwrite your models.
178 -
179 -##### Build
180 -
181 -After editing your WORKSPACE file to update the SDK/NDK configuration, you may
182 -build the APK. Run this from your workspace root:
183 -
184 -```bash
185 -bazel build --cxxopt='--std=c++11' -c opt //tensorflow/examples/android:tensorflow_demo
186 -```
187 -
188 -##### Install
189 -
190 -Make sure that adb debugging is enabled on your Android 5.0 (API 21) or later
191 -device, then after building use the following command from your workspace root
192 -to install the APK:
193 -
194 -```bash
195 -adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
196 -```
197 -
198 -### Android Studio with Bazel
199 -
200 -Android Studio may be used to build the demo in conjunction with Bazel. First,
201 -make sure that you can build with Bazel following the above directions. Then,
202 -look at [build.gradle](build.gradle) and make sure that the path to Bazel
203 -matches that of your system.
204 -
205 -At this point you can add the tensorflow/examples/android directory as a new
206 -Android Studio project. Click through installing all the Gradle extensions it
207 -requests, and you should be able to have Android Studio build the demo like any
208 -other application (it will call out to Bazel to build the native code with the
209 -NDK).
210 -
211 -### CMake
212 -
213 -Full CMake support for the demo is coming soon, but for now it is possible to
214 -build the TensorFlow Android Inference library using
215 -[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake).
1 -package(
2 - default_visibility = ["//visibility:public"],
3 - licenses = ["notice"], # Apache 2.0
4 -)
5 -
6 -# It is necessary to use this filegroup rather than globbing the files in this
7 -# folder directly the examples/android:tensorflow_demo target due to the fact
8 -# that assets_dir is necessarily set to "" there (to allow using other
9 -# arbitrary targets as assets).
10 -filegroup(
11 - name = "asset_files",
12 - srcs = glob(
13 - ["**/*"],
14 - exclude = ["BUILD"],
15 - ),
16 -)
This file is too large to display.
1 -// This file provides basic support for building the TensorFlow demo
2 -// in Android Studio with Gradle.
3 -//
4 -// Note that Bazel is still used by default to compile the native libs,
5 -// and should be installed at the location noted below. This build file
6 -// automates the process of calling out to it and copying the compiled
7 -// libraries back into the appropriate directory.
8 -//
9 -// Alternatively, experimental support for Makefile builds is provided by
10 -// setting nativeBuildSystem below to 'makefile'. This will allow building the demo
11 -// on Windows machines, but note that full equivalence with the Bazel
12 -// build is not yet guaranteed. See comments below for caveats and tips
13 -// for speeding up the build, such as enabling ccache.
14 -// NOTE: Running a make build will cause subsequent Bazel builds to *fail*
15 -// unless the contrib/makefile/downloads/ and gen/ dirs are deleted afterwards.
16 -
17 -// The cmake build only creates libtensorflow_demo.so. In this situation,
18 -// libtensorflow_inference.so will be acquired via the tensorflow.aar dependency.
19 -
20 -// It is necessary to customize Gradle's build directory, as otherwise
21 -// it will conflict with the BUILD file used by Bazel on case-insensitive OSs.
22 -project.buildDir = 'gradleBuild'
23 -getProject().setBuildDir('gradleBuild')
24 -
25 -buildscript {
26 - repositories {
27 - jcenter()
28 - google()
29 - }
30 -
31 - dependencies {
32 - classpath 'com.android.tools.build:gradle:3.3.1'
33 - classpath 'org.apache.httpcomponents:httpclient:4.5.4'
34 - }
35 -}
36 -
37 -allprojects {
38 - repositories {
39 - jcenter()
40 - google()
41 - }
42 -}
43 -
44 -// set to 'bazel', 'cmake', 'makefile', 'none'
45 -def nativeBuildSystem = 'none'
46 -
47 -// Controls output directory in APK and CPU type for Bazel builds.
48 -// NOTE: Does not affect the Makefile build target API (yet), which currently
49 -// assumes armeabi-v7a. If building with make, changing this will require
50 -// editing the Makefile as well.
51 -// The CMake build has only been tested with armeabi-v7a; others may not work.
52 -def cpuType = 'armeabi-v7a'
53 -
54 -// Output directory in the local directory for packaging into the APK.
55 -def nativeOutDir = 'libs/' + cpuType
56 -
57 -// Default to building with Bazel and override with make if requested.
58 -def nativeBuildRule = 'buildNativeBazel'
59 -def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so'
60 -def inferenceLibPath = '../../../bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so'
61 -
62 -// Override for Makefile builds.
63 -if (nativeBuildSystem == 'makefile') {
64 - nativeBuildRule = 'buildNativeMake'
65 - demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_demo.so'
66 - inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_inference.so'
67 -}
68 -
69 -// If building with Bazel, this is the location of the bazel binary.
70 -// NOTE: Bazel does not yet support building for Android on Windows,
71 -// so in this case the Makefile build must be used as described above.
72 -def bazelLocation = '/usr/local/bin/bazel'
73 -
74 -// import DownloadModels task
75 -project.ext.ASSET_DIR = projectDir.toString() + '/assets'
76 -project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
77 -
78 -// Download default models; if you wish to use your own models then
79 -// place them in the "assets" directory and comment out this line.
80 -apply from: "download-models.gradle"
81 -
82 -apply plugin: 'com.android.application'
83 -
84 -android {
85 - compileSdkVersion 23
86 -
87 - if (nativeBuildSystem == 'cmake') {
88 - defaultConfig {
89 - applicationId = 'org.tensorflow.demo'
90 - minSdkVersion 21
91 - targetSdkVersion 23
92 - ndk {
93 - abiFilters "${cpuType}"
94 - }
95 - externalNativeBuild {
96 - cmake {
97 - arguments '-DANDROID_STL=c++_static'
98 - }
99 - }
100 - }
101 - externalNativeBuild {
102 - cmake {
103 - path './jni/CMakeLists.txt'
104 - }
105 - }
106 - }
107 -
108 - lintOptions {
109 - abortOnError false
110 - }
111 -
112 - sourceSets {
113 - main {
114 - if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') {
115 - // TensorFlow Java API sources.
116 - java {
117 - srcDir '../../java/src/main/java'
118 - exclude '**/examples/**'
119 - }
120 -
121 - // Android TensorFlow wrappers, etc.
122 - java {
123 - srcDir '../../tools/android/inference_interface/java'
124 - }
125 - }
126 - // Android demo app sources.
127 - java {
128 - srcDir 'src'
129 - }
130 -
131 - manifest.srcFile 'AndroidManifest.xml'
132 - resources.srcDirs = ['src']
133 - aidl.srcDirs = ['src']
134 - renderscript.srcDirs = ['src']
135 - res.srcDirs = ['res']
136 - assets.srcDirs = [project.ext.ASSET_DIR]
137 - jniLibs.srcDirs = ['libs']
138 - }
139 -
140 - debug.setRoot('build-types/debug')
141 - release.setRoot('build-types/release')
142 - }
143 - defaultConfig {
144 - targetSdkVersion 23
145 - minSdkVersion 21
146 - }
147 -}
148 -
149 -task buildNativeBazel(type: Exec) {
150 - workingDir '../../..'
151 - commandLine bazelLocation, 'build', '-c', 'opt', \
152 - 'tensorflow/examples/android:tensorflow_native_libs', \
153 - '--crosstool_top=//external:android/crosstool', \
154 - '--cpu=' + cpuType, \
155 - '--host_crosstool_top=@bazel_tools//tools/cpp:toolchain'
156 -}
157 -
158 -task buildNativeMake(type: Exec) {
159 - environment "NDK_ROOT", android.ndkDirectory
160 - // Tip: install ccache and uncomment the following to speed up
161 - // builds significantly.
162 - // environment "CC_PREFIX", 'ccache'
163 - workingDir '../../..'
164 - commandLine 'tensorflow/contrib/makefile/build_all_android.sh', \
165 - '-s', \
166 - 'tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in', \
167 - '-t', \
168 - 'libtensorflow_inference.so libtensorflow_demo.so all' \
169 - , '-a', cpuType \
170 - //, '-T' // Uncomment to skip protobuf and speed up subsequent builds.
171 -}
172 -
173 -
174 -task copyNativeLibs(type: Copy) {
175 - from demoLibPath
176 - from inferenceLibPath
177 - into nativeOutDir
178 - duplicatesStrategy = 'include'
179 - dependsOn nativeBuildRule
180 - fileMode 0644
181 -}
182 -
183 -tasks.whenTaskAdded { task ->
184 - if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') {
185 - if (task.name == 'assembleDebug') {
186 - task.dependsOn 'copyNativeLibs'
187 - }
188 - if (task.name == 'assembleRelease') {
189 - task.dependsOn 'copyNativeLibs'
190 - }
191 - }
192 -}
193 -
194 -dependencies {
195 - if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') {
196 - implementation 'org.tensorflow:tensorflow-android:+'
197 - }
198 -}
1 -/*
2 - * download-models.gradle
3 - * Downloads model files from ${MODEL_URL} into application's asset folder
4 - * Input:
5 - * project.ext.TMP_DIR: absolute path to hold downloaded zip files
6 - * project.ext.ASSET_DIR: absolute path to save unzipped model files
7 - * Output:
8 - * 3 model files will be downloaded into given folder of ext.ASSET_DIR
9 - */
10 -// hard coded model files
11 -// LINT.IfChange
12 -def models = ['inception_v1.zip',
13 - 'object_detection/ssd_mobilenet_v1_android_export.zip',
14 - 'stylize_v1.zip',
15 - 'speech_commands_conv_actions.zip']
16 -// LINT.ThenChange(//tensorflow/examples/android/BUILD)
17 -
18 -// Root URL for model archives
19 -def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models'
20 -
21 -buildscript {
22 - repositories {
23 - jcenter()
24 - }
25 - dependencies {
26 - classpath 'de.undercouch:gradle-download-task:3.2.0'
27 - }
28 -}
29 -
30 -import de.undercouch.gradle.tasks.download.Download
31 -task downloadFile(type: Download){
32 - for (f in models) {
33 - src "${MODEL_URL}/" + f
34 - }
35 - dest new File(project.ext.TMP_DIR)
36 - overwrite true
37 -}
38 -
39 -task extractModels(type: Copy) {
40 - for (f in models) {
41 - def localFile = f.split("/")[-1]
42 - from zipTree(project.ext.TMP_DIR + '/' + localFile)
43 - }
44 -
45 - into file(project.ext.ASSET_DIR)
46 - fileMode 0644
47 - exclude '**/LICENSE'
48 -
49 - def needDownload = false
50 - for (f in models) {
51 - def localFile = f.split("/")[-1]
52 - if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) {
53 - needDownload = true
54 - }
55 - }
56 -
57 - if (needDownload) {
58 - dependsOn downloadFile
59 - }
60 -}
61 -
62 -tasks.whenTaskAdded { task ->
63 - if (task.name == 'assembleDebug') {
64 - task.dependsOn 'extractModels'
65 - }
66 - if (task.name == 'assembleRelease') {
67 - task.dependsOn 'extractModels'
68 - }
69 -}
70 -
1 -#Sat Nov 18 15:06:47 CET 2017
2 -distributionBase=GRADLE_USER_HOME
3 -distributionPath=wrapper/dists
4 -zipStoreBase=GRADLE_USER_HOME
5 -zipStorePath=wrapper/dists
6 -distributionUrl=https\://services.gradle.org/distributions/gradle-4.1-all.zip
1 -#!/usr/bin/env bash
2 -
3 -##############################################################################
4 -##
5 -## Gradle start up script for UN*X
6 -##
7 -##############################################################################
8 -
9 -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
10 -DEFAULT_JVM_OPTS=""
11 -
12 -APP_NAME="Gradle"
13 -APP_BASE_NAME=`basename "$0"`
14 -
15 -# Use the maximum available, or set MAX_FD != -1 to use that value.
16 -MAX_FD="maximum"
17 -
18 -warn ( ) {
19 - echo "$*"
20 -}
21 -
22 -die ( ) {
23 - echo
24 - echo "$*"
25 - echo
26 - exit 1
27 -}
28 -
29 -# OS specific support (must be 'true' or 'false').
30 -cygwin=false
31 -msys=false
32 -darwin=false
33 -case "`uname`" in
34 - CYGWIN* )
35 - cygwin=true
36 - ;;
37 - Darwin* )
38 - darwin=true
39 - ;;
40 - MINGW* )
41 - msys=true
42 - ;;
43 -esac
44 -
45 -# Attempt to set APP_HOME
46 -# Resolve links: $0 may be a link
47 -PRG="$0"
48 -# Need this for relative symlinks.
49 -while [ -h "$PRG" ] ; do
50 - ls=`ls -ld "$PRG"`
51 - link=`expr "$ls" : '.*-> \(.*\)$'`
52 - if expr "$link" : '/.*' > /dev/null; then
53 - PRG="$link"
54 - else
55 - PRG=`dirname "$PRG"`"/$link"
56 - fi
57 -done
58 -SAVED="`pwd`"
59 -cd "`dirname \"$PRG\"`/" >/dev/null
60 -APP_HOME="`pwd -P`"
61 -cd "$SAVED" >/dev/null
62 -
63 -CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
64 -
65 -# Determine the Java command to use to start the JVM.
66 -if [ -n "$JAVA_HOME" ] ; then
67 - if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
68 - # IBM's JDK on AIX uses strange locations for the executables
69 - JAVACMD="$JAVA_HOME/jre/sh/java"
70 - else
71 - JAVACMD="$JAVA_HOME/bin/java"
72 - fi
73 - if [ ! -x "$JAVACMD" ] ; then
74 - die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
75 -
76 -Please set the JAVA_HOME variable in your environment to match the
77 -location of your Java installation."
78 - fi
79 -else
80 - JAVACMD="java"
81 - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
82 -
83 -Please set the JAVA_HOME variable in your environment to match the
84 -location of your Java installation."
85 -fi
86 -
87 -# Increase the maximum file descriptors if we can.
88 -if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
89 - MAX_FD_LIMIT=`ulimit -H -n`
90 - if [ $? -eq 0 ] ; then
91 - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
92 - MAX_FD="$MAX_FD_LIMIT"
93 - fi
94 - ulimit -n $MAX_FD
95 - if [ $? -ne 0 ] ; then
96 - warn "Could not set maximum file descriptor limit: $MAX_FD"
97 - fi
98 - else
99 - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
100 - fi
101 -fi
102 -
103 -# For Darwin, add options to specify how the application appears in the dock
104 -if $darwin; then
105 - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
106 -fi
107 -
108 -# For Cygwin, switch paths to Windows format before running java
109 -if $cygwin ; then
110 - APP_HOME=`cygpath --path --mixed "$APP_HOME"`
111 - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
112 - JAVACMD=`cygpath --unix "$JAVACMD"`
113 -
114 - # We build the pattern for arguments to be converted via cygpath
115 - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
116 - SEP=""
117 - for dir in $ROOTDIRSRAW ; do
118 - ROOTDIRS="$ROOTDIRS$SEP$dir"
119 - SEP="|"
120 - done
121 - OURCYGPATTERN="(^($ROOTDIRS))"
122 - # Add a user-defined pattern to the cygpath arguments
123 - if [ "$GRADLE_CYGPATTERN" != "" ] ; then
124 - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
125 - fi
126 - # Now convert the arguments - kludge to limit ourselves to /bin/sh
127 - i=0
128 - for arg in "$@" ; do
129 - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
130 - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
131 -
132 - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
133 - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
134 - else
135 - eval `echo args$i`="\"$arg\""
136 - fi
137 - i=$((i+1))
138 - done
139 - case $i in
140 - (0) set -- ;;
141 - (1) set -- "$args0" ;;
142 - (2) set -- "$args0" "$args1" ;;
143 - (3) set -- "$args0" "$args1" "$args2" ;;
144 - (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
145 - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
146 - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
147 - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
148 - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
149 - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
150 - esac
151 -fi
152 -
153 -# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
154 -function splitJvmOpts() {
155 - JVM_OPTS=("$@")
156 -}
157 -eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
158 -JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
159 -
160 -exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
1 -@if "%DEBUG%" == "" @echo off
2 -@rem ##########################################################################
3 -@rem
4 -@rem Gradle startup script for Windows
5 -@rem
6 -@rem ##########################################################################
7 -
8 -@rem Set local scope for the variables with windows NT shell
9 -if "%OS%"=="Windows_NT" setlocal
10 -
11 -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
12 -set DEFAULT_JVM_OPTS=
13 -
14 -set DIRNAME=%~dp0
15 -if "%DIRNAME%" == "" set DIRNAME=.
16 -set APP_BASE_NAME=%~n0
17 -set APP_HOME=%DIRNAME%
18 -
19 -@rem Find java.exe
20 -if defined JAVA_HOME goto findJavaFromJavaHome
21 -
22 -set JAVA_EXE=java.exe
23 -%JAVA_EXE% -version >NUL 2>&1
24 -if "%ERRORLEVEL%" == "0" goto init
25 -
26 -echo.
27 -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
28 -echo.
29 -echo Please set the JAVA_HOME variable in your environment to match the
30 -echo location of your Java installation.
31 -
32 -goto fail
33 -
34 -:findJavaFromJavaHome
35 -set JAVA_HOME=%JAVA_HOME:"=%
36 -set JAVA_EXE=%JAVA_HOME%/bin/java.exe
37 -
38 -if exist "%JAVA_EXE%" goto init
39 -
40 -echo.
41 -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
42 -echo.
43 -echo Please set the JAVA_HOME variable in your environment to match the
44 -echo location of your Java installation.
45 -
46 -goto fail
47 -
48 -:init
49 -@rem Get command-line arguments, handling Windowz variants
50 -
51 -if not "%OS%" == "Windows_NT" goto win9xME_args
52 -if "%@eval[2+2]" == "4" goto 4NT_args
53 -
54 -:win9xME_args
55 -@rem Slurp the command line arguments.
56 -set CMD_LINE_ARGS=
57 -set _SKIP=2
58 -
59 -:win9xME_args_slurp
60 -if "x%~1" == "x" goto execute
61 -
62 -set CMD_LINE_ARGS=%*
63 -goto execute
64 -
65 -:4NT_args
66 -@rem Get arguments from the 4NT Shell from JP Software
67 -set CMD_LINE_ARGS=%$
68 -
69 -:execute
70 -@rem Setup the command line
71 -
72 -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
73 -
74 -@rem Execute Gradle
75 -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
76 -
77 -:end
78 -@rem End local scope for the variables with windows NT shell
79 -if "%ERRORLEVEL%"=="0" goto mainEnd
80 -
81 -:fail
82 -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
83 -rem the _cmd.exe /c_ return code!
84 -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
85 -exit /b 1
86 -
87 -:mainEnd
88 -if "%OS%"=="Windows_NT" endlocal
89 -
90 -:omega
1 -#
2 -# Copyright (C) 2016 The Android Open Source Project
3 -#
4 -# Licensed under the Apache License, Version 2.0 (the "License");
5 -# you may not use this file except in compliance with the License.
6 -# You may obtain a copy of the License at
7 -#
8 -# http://www.apache.org/licenses/LICENSE-2.0
9 -#
10 -# Unless required by applicable law or agreed to in writing, software
11 -# distributed under the License is distributed on an "AS IS" BASIS,
12 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 -# See the License for the specific language governing permissions and
14 -# limitations under the License.
15 -#
16 -
17 -project(TENSORFLOW_DEMO)
18 -cmake_minimum_required(VERSION 3.4.1)
19 -
20 -set(CMAKE_VERBOSE_MAKEFILE on)
21 -
22 -get_filename_component(TF_SRC_ROOT ${CMAKE_SOURCE_DIR}/../../../.. ABSOLUTE)
23 -get_filename_component(SAMPLE_SRC_DIR ${CMAKE_SOURCE_DIR}/.. ABSOLUTE)
24 -
25 -if (ANDROID_ABI MATCHES "^armeabi-v7a$")
26 - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon")
27 -elseif(ANDROID_ABI MATCHES "^arm64-v8a")
28 - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -ftree-vectorize")
29 -endif()
30 -
31 -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSTANDALONE_DEMO_LIB \
32 - -std=c++11 -fno-exceptions -fno-rtti -O2 -Wno-narrowing \
33 - -fPIE")
34 -set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
35 - -Wl,--allow-multiple-definition \
36 - -Wl,--whole-archive -fPIE -v")
37 -
38 -file(GLOB_RECURSE tensorflow_demo_sources ${SAMPLE_SRC_DIR}/jni/*.*)
39 -add_library(tensorflow_demo SHARED
40 - ${tensorflow_demo_sources})
41 -target_include_directories(tensorflow_demo PRIVATE
42 - ${TF_SRC_ROOT}
43 - ${CMAKE_SOURCE_DIR})
44 -
45 -target_link_libraries(tensorflow_demo
46 - android
47 - log
48 - jnigraphics
49 - m
50 - atomic
51 - z)
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// This file binds the native image utility code to the Java class
17 -// which exposes them.
18 -
19 -#include <jni.h>
20 -#include <stdio.h>
21 -#include <stdlib.h>
22 -
23 -#include "tensorflow/examples/android/jni/rgb2yuv.h"
24 -#include "tensorflow/examples/android/jni/yuv2rgb.h"
25 -
26 -#define IMAGEUTILS_METHOD(METHOD_NAME) \
27 - Java_org_tensorflow_demo_env_ImageUtils_##METHOD_NAME // NOLINT
28 -
29 -#ifdef __cplusplus
30 -extern "C" {
31 -#endif
32 -
33 -JNIEXPORT void JNICALL
34 -IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
35 - JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
36 - jint width, jint height, jboolean halfSize);
37 -
38 -JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
39 - JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
40 - jintArray output, jint width, jint height, jint y_row_stride,
41 - jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize);
42 -
43 -JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
44 - JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
45 - jint height);
46 -
47 -JNIEXPORT void JNICALL
48 -IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
49 - JNIEnv* env, jclass clazz, jintArray input, jbyteArray output,
50 - jint width, jint height);
51 -
52 -JNIEXPORT void JNICALL
53 -IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)(
54 - JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
55 - jint width, jint height);
56 -
57 -#ifdef __cplusplus
58 -}
59 -#endif
60 -
61 -JNIEXPORT void JNICALL
62 -IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
63 - JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
64 - jint width, jint height, jboolean halfSize) {
65 - jboolean inputCopy = JNI_FALSE;
66 - jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
67 -
68 - jboolean outputCopy = JNI_FALSE;
69 - jint* const o = env->GetIntArrayElements(output, &outputCopy);
70 -
71 - if (halfSize) {
72 - ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(i),
73 - reinterpret_cast<uint32_t*>(o), width,
74 - height);
75 - } else {
76 - ConvertYUV420SPToARGB8888(reinterpret_cast<uint8_t*>(i),
77 - reinterpret_cast<uint8_t*>(i) + width * height,
78 - reinterpret_cast<uint32_t*>(o), width, height);
79 - }
80 -
81 - env->ReleaseByteArrayElements(input, i, JNI_ABORT);
82 - env->ReleaseIntArrayElements(output, o, 0);
83 -}
84 -
85 -JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
86 - JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
87 - jintArray output, jint width, jint height, jint y_row_stride,
88 - jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize) {
89 - jboolean inputCopy = JNI_FALSE;
90 - jbyte* const y_buff = env->GetByteArrayElements(y, &inputCopy);
91 - jboolean outputCopy = JNI_FALSE;
92 - jint* const o = env->GetIntArrayElements(output, &outputCopy);
93 -
94 - if (halfSize) {
95 - ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(y_buff),
96 - reinterpret_cast<uint32_t*>(o), width,
97 - height);
98 - } else {
99 - jbyte* const u_buff = env->GetByteArrayElements(u, &inputCopy);
100 - jbyte* const v_buff = env->GetByteArrayElements(v, &inputCopy);
101 -
102 - ConvertYUV420ToARGB8888(
103 - reinterpret_cast<uint8_t*>(y_buff), reinterpret_cast<uint8_t*>(u_buff),
104 - reinterpret_cast<uint8_t*>(v_buff), reinterpret_cast<uint32_t*>(o),
105 - width, height, y_row_stride, uv_row_stride, uv_pixel_stride);
106 -
107 - env->ReleaseByteArrayElements(u, u_buff, JNI_ABORT);
108 - env->ReleaseByteArrayElements(v, v_buff, JNI_ABORT);
109 - }
110 -
111 - env->ReleaseByteArrayElements(y, y_buff, JNI_ABORT);
112 - env->ReleaseIntArrayElements(output, o, 0);
113 -}
114 -
115 -JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
116 - JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
117 - jint height) {
118 - jboolean inputCopy = JNI_FALSE;
119 - jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
120 -
121 - jboolean outputCopy = JNI_FALSE;
122 - jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
123 -
124 - ConvertYUV420SPToRGB565(reinterpret_cast<uint8_t*>(i),
125 - reinterpret_cast<uint16_t*>(o), width, height);
126 -
127 - env->ReleaseByteArrayElements(input, i, JNI_ABORT);
128 - env->ReleaseByteArrayElements(output, o, 0);
129 -}
130 -
131 -JNIEXPORT void JNICALL
132 -IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
133 - JNIEnv* env, jclass clazz, jintArray input, jbyteArray output,
134 - jint width, jint height) {
135 - jboolean inputCopy = JNI_FALSE;
136 - jint* const i = env->GetIntArrayElements(input, &inputCopy);
137 -
138 - jboolean outputCopy = JNI_FALSE;
139 - jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
140 -
141 - ConvertARGB8888ToYUV420SP(reinterpret_cast<uint32_t*>(i),
142 - reinterpret_cast<uint8_t*>(o), width, height);
143 -
144 - env->ReleaseIntArrayElements(input, i, JNI_ABORT);
145 - env->ReleaseByteArrayElements(output, o, 0);
146 -}
147 -
148 -JNIEXPORT void JNICALL
149 -IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)(
150 - JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
151 - jint width, jint height) {
152 - jboolean inputCopy = JNI_FALSE;
153 - jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
154 -
155 - jboolean outputCopy = JNI_FALSE;
156 - jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
157 -
158 - ConvertRGB565ToYUV420SP(reinterpret_cast<uint16_t*>(i),
159 - reinterpret_cast<uint8_t*>(o), width, height);
160 -
161 - env->ReleaseByteArrayElements(input, i, JNI_ABORT);
162 - env->ReleaseByteArrayElements(output, o, 0);
163 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
18 -
19 -#include <math.h>
20 -
21 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 -
23 -namespace tf_tracking {
24 -
25 -// Arbitrary keypoint type ids for labeling the origin of tracked keypoints.
26 -enum KeypointType {
27 - KEYPOINT_TYPE_DEFAULT = 0,
28 - KEYPOINT_TYPE_FAST = 1,
29 - KEYPOINT_TYPE_INTEREST = 2
30 -};
31 -
32 -// Struct that can be used to more richly store the results of a detection
33 -// than a single number, while still maintaining comparability.
34 -struct MatchScore {
35 - explicit MatchScore(double val) : value(val) {}
36 - MatchScore() { value = 0.0; }
37 -
38 - double value;
39 -
40 - MatchScore& operator+(const MatchScore& rhs) {
41 - value += rhs.value;
42 - return *this;
43 - }
44 -
45 - friend std::ostream& operator<<(std::ostream& stream,
46 - const MatchScore& detection) {
47 - stream << detection.value;
48 - return stream;
49 - }
50 -};
51 -inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) {
52 - return cC1.value < cC2.value;
53 -}
54 -inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) {
55 - return cC1.value > cC2.value;
56 -}
57 -inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) {
58 - return cC1.value >= cC2.value;
59 -}
60 -inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) {
61 - return cC1.value <= cC2.value;
62 -}
63 -
64 -// Fixed seed used for all random number generators.
65 -static const int kRandomNumberSeed = 11111;
66 -
67 -// TODO(andrewharp): Move as many of these settings as possible into a settings
68 -// object which can be passed in from Java at runtime.
69 -
70 -// Whether or not to use ESM instead of LK flow.
71 -static const bool kUseEsm = false;
72 -
73 -// This constant gets added to the diagonal of the Hessian
74 -// before solving for translation in 2dof ESM.
75 -// It ensures better behavior especially in the absence of
76 -// strong texture.
77 -static const int kEsmRegularizer = 20;
78 -
79 -// Do we want to brightness-normalize each keypoint patch when we compute
80 -// its flow using ESM?
81 -static const bool kDoBrightnessNormalize = true;
82 -
83 -// Whether or not to use fixed-point interpolated pixel lookups in optical flow.
84 -#define USE_FIXED_POINT_FLOW 1
85 -
86 -// Whether to normalize keypoint windows for intensity in LK optical flow.
87 -// This is a define for now because it helps keep the code streamlined.
88 -#define NORMALIZE 1
89 -
90 -// Number of keypoints to store per frame.
91 -static const int kMaxKeypoints = 76;
92 -
93 -// Keypoint detection.
94 -static const int kMaxTempKeypoints = 1024;
95 -
96 -// Number of floats each keypoint takes up when exporting to an array.
97 -static const int kKeypointStep = 7;
98 -
99 -// Number of frame deltas to keep around in the circular queue.
100 -static const int kNumFrames = 512;
101 -
102 -// Number of iterations to do tracking on each keypoint at each pyramid level.
103 -static const int kNumIterations = 3;
104 -
105 -// The number of bins (on a side) to divide each bin from the previous
106 -// cache level into. Higher numbers will decrease performance by increasing
107 -// cache misses, but mean that cache hits are more locally relevant.
108 -static const int kCacheBranchFactor = 2;
109 -
110 -// Number of levels to put in the cache.
111 -// Each level of the cache is a square grid of bins, length:
112 -// branch_factor^(level - 1) on each side.
113 -//
114 -// This may be greater than kNumPyramidLevels. Setting it to 0 means no
115 -// caching is enabled.
116 -static const int kNumCacheLevels = 3;
117 -
118 -// The level at which the cache pyramid gets cut off and replaced by a matrix
119 -// transform if such a matrix has been provided to the cache.
120 -static const int kCacheCutoff = 1;
121 -
122 -static const int kNumPyramidLevels = 4;
123 -
124 -// The minimum number of keypoints needed in an object's area.
125 -static const int kMaxKeypointsForObject = 16;
126 -
127 -// Minimum number of pyramid levels to use after getting cached value.
128 -// This allows fine-scale adjustment from the cached value, which is taken
129 -// from the center of the corresponding top cache level box.
130 -// Can be [0, kNumPyramidLevels).
131 -static const int kMinNumPyramidLevelsToUseForAdjustment = 1;
132 -
133 -// Window size to integrate over to find local image derivative.
134 -static const int kFlowIntegrationWindowSize = 3;
135 -
136 -// Total area of integration windows.
137 -static const int kFlowArraySize =
138 - (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1);
139 -
140 -// Error that's considered good enough to early abort tracking.
141 -static const float kTrackingAbortThreshold = 0.03f;
142 -
143 -// Maximum number of deviations a keypoint-correspondence delta can be from the
144 -// weighted average before being thrown out for region-based queries.
145 -static const float kNumDeviations = 2.0f;
146 -
147 -// The length of the allowed delta between the forward and the backward
148 -// flow deltas in terms of the length of the forward flow vector.
149 -static const float kMaxForwardBackwardErrorAllowed = 0.5f;
150 -
151 -// Threshold for pixels to be considered different.
152 -static const int kFastDiffAmount = 10;
153 -
154 -// How far from edge of frame to stop looking for FAST keypoints.
155 -static const int kFastBorderBuffer = 10;
156 -
157 -// Determines if non-detected arbitrary keypoints should be added to regions.
158 -// This will help if no keypoints have been detected in the region yet.
159 -static const bool kAddArbitraryKeypoints = true;
160 -
161 -// How many arbitrary keypoints to add along each axis as candidates for each
162 -// region?
163 -static const int kNumToAddAsCandidates = 1;
164 -
165 -// In terms of region dimensions, how closely can we place keypoints
166 -// next to each other?
167 -static const float kClosestPercent = 0.6f;
168 -
169 -// How many FAST qualifying pixels must be connected to a pixel for it to be
170 -// considered a candidate keypoint for Harris filtering.
171 -static const int kMinNumConnectedForFastKeypoint = 8;
172 -
173 -// Size of the window to integrate over for Harris filtering.
174 -// Compare to kFlowIntegrationWindowSize.
175 -static const int kHarrisWindowSize = 2;
176 -
177 -
178 -// DETECTOR PARAMETERS
179 -
180 -// Before relocalizing, make sure the new proposed position is better than
181 -// the existing position by a small amount to prevent thrashing.
182 -static const MatchScore kMatchScoreBuffer(0.01f);
183 -
184 -// Minimum score a tracked object can have and still be considered a match.
185 -// TODO(andrewharp): Make this a per detector thing.
186 -static const MatchScore kMinimumMatchScore(0.5f);
187 -
188 -static const float kMinimumCorrelationForTracking = 0.4f;
189 -
190 -static const MatchScore kMatchScoreForImmediateTermination(0.0f);
191 -
192 -// Run the detector every N frames.
193 -static const int kDetectEveryNFrames = 4;
194 -
195 -// How many features does each feature_set contain?
196 -static const int kFeaturesPerFeatureSet = 10;
197 -
198 -// The number of FeatureSets managed by the object detector.
199 -// More FeatureSets can increase recall at the cost of performance.
200 -static const int kNumFeatureSets = 7;
201 -
202 -// How many FeatureSets must respond affirmatively for a candidate descriptor
203 -// and position to be given more thorough attention?
204 -static const int kNumFeatureSetsForCandidate = 2;
205 -
206 -// How large the thumbnails used for correlation validation are. Used for both
207 -// width and height.
208 -static const int kNormalizedThumbnailSize = 11;
209 -
210 -// The area of intersection divided by union for the bounding boxes that tells
211 -// if this tracking has slipped enough to invalidate all unlocked examples.
212 -static const float kPositionOverlapThreshold = 0.6f;
213 -
214 -// The number of detection failures allowed before an object goes invisible.
215 -// Tracking will still occur, so if it is actually still being tracked and
216 -// comes back into a detectable position, it's likely to be found.
217 -static const int kMaxNumDetectionFailures = 4;
218 -
219 -
220 -// Minimum square size to scan with sliding window.
221 -static const float kScanMinSquareSize = 16.0f;
222 -
223 -// Minimum square size to scan with sliding window.
224 -static const float kScanMaxSquareSize = 64.0f;
225 -
226 -// Scale difference for consecutive scans of the sliding window.
227 -static const float kScanScaleFactor = sqrtf(2.0f);
228 -
229 -// Step size for sliding window.
230 -static const int kScanStepSize = 10;
231 -
232 -
233 -// How tightly to pack the descriptor boxes for confirmed exemplars.
234 -static const float kLockedScaleFactor = 1 / sqrtf(2.0f);
235 -
236 -// How tightly to pack the descriptor boxes for unconfirmed exemplars.
237 -static const float kUnlockedScaleFactor = 1 / 2.0f;
238 -
239 -// How tightly the boxes to scan centered at the last known position will be
240 -// packed.
241 -static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f);
242 -
243 -// The bounds on how close a new object example must be to existing object
244 -// examples for detection to be valid.
245 -static const float kMinCorrelationForNewExample = 0.75f;
246 -static const float kMaxCorrelationForNewExample = 0.99f;
247 -
248 -
249 -// The number of safe tries an exemplar has after being created before
250 -// missed detections count against it.
251 -static const int kFreeTries = 5;
252 -
253 -// A false positive is worth this many missed detections.
254 -static const int kFalsePositivePenalty = 5;
255 -
256 -struct ObjectDetectorConfig {
257 - const Size image_size;
258 -
259 - explicit ObjectDetectorConfig(const Size& image_size)
260 - : image_size(image_size) {}
261 - virtual ~ObjectDetectorConfig() = default;
262 -};
263 -
264 -struct KeypointDetectorConfig {
265 - const Size image_size;
266 -
267 - bool detect_skin;
268 -
269 - explicit KeypointDetectorConfig(const Size& image_size)
270 - : image_size(image_size),
271 - detect_skin(false) {}
272 -};
273 -
274 -
275 -struct OpticalFlowConfig {
276 - const Size image_size;
277 -
278 - explicit OpticalFlowConfig(const Size& image_size)
279 - : image_size(image_size) {}
280 -};
281 -
282 -struct TrackerConfig {
283 - const Size image_size;
284 - KeypointDetectorConfig keypoint_detector_config;
285 - OpticalFlowConfig flow_config;
286 - bool always_track;
287 -
288 - float object_box_scale_factor_for_features;
289 -
290 - explicit TrackerConfig(const Size& image_size)
291 - : image_size(image_size),
292 - keypoint_detector_config(image_size),
293 - flow_config(image_size),
294 - always_track(false),
295 - object_box_scale_factor_for_features(1.0f) {}
296 -};
297 -
298 -} // namespace tf_tracking
299 -
300 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
18 -
19 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
21 -
22 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
24 -
25 -namespace tf_tracking {
26 -
27 -// Class that helps OpticalFlow to speed up flow computation
28 -// by caching coarse-grained flow.
29 -class FlowCache {
30 - public:
31 - explicit FlowCache(const OpticalFlowConfig* const config)
32 - : config_(config),
33 - image_size_(config->image_size),
34 - optical_flow_(config),
35 - fullframe_matrix_(NULL) {
36 - for (int i = 0; i < kNumCacheLevels; ++i) {
37 - const int curr_dims = BlockDimForCacheLevel(i);
38 - has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
39 - displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
40 - }
41 - }
42 -
43 - ~FlowCache() {
44 - for (int i = 0; i < kNumCacheLevels; ++i) {
45 - SAFE_DELETE(has_cache_[i]);
46 - SAFE_DELETE(displacements_[i]);
47 - }
48 - delete[](fullframe_matrix_);
49 - fullframe_matrix_ = NULL;
50 - }
51 -
52 - void NextFrame(ImageData* const new_frame,
53 - const float* const align_matrix23) {
54 - ClearCache();
55 - SetFullframeAlignmentMatrix(align_matrix23);
56 - optical_flow_.NextFrame(new_frame);
57 - }
58 -
59 - void ClearCache() {
60 - for (int i = 0; i < kNumCacheLevels; ++i) {
61 - has_cache_[i]->Clear(false);
62 - }
63 - delete[](fullframe_matrix_);
64 - fullframe_matrix_ = NULL;
65 - }
66 -
67 - // Finds the flow at a point, using the cache for performance.
68 - bool FindFlowAtPoint(const float u_x, const float u_y,
69 - float* const flow_x, float* const flow_y) const {
70 - // Get the best guess from the cache.
71 - const Point2f guess_from_cache = LookupGuess(u_x, u_y);
72 -
73 - *flow_x = guess_from_cache.x;
74 - *flow_y = guess_from_cache.y;
75 -
76 - // Now refine the guess using the image pyramid.
77 - for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
78 - pyramid_level >= 0; --pyramid_level) {
79 - if (!optical_flow_.FindFlowAtPointSingleLevel(
80 - pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
81 - return false;
82 - }
83 - }
84 -
85 - return true;
86 - }
87 -
88 - // Determines the displacement of a point, and uses that to calculate a new
89 - // position.
90 - // Returns true iff the displacement determination worked and the new position
91 - // is in the image.
92 - bool FindNewPositionOfPoint(const float u_x, const float u_y,
93 - float* final_x, float* final_y) const {
94 - float flow_x;
95 - float flow_y;
96 - if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
97 - return false;
98 - }
99 -
100 - // Add in the displacement to get the final position.
101 - *final_x = u_x + flow_x;
102 - *final_y = u_y + flow_y;
103 -
104 - // Assign the best guess, if we're still in the image.
105 - if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
106 - InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
107 - return true;
108 - } else {
109 - return false;
110 - }
111 - }
112 -
113 - // Comparison function for qsort.
114 - static int Compare(const void* a, const void* b) {
115 - return *reinterpret_cast<const float*>(a) -
116 - *reinterpret_cast<const float*>(b);
117 - }
118 -
119 - // Returns the median flow within the given bounding box as determined
120 - // by a grid_width x grid_height grid.
121 - Point2f GetMedianFlow(const BoundingBox& bounding_box,
122 - const bool filter_by_fb_error,
123 - const int grid_width,
124 - const int grid_height) const {
125 - const int kMaxPoints = 100;
126 - SCHECK(grid_width * grid_height <= kMaxPoints,
127 - "Too many points for Median flow!");
128 -
129 - const BoundingBox valid_box = bounding_box.Intersect(
130 - BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
131 -
132 - if (valid_box.GetArea() <= 0.0f) {
133 - return Point2f(0, 0);
134 - }
135 -
136 - float x_deltas[kMaxPoints];
137 - float y_deltas[kMaxPoints];
138 -
139 - int curr_offset = 0;
140 - for (int i = 0; i < grid_width; ++i) {
141 - for (int j = 0; j < grid_height; ++j) {
142 - const float x_in = valid_box.left_ +
143 - (valid_box.GetWidth() * i) / (grid_width - 1);
144 -
145 - const float y_in = valid_box.top_ +
146 - (valid_box.GetHeight() * j) / (grid_height - 1);
147 -
148 - float curr_flow_x;
149 - float curr_flow_y;
150 - const bool success = FindNewPositionOfPoint(x_in, y_in,
151 - &curr_flow_x, &curr_flow_y);
152 -
153 - if (success) {
154 - x_deltas[curr_offset] = curr_flow_x;
155 - y_deltas[curr_offset] = curr_flow_y;
156 - ++curr_offset;
157 - } else {
158 - LOGW("Tracking failure!");
159 - }
160 - }
161 - }
162 -
163 - if (curr_offset > 0) {
164 - qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
165 - qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
166 -
167 - return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
168 - }
169 -
170 - LOGW("No points were valid!");
171 - return Point2f(0, 0);
172 - }
173 -
174 - void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
175 - if (align_matrix23 != NULL) {
176 - if (fullframe_matrix_ == NULL) {
177 - fullframe_matrix_ = new float[6];
178 - }
179 -
180 - memcpy(fullframe_matrix_, align_matrix23,
181 - 6 * sizeof(fullframe_matrix_[0]));
182 - }
183 - }
184 -
185 - private:
186 - Point2f LookupGuessFromLevel(
187 - const int cache_level, const float x, const float y) const {
188 - // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
189 -
190 - // Cutoff at the target level and use the matrix transform instead.
191 - if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
192 - const float xnew = x * fullframe_matrix_[0] +
193 - y * fullframe_matrix_[1] +
194 - fullframe_matrix_[2];
195 - const float ynew = x * fullframe_matrix_[3] +
196 - y * fullframe_matrix_[4] +
197 - fullframe_matrix_[5];
198 -
199 - return Point2f(xnew - x, ynew - y);
200 - }
201 -
202 - const int level_dim = BlockDimForCacheLevel(cache_level);
203 - const int pixels_per_cache_block_x =
204 - (image_size_.width + level_dim - 1) / level_dim;
205 - const int pixels_per_cache_block_y =
206 - (image_size_.height + level_dim - 1) / level_dim;
207 - const int index_x = x / pixels_per_cache_block_x;
208 - const int index_y = y / pixels_per_cache_block_y;
209 -
210 - Point2f displacement;
211 - if (!(*has_cache_[cache_level])[index_y][index_x]) {
212 - (*has_cache_[cache_level])[index_y][index_x] = true;
213 -
214 - // Get the lower cache level's best guess, if it exists.
215 - displacement = cache_level >= kNumCacheLevels - 1 ?
216 - Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
217 - // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
218 - // best_guess.x, best_guess.y);
219 -
220 - // Find the center of the block.
221 - const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
222 - const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
223 - const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
224 -
225 - // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
226 - // "Querying %5.2f, %5.2f at pyramid level %d, ",
227 - // cache_level, index_x, index_y,
228 - // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
229 - // center_x, center_y, pyramid_level);
230 -
231 - // TODO(andrewharp): Turn on FB error filtering.
232 - const bool success = optical_flow_.FindFlowAtPointSingleLevel(
233 - pyramid_level, center_x, center_y, false,
234 - &displacement.x, &displacement.y);
235 -
236 - if (!success) {
237 - LOGV("Computation of cached value failed for level %d!", cache_level);
238 - }
239 -
240 - // Store the value for later use.
241 - (*displacements_[cache_level])[index_y][index_x] = displacement;
242 - } else {
243 - displacement = (*displacements_[cache_level])[index_y][index_x];
244 - }
245 -
246 - // LOGI("Returning %5.2f, %5.2f for level %d",
247 - // displacement.x, displacement.y, cache_level);
248 - return displacement;
249 - }
250 -
251 - Point2f LookupGuess(const float x, const float y) const {
252 - if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
253 - return Point2f(0, 0);
254 - }
255 -
256 - // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
257 - if (kNumCacheLevels > 0) {
258 - return LookupGuessFromLevel(0, x, y);
259 - } else {
260 - return Point2f(0, 0);
261 - }
262 - }
263 -
264 - // Returns the number of cache bins in each dimension for a given level
265 - // of the cache.
266 - int BlockDimForCacheLevel(const int cache_level) const {
267 - // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
268 - // thus if there are 4 cache levels, requesting level 3 (0-based) should
269 - // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
270 - // and so on.
271 - int block_dim = kNumCacheLevels;
272 - for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
273 - --curr_level) {
274 - block_dim *= kCacheBranchFactor;
275 - }
276 - return block_dim;
277 - }
278 -
279 - // Returns the level of the image pyramid that a given cache level maps to.
280 - int PyramidLevelForCacheLevel(const int cache_level) const {
281 - // Higher cache and pyramid levels have smaller dimensions. The highest
282 - // cache level should refer to the highest image pyramid level. The
283 - // lower, finer image pyramid levels are uncached (assuming
284 - // kNumCacheLevels < kNumPyramidLevels).
285 - return cache_level + (kNumPyramidLevels - kNumCacheLevels);
286 - }
287 -
288 - const OpticalFlowConfig* const config_;
289 -
290 - const Size image_size_;
291 - OpticalFlow optical_flow_;
292 -
293 - float* fullframe_matrix_;
294 -
295 - // Whether this value is currently present in the cache.
296 - Image<bool>* has_cache_[kNumCacheLevels];
297 -
298 - // The cached displacement values.
299 - Image<Point2f>* displacements_[kNumCacheLevels];
300 -
301 - TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
302 -};
303 -
304 -} // namespace tf_tracking
305 -
306 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#include <float.h>
17 -
18 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
19 -#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
20 -
21 -namespace tf_tracking {
22 -
23 -void FramePair::Init(const int64_t start_time, const int64_t end_time) {
24 - start_time_ = start_time;
25 - end_time_ = end_time;
26 - memset(optical_flow_found_keypoint_, false,
27 - sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
28 - number_of_keypoints_ = 0;
29 -}
30 -
31 -void FramePair::AdjustBox(const BoundingBox box,
32 - float* const translation_x,
33 - float* const translation_y,
34 - float* const scale_x,
35 - float* const scale_y) const {
36 - static float weights[kMaxKeypoints];
37 - static Point2f deltas[kMaxKeypoints];
38 - memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
39 -
40 - BoundingBox resized_box(box);
41 - resized_box.Scale(0.4f, 0.4f);
42 - FillWeights(resized_box, weights);
43 - FillTranslations(deltas);
44 -
45 - const Point2f translation = GetWeightedMedian(weights, deltas);
46 -
47 - *translation_x = translation.x;
48 - *translation_y = translation.y;
49 -
50 - const Point2f old_center = box.GetCenter();
51 - const int good_scale_points =
52 - FillScales(old_center, translation, weights, deltas);
53 -
54 - // Default scale factor is 1 for x and y.
55 - *scale_x = 1.0f;
56 - *scale_y = 1.0f;
57 -
58 - // The assumption is that all deltas that make it to this stage with a
59 - // corresponding optical_flow_found_keypoint_[i] == true are not in
60 - // themselves degenerate.
61 - //
62 - // The degeneracy with scale arose because if the points are too close to the
63 - // center of the objects, the scale ratio determination might be incalculable.
64 - //
65 - // The check for kMinNumInRange is not a degeneracy check, but merely an
66 - // attempt to ensure some sort of stability. The actual degeneracy check is in
67 - // the comparison to EPSILON in FillScales (which I've updated to return the
68 - // number good remaining as well).
69 - static const int kMinNumInRange = 5;
70 - if (good_scale_points >= kMinNumInRange) {
71 - const float scale_factor = GetWeightedMedianScale(weights, deltas);
72 -
73 - if (scale_factor > 0.0f) {
74 - *scale_x = scale_factor;
75 - *scale_y = scale_factor;
76 - }
77 - }
78 -}
79 -
80 -int FramePair::FillWeights(const BoundingBox& box,
81 - float* const weights) const {
82 - // Compute the max score.
83 - float max_score = -FLT_MAX;
84 - float min_score = FLT_MAX;
85 - for (int i = 0; i < kMaxKeypoints; ++i) {
86 - if (optical_flow_found_keypoint_[i]) {
87 - max_score = MAX(max_score, frame1_keypoints_[i].score_);
88 - min_score = MIN(min_score, frame1_keypoints_[i].score_);
89 - }
90 - }
91 -
92 - int num_in_range = 0;
93 - for (int i = 0; i < kMaxKeypoints; ++i) {
94 - if (!optical_flow_found_keypoint_[i]) {
95 - weights[i] = 0.0f;
96 - continue;
97 - }
98 -
99 - const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
100 - if (in_box) {
101 - ++num_in_range;
102 - }
103 -
104 - // The weighting based off distance. Anything within the bounding box
105 - // has a weight of 1, and everything outside of that is within the range
106 - // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
107 - float distance_score = 1.0f;
108 - if (!in_box) {
109 - const Point2f initial = box.GetCenter();
110 - const float sq_x_dist =
111 - Square(initial.x - frame1_keypoints_[i].pos_.x);
112 - const float sq_y_dist =
113 - Square(initial.y - frame1_keypoints_[i].pos_.y);
114 - const float squared_half_width = Square(box.GetWidth() / 2.0f);
115 - const float squared_half_height = Square(box.GetHeight() / 2.0f);
116 -
117 - static const float kOutOfBoxMultiplier = 0.5f;
118 - distance_score = kOutOfBoxMultiplier *
119 - MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
120 - }
121 -
122 - // The weighting based on relative score strength. kBaseScore - 1.0f.
123 - float intrinsic_score = 1.0f;
124 - if (max_score > min_score) {
125 - static const float kBaseScore = 0.5f;
126 - intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
127 - (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
128 - }
129 -
130 - // The final score will be in the range [0, 1].
131 - weights[i] = distance_score * intrinsic_score;
132 - }
133 -
134 - return num_in_range;
135 -}
136 -
137 -void FramePair::FillTranslations(Point2f* const translations) const {
138 - for (int i = 0; i < kMaxKeypoints; ++i) {
139 - if (!optical_flow_found_keypoint_[i]) {
140 - continue;
141 - }
142 - translations[i].x =
143 - frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
144 - translations[i].y =
145 - frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
146 - }
147 -}
148 -
149 -int FramePair::FillScales(const Point2f& old_center,
150 - const Point2f& translation,
151 - float* const weights,
152 - Point2f* const scales) const {
153 - int num_good = 0;
154 - for (int i = 0; i < kMaxKeypoints; ++i) {
155 - if (!optical_flow_found_keypoint_[i]) {
156 - continue;
157 - }
158 -
159 - const Keypoint keypoint1 = frame1_keypoints_[i];
160 - const Keypoint keypoint2 = frame2_keypoints_[i];
161 -
162 - const float dist1_x = keypoint1.pos_.x - old_center.x;
163 - const float dist1_y = keypoint1.pos_.y - old_center.y;
164 -
165 - const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
166 - const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
167 -
168 - // Make sure that the scale makes sense; points too close to the center
169 - // will result in either NaNs or infinite results for scale due to
170 - // limited tracking and floating point resolution.
171 - // Also check that the parity of the points is the same with respect to
172 - // x and y, as we can't really make sense of data that has flipped.
173 - if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
174 - (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
175 - ((dist2_y > EPSILON && dist1_y > EPSILON) ||
176 - (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
177 - scales[i].x = dist2_x / dist1_x;
178 - scales[i].y = dist2_y / dist1_y;
179 - ++num_good;
180 - } else {
181 - weights[i] = 0.0f;
182 - scales[i].x = 1.0f;
183 - scales[i].y = 1.0f;
184 - }
185 - }
186 - return num_good;
187 -}
188 -
189 -struct WeightedDelta {
190 - float weight;
191 - float delta;
192 -};
193 -
194 -// Sort by delta, not by weight.
195 -inline int WeightedDeltaCompare(const void* const a, const void* const b) {
196 - return (reinterpret_cast<const WeightedDelta*>(a)->delta -
197 - reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
198 -}
199 -
200 -// Returns the median delta from a sorted set of weighted deltas.
201 -static float GetMedian(const int num_items,
202 - const WeightedDelta* const weighted_deltas,
203 - const float sum) {
204 - if (num_items == 0 || sum < EPSILON) {
205 - return 0.0f;
206 - }
207 -
208 - float current_weight = 0.0f;
209 - const float target_weight = sum / 2.0f;
210 - for (int i = 0; i < num_items; ++i) {
211 - if (weighted_deltas[i].weight > 0.0f) {
212 - current_weight += weighted_deltas[i].weight;
213 - if (current_weight >= target_weight) {
214 - return weighted_deltas[i].delta;
215 - }
216 - }
217 - }
218 - LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
219 - return 0.0f;
220 -}
221 -
222 -Point2f FramePair::GetWeightedMedian(
223 - const float* const weights, const Point2f* const deltas) const {
224 - Point2f median_delta;
225 -
226 - // TODO(andrewharp): only sort deltas that could possibly have an effect.
227 - static WeightedDelta weighted_deltas[kMaxKeypoints];
228 -
229 - // Compute median X value.
230 - {
231 - float total_weight = 0.0f;
232 -
233 - // Compute weighted mean and deltas.
234 - for (int i = 0; i < kMaxKeypoints; ++i) {
235 - weighted_deltas[i].delta = deltas[i].x;
236 - const float weight = weights[i];
237 - weighted_deltas[i].weight = weight;
238 - if (weight > 0.0f) {
239 - total_weight += weight;
240 - }
241 - }
242 - qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
243 - WeightedDeltaCompare);
244 - median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
245 - }
246 -
247 - // Compute median Y value.
248 - {
249 - float total_weight = 0.0f;
250 -
251 - // Compute weighted mean and deltas.
252 - for (int i = 0; i < kMaxKeypoints; ++i) {
253 - const float weight = weights[i];
254 - weighted_deltas[i].weight = weight;
255 - weighted_deltas[i].delta = deltas[i].y;
256 - if (weight > 0.0f) {
257 - total_weight += weight;
258 - }
259 - }
260 - qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
261 - WeightedDeltaCompare);
262 - median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
263 - }
264 -
265 - return median_delta;
266 -}
267 -
268 -float FramePair::GetWeightedMedianScale(
269 - const float* const weights, const Point2f* const deltas) const {
270 - float median_delta;
271 -
272 - // TODO(andrewharp): only sort deltas that could possibly have an effect.
273 - static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
274 -
275 - // Compute median scale value across x and y.
276 - {
277 - float total_weight = 0.0f;
278 -
279 - // Add X values.
280 - for (int i = 0; i < kMaxKeypoints; ++i) {
281 - weighted_deltas[i].delta = deltas[i].x;
282 - const float weight = weights[i];
283 - weighted_deltas[i].weight = weight;
284 - if (weight > 0.0f) {
285 - total_weight += weight;
286 - }
287 - }
288 -
289 - // Add Y values.
290 - for (int i = 0; i < kMaxKeypoints; ++i) {
291 - weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
292 - const float weight = weights[i];
293 - weighted_deltas[i + kMaxKeypoints].weight = weight;
294 - if (weight > 0.0f) {
295 - total_weight += weight;
296 - }
297 - }
298 -
299 - qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
300 - WeightedDeltaCompare);
301 -
302 - median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
303 - }
304 -
305 - return median_delta;
306 -}
307 -
308 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
18 -
19 -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
20 -
21 -namespace tf_tracking {
22 -
23 -// A class that records keypoint correspondences from pairs of
24 -// consecutive frames.
25 -class FramePair {
26 - public:
27 - FramePair()
28 - : start_time_(0),
29 - end_time_(0),
30 - number_of_keypoints_(0) {}
31 -
32 - // Cleans up the FramePair so that they can be reused.
33 - void Init(const int64_t start_time, const int64_t end_time);
34 -
35 - void AdjustBox(const BoundingBox box,
36 - float* const translation_x,
37 - float* const translation_y,
38 - float* const scale_x,
39 - float* const scale_y) const;
40 -
41 - private:
42 - // Returns the weighted median of the given deltas, computed independently on
43 - // x and y. Returns 0,0 in case of failure. The assumption is that a
44 - // translation of 0.0 in the degenerate case is the best that can be done, and
45 - // should not be considered an error.
46 - //
47 - // In the case of scale, a slight exception is made just to be safe and
48 - // there is a check for 0.0 explicitly, but that shouldn't ever be possible to
49 - // happen naturally because of the non-zero + parity checks in FillScales.
50 - Point2f GetWeightedMedian(const float* const weights,
51 - const Point2f* const deltas) const;
52 -
53 - float GetWeightedMedianScale(const float* const weights,
54 - const Point2f* const deltas) const;
55 -
56 - // Weights points based on the query_point and cutoff_dist.
57 - int FillWeights(const BoundingBox& box,
58 - float* const weights) const;
59 -
60 - // Fills in the array of deltas with the translations of the points
61 - // between frames.
62 - void FillTranslations(Point2f* const translations) const;
63 -
64 - // Fills in the array of deltas with the relative scale factor of points
65 - // relative to a given center. Has the ability to override the weight to 0 if
66 - // a degenerate scale is detected.
67 - // Translation is the amount the center of the box has moved from one frame to
68 - // the next.
69 - int FillScales(const Point2f& old_center,
70 - const Point2f& translation,
71 - float* const weights,
72 - Point2f* const scales) const;
73 -
74 - // TODO(andrewharp): Make these private.
75 - public:
76 - // The time at frame1.
77 - int64_t start_time_;
78 -
79 - // The time at frame2.
80 - int64_t end_time_;
81 -
82 - // This array will contain the keypoints found in frame 1.
83 - Keypoint frame1_keypoints_[kMaxKeypoints];
84 -
85 - // Contain the locations of the keypoints from frame 1 in frame 2.
86 - Keypoint frame2_keypoints_[kMaxKeypoints];
87 -
88 - // The number of keypoints in frame 1.
89 - int number_of_keypoints_;
90 -
91 - // Keeps track of which keypoint correspondences were actually found from one
92 - // frame to another.
93 - // The i-th element of this array will be non-zero if and only if the i-th
94 - // keypoint of frame 1 was found in frame 2.
95 - bool optical_flow_found_keypoint_[kMaxKeypoints];
96 -
97 - private:
98 - TF_DISALLOW_COPY_AND_ASSIGN(FramePair);
99 -};
100 -
101 -} // namespace tf_tracking
102 -
103 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
18 -
19 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
20 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
21 -
22 -namespace tf_tracking {
23 -
24 -struct Size {
25 - Size(const int width, const int height) : width(width), height(height) {}
26 -
27 - int width;
28 - int height;
29 -};
30 -
31 -
32 -class Point2f {
33 - public:
34 - Point2f() : x(0.0f), y(0.0f) {}
35 - Point2f(const float x, const float y) : x(x), y(y) {}
36 -
37 - inline Point2f operator- (const Point2f& that) const {
38 - return Point2f(this->x - that.x, this->y - that.y);
39 - }
40 -
41 - inline Point2f operator+ (const Point2f& that) const {
42 - return Point2f(this->x + that.x, this->y + that.y);
43 - }
44 -
45 - inline Point2f& operator+= (const Point2f& that) {
46 - this->x += that.x;
47 - this->y += that.y;
48 - return *this;
49 - }
50 -
51 - inline Point2f& operator-= (const Point2f& that) {
52 - this->x -= that.x;
53 - this->y -= that.y;
54 - return *this;
55 - }
56 -
57 - inline Point2f operator- (const Point2f& that) {
58 - return Point2f(this->x - that.x, this->y - that.y);
59 - }
60 -
61 - inline float LengthSquared() {
62 - return Square(this->x) + Square(this->y);
63 - }
64 -
65 - inline float Length() {
66 - return sqrtf(LengthSquared());
67 - }
68 -
69 - inline float DistanceSquared(const Point2f& that) {
70 - return Square(this->x - that.x) + Square(this->y - that.y);
71 - }
72 -
73 - inline float Distance(const Point2f& that) {
74 - return sqrtf(DistanceSquared(that));
75 - }
76 -
77 - float x;
78 - float y;
79 -};
80 -
81 -inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
82 - stream << point.x << "," << point.y;
83 - return stream;
84 -}
85 -
86 -class BoundingBox {
87 - public:
88 - BoundingBox()
89 - : left_(0),
90 - top_(0),
91 - right_(0),
92 - bottom_(0) {}
93 -
94 - BoundingBox(const BoundingBox& bounding_box)
95 - : left_(bounding_box.left_),
96 - top_(bounding_box.top_),
97 - right_(bounding_box.right_),
98 - bottom_(bounding_box.bottom_) {
99 - SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
100 - SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
101 - }
102 -
103 - BoundingBox(const float left,
104 - const float top,
105 - const float right,
106 - const float bottom)
107 - : left_(left),
108 - top_(top),
109 - right_(right),
110 - bottom_(bottom) {
111 - SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
112 - SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
113 - }
114 -
115 - BoundingBox(const Point2f& point1, const Point2f& point2)
116 - : left_(MIN(point1.x, point2.x)),
117 - top_(MIN(point1.y, point2.y)),
118 - right_(MAX(point1.x, point2.x)),
119 - bottom_(MAX(point1.y, point2.y)) {}
120 -
121 - inline void CopyToArray(float* const bounds_array) const {
122 - bounds_array[0] = left_;
123 - bounds_array[1] = top_;
124 - bounds_array[2] = right_;
125 - bounds_array[3] = bottom_;
126 - }
127 -
128 - inline float GetWidth() const {
129 - return right_ - left_;
130 - }
131 -
132 - inline float GetHeight() const {
133 - return bottom_ - top_;
134 - }
135 -
136 - inline float GetArea() const {
137 - const float width = GetWidth();
138 - const float height = GetHeight();
139 - if (width <= 0 || height <= 0) {
140 - return 0.0f;
141 - }
142 -
143 - return width * height;
144 - }
145 -
146 - inline Point2f GetCenter() const {
147 - return Point2f((left_ + right_) / 2.0f,
148 - (top_ + bottom_) / 2.0f);
149 - }
150 -
151 - inline bool ValidBox() const {
152 - return GetArea() > 0.0f;
153 - }
154 -
155 - // Returns a bounding box created from the overlapping area of these two.
156 - inline BoundingBox Intersect(const BoundingBox& that) const {
157 - const float new_left = MAX(this->left_, that.left_);
158 - const float new_right = MIN(this->right_, that.right_);
159 -
160 - if (new_left >= new_right) {
161 - return BoundingBox();
162 - }
163 -
164 - const float new_top = MAX(this->top_, that.top_);
165 - const float new_bottom = MIN(this->bottom_, that.bottom_);
166 -
167 - if (new_top >= new_bottom) {
168 - return BoundingBox();
169 - }
170 -
171 - return BoundingBox(new_left, new_top, new_right, new_bottom);
172 - }
173 -
174 - // Returns a bounding box that can contain both boxes.
175 - inline BoundingBox Union(const BoundingBox& that) const {
176 - return BoundingBox(MIN(this->left_, that.left_),
177 - MIN(this->top_, that.top_),
178 - MAX(this->right_, that.right_),
179 - MAX(this->bottom_, that.bottom_));
180 - }
181 -
182 - inline float PascalScore(const BoundingBox& that) const {
183 - SCHECK(GetArea() > 0.0f, "Empty bounding box!");
184 - SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
185 -
186 - const float intersect_area = this->Intersect(that).GetArea();
187 -
188 - if (intersect_area <= 0) {
189 - return 0;
190 - }
191 -
192 - const float score =
193 - intersect_area / (GetArea() + that.GetArea() - intersect_area);
194 - SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
195 - return score;
196 - }
197 -
198 - inline bool Intersects(const BoundingBox& that) const {
199 - return InRange(that.left_, left_, right_)
200 - || InRange(that.right_, left_, right_)
201 - || InRange(that.top_, top_, bottom_)
202 - || InRange(that.bottom_, top_, bottom_);
203 - }
204 -
205 - // Returns whether another bounding box is completely inside of this bounding
206 - // box. Sharing edges is ok.
207 - inline bool Contains(const BoundingBox& that) const {
208 - return that.left_ >= left_ &&
209 - that.right_ <= right_ &&
210 - that.top_ >= top_ &&
211 - that.bottom_ <= bottom_;
212 - }
213 -
214 - inline bool Contains(const Point2f& point) const {
215 - return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
216 - }
217 -
218 - inline void Shift(const Point2f shift_amount) {
219 - left_ += shift_amount.x;
220 - top_ += shift_amount.y;
221 - right_ += shift_amount.x;
222 - bottom_ += shift_amount.y;
223 - }
224 -
225 - inline void ScaleOrigin(const float scale_x, const float scale_y) {
226 - left_ *= scale_x;
227 - right_ *= scale_x;
228 - top_ *= scale_y;
229 - bottom_ *= scale_y;
230 - }
231 -
232 - inline void Scale(const float scale_x, const float scale_y) {
233 - const Point2f center = GetCenter();
234 - const float half_width = GetWidth() / 2.0f;
235 - const float half_height = GetHeight() / 2.0f;
236 -
237 - left_ = center.x - half_width * scale_x;
238 - right_ = center.x + half_width * scale_x;
239 -
240 - top_ = center.y - half_height * scale_y;
241 - bottom_ = center.y + half_height * scale_y;
242 - }
243 -
244 - float left_;
245 - float top_;
246 - float right_;
247 - float bottom_;
248 -};
249 -inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
250 - stream << "[" << box.left_ << " - " << box.right_
251 - << ", " << box.top_ << " - " << box.bottom_
252 - << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
253 - return stream;
254 -}
255 -
256 -
257 -class BoundingSquare {
258 - public:
259 - BoundingSquare(const float x, const float y, const float size)
260 - : x_(x), y_(y), size_(size) {}
261 -
262 - explicit BoundingSquare(const BoundingBox& box)
263 - : x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
264 -#ifdef SANITY_CHECKS
265 - if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
266 - LOG(WARNING) << "This is not a square: " << box << std::endl;
267 - }
268 -#endif
269 - }
270 -
271 - inline BoundingBox ToBoundingBox() const {
272 - return BoundingBox(x_, y_, x_ + size_, y_ + size_);
273 - }
274 -
275 - inline bool ValidBox() {
276 - return size_ > 0.0f;
277 - }
278 -
279 - inline void Shift(const Point2f shift_amount) {
280 - x_ += shift_amount.x;
281 - y_ += shift_amount.y;
282 - }
283 -
284 - inline void Scale(const float scale) {
285 - const float new_size = size_ * scale;
286 - const float position_diff = (new_size - size_) / 2.0f;
287 - x_ -= position_diff;
288 - y_ -= position_diff;
289 - size_ = new_size;
290 - }
291 -
292 - float x_;
293 - float y_;
294 - float size_;
295 -};
296 -inline std::ostream& operator<<(std::ostream& stream,
297 - const BoundingSquare& square) {
298 - stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
299 - return stream;
300 -}
301 -
302 -
303 -inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
304 - const float size) {
305 - const float width_diff = (original_box.GetWidth() - size) / 2.0f;
306 - const float height_diff = (original_box.GetHeight() - size) / 2.0f;
307 - return BoundingSquare(original_box.left_ + width_diff,
308 - original_box.top_ + height_diff,
309 - size);
310 -}
311 -
312 -inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
313 - return GetCenteredSquare(
314 - original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
315 -}
316 -
317 -} // namespace tf_tracking
318 -
319 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
18 -
19 -#include <GLES/gl.h>
20 -#include <GLES/glext.h>
21 -
22 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
23 -
24 -namespace tf_tracking {
25 -
26 -// Draws a box at the given position.
27 -inline static void DrawBox(const BoundingBox& bounding_box) {
28 - const GLfloat line[] = {
29 - bounding_box.left_, bounding_box.bottom_,
30 - bounding_box.left_, bounding_box.top_,
31 - bounding_box.left_, bounding_box.top_,
32 - bounding_box.right_, bounding_box.top_,
33 - bounding_box.right_, bounding_box.top_,
34 - bounding_box.right_, bounding_box.bottom_,
35 - bounding_box.right_, bounding_box.bottom_,
36 - bounding_box.left_, bounding_box.bottom_
37 - };
38 -
39 - glVertexPointer(2, GL_FLOAT, 0, line);
40 - glEnableClientState(GL_VERTEX_ARRAY);
41 -
42 - glDrawArrays(GL_LINES, 0, 8);
43 -}
44 -
45 -
46 -// Changes the coordinate system such that drawing to an arbitrary square in
47 -// the world can thereafter be drawn to using coordinates 0 - 1.
48 -inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) {
49 - glScalef(square.size_, square.size_, 1.0f);
50 - glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f);
51 -}
52 -
53 -} // namespace tf_tracking
54 -
55 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
18 -
19 -#include <stdint.h>
20 -
21 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
24 -
25 -namespace tf_tracking {
26 -
27 -template <typename T>
28 -Image<T>::Image(const int width, const int height)
29 - : width_less_one_(width - 1),
30 - height_less_one_(height - 1),
31 - data_size_(width * height),
32 - own_data_(true),
33 - width_(width),
34 - height_(height),
35 - stride_(width) {
36 - Allocate();
37 -}
38 -
39 -template <typename T>
40 -Image<T>::Image(const Size& size)
41 - : width_less_one_(size.width - 1),
42 - height_less_one_(size.height - 1),
43 - data_size_(size.width * size.height),
44 - own_data_(true),
45 - width_(size.width),
46 - height_(size.height),
47 - stride_(size.width) {
48 - Allocate();
49 -}
50 -
51 -// Constructor that creates an image from preallocated data.
52 -// Note: The image takes ownership of the data lifecycle, unless own_data is
53 -// set to false.
54 -template <typename T>
55 -Image<T>::Image(const int width, const int height, T* const image_data,
56 - const bool own_data) :
57 - width_less_one_(width - 1),
58 - height_less_one_(height - 1),
59 - data_size_(width * height),
60 - own_data_(own_data),
61 - width_(width),
62 - height_(height),
63 - stride_(width) {
64 - image_data_ = image_data;
65 - SCHECK(image_data_ != NULL, "Can't create image with NULL data!");
66 -}
67 -
68 -template <typename T>
69 -Image<T>::~Image() {
70 - if (own_data_) {
71 - delete[] image_data_;
72 - }
73 - image_data_ = NULL;
74 -}
75 -
76 -template<typename T>
77 -template<class DstType>
78 -bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x,
79 - const int fp_y,
80 - const int patchwidth,
81 - const int patchheight,
82 - DstType* to_data) const {
83 - // Calculate weights.
84 - const int trunc_x = fp_x >> 16;
85 - const int trunc_y = fp_y >> 16;
86 -
87 - if (trunc_x < 0 || trunc_y < 0 ||
88 - (trunc_x + patchwidth) >= width_less_one_ ||
89 - (trunc_y + patchheight) >= height_less_one_) {
90 - return false;
91 - }
92 -
93 - // Now walk over destination patch and fill from interpolated source image.
94 - for (int y = 0; y < patchheight; ++y, to_data += patchwidth) {
95 - for (int x = 0; x < patchwidth; ++x) {
96 - to_data[x] =
97 - static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16),
98 - fp_y + (y << 16)));
99 - }
100 - }
101 -
102 - return true;
103 -}
104 -
105 -template <typename T>
106 -Image<T>* Image<T>::Crop(
107 - const int left, const int top, const int right, const int bottom) const {
108 - SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left);
109 - SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right);
110 - SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top);
111 - SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom);
112 -
113 - SCHECK(left <= right, "mismatch!");
114 - SCHECK(top <= bottom, "mismatch!");
115 -
116 - const int new_width = right - left + 1;
117 - const int new_height = bottom - top + 1;
118 -
119 - Image<T>* const cropped_image = new Image(new_width, new_height);
120 -
121 - for (int y = 0; y < new_height; ++y) {
122 - memcpy((*cropped_image)[y], ((*this)[y + top] + left),
123 - new_width * sizeof(T));
124 - }
125 -
126 - return cropped_image;
127 -}
128 -
129 -template <typename T>
130 -inline float Image<T>::GetPixelInterp(const float x, const float y) const {
131 - // Do int conversion one time.
132 - const int floored_x = static_cast<int>(x);
133 - const int floored_y = static_cast<int>(y);
134 -
135 - // Note: it might be the case that the *_[min|max] values are clipped, and
136 - // these (the a b c d vals) aren't (for speed purposes), but that doesn't
137 - // matter. We'll just be blending the pixel with itself in that case anyway.
138 - const float b = x - floored_x;
139 - const float a = 1.0f - b;
140 -
141 - const float d = y - floored_y;
142 - const float c = 1.0f - d;
143 -
144 - SCHECK(ValidInterpPixel(x, y),
145 - "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)",
146 - x, width_less_one_, y, height_less_one_);
147 -
148 - const T* const pix_ptr = (*this)[floored_y] + floored_x;
149 -
150 - // Get the pixel values surrounding this point.
151 - const T& p1 = pix_ptr[0];
152 - const T& p2 = pix_ptr[1];
153 - const T& p3 = pix_ptr[width_];
154 - const T& p4 = pix_ptr[width_ + 1];
155 -
156 - // Simple bilinear interpolation between four reference pixels.
157 - // If x is the value requested:
158 - // a b
159 - // -------
160 - // c |p1 p2|
161 - // | x |
162 - // d |p3 p4|
163 - // -------
164 - return c * ((a * p1) + (b * p2)) +
165 - d * ((a * p3) + (b * p4));
166 -}
167 -
168 -
169 -template <typename T>
170 -inline T Image<T>::GetPixelInterpFixed1616(
171 - const int fp_x_whole, const int fp_y_whole) const {
172 - static const int kFixedPointOne = 0x00010000;
173 - static const int kFixedPointHalf = 0x00008000;
174 - static const int kFixedPointTruncateMask = 0xFFFF0000;
175 -
176 - int trunc_x = fp_x_whole & kFixedPointTruncateMask;
177 - int trunc_y = fp_y_whole & kFixedPointTruncateMask;
178 - const int fp_x = fp_x_whole - trunc_x;
179 - const int fp_y = fp_y_whole - trunc_y;
180 -
181 - // Scale the truncated values back to regular ints.
182 - trunc_x >>= 16;
183 - trunc_y >>= 16;
184 -
185 - const int one_minus_fp_x = kFixedPointOne - fp_x;
186 - const int one_minus_fp_y = kFixedPointOne - fp_y;
187 -
188 - const T* trunc_start = (*this)[trunc_y] + trunc_x;
189 -
190 - const T a = trunc_start[0];
191 - const T b = trunc_start[1];
192 - const T c = trunc_start[stride_];
193 - const T d = trunc_start[stride_ + 1];
194 -
195 - return (
196 - (one_minus_fp_y * static_cast<int64_t>(one_minus_fp_x * a + fp_x * b) +
197 - fp_y * static_cast<int64_t>(one_minus_fp_x * c + fp_x * d) +
198 - kFixedPointHalf) >>
199 - 32);
200 -}
201 -
202 -template <typename T>
203 -inline bool Image<T>::ValidPixel(const int x, const int y) const {
204 - return InRange(x, ZERO, width_less_one_) &&
205 - InRange(y, ZERO, height_less_one_);
206 -}
207 -
208 -template <typename T>
209 -inline BoundingBox Image<T>::GetContainingBox() const {
210 - return BoundingBox(
211 - 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON);
212 -}
213 -
214 -template <typename T>
215 -inline bool Image<T>::Contains(const BoundingBox& bounding_box) const {
216 - // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds
217 - // are ok.
218 - return GetContainingBox().Contains(bounding_box);
219 -}
220 -
221 -template <typename T>
222 -inline bool Image<T>::ValidInterpPixel(const float x, const float y) const {
223 - // Exclusive of max because we can be more efficient if we don't handle
224 - // interpolating on or past the last pixel.
225 - return (x >= ZERO) && (x < width_less_one_) &&
226 - (y >= ZERO) && (y < height_less_one_);
227 -}
228 -
229 -template <typename T>
230 -void Image<T>::DownsampleAveraged(const T* const original, const int stride,
231 - const int factor) {
232 -#ifdef __ARM_NEON
233 - if (factor == 4 || factor == 2) {
234 - DownsampleAveragedNeon(original, stride, factor);
235 - return;
236 - }
237 -#endif
238 -
239 - // TODO(andrewharp): delete or enable this for non-uint8_t downsamples.
240 - const int pixels_per_block = factor * factor;
241 -
242 - // For every pixel in resulting image.
243 - for (int y = 0; y < height_; ++y) {
244 - const int orig_y = y * factor;
245 - const int y_bound = orig_y + factor;
246 -
247 - // Sum up the original pixels.
248 - for (int x = 0; x < width_; ++x) {
249 - const int orig_x = x * factor;
250 - const int x_bound = orig_x + factor;
251 -
252 - // Making this int32_t because type U or T might overflow.
253 - int32_t pixel_sum = 0;
254 -
255 - // Grab all the pixels that make up this pixel.
256 - for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) {
257 - const T* p = original + curr_y * stride + orig_x;
258 -
259 - for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) {
260 - pixel_sum += *p++;
261 - }
262 - }
263 -
264 - (*this)[y][x] = pixel_sum / pixels_per_block;
265 - }
266 - }
267 -}
268 -
269 -template <typename T>
270 -void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) {
271 - // Calculating the scaling factors based on target image size.
272 - const float factor_x = static_cast<float>(original.GetWidth()) /
273 - static_cast<float>(width_);
274 - const float factor_y = static_cast<float>(original.GetHeight()) /
275 - static_cast<float>(height_);
276 -
277 - // Calculating initial offset in x-axis.
278 - const float offset_x = 0.5f * (original.GetWidth() - width_) / width_;
279 -
280 - // Calculating initial offset in y-axis.
281 - const float offset_y = 0.5f * (original.GetHeight() - height_) / height_;
282 -
283 - float orig_y = offset_y;
284 -
285 - // For every pixel in resulting image.
286 - for (int y = 0; y < height_; ++y) {
287 - float orig_x = offset_x;
288 -
289 - // Finding nearest pixel on y-axis.
290 - const int nearest_y = static_cast<int>(orig_y + 0.5f);
291 - const T* row_data = original[nearest_y];
292 -
293 - T* pixel_ptr = (*this)[y];
294 -
295 - for (int x = 0; x < width_; ++x) {
296 - // Finding nearest pixel on x-axis.
297 - const int nearest_x = static_cast<int>(orig_x + 0.5f);
298 -
299 - *pixel_ptr++ = row_data[nearest_x];
300 -
301 - orig_x += factor_x;
302 - }
303 -
304 - orig_y += factor_y;
305 - }
306 -}
307 -
308 -template <typename T>
309 -void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) {
310 - // TODO(andrewharp): Turn this into a general compare sizes/bulk
311 - // copy method.
312 - if (original.GetWidth() == GetWidth() &&
313 - original.GetHeight() == GetHeight() &&
314 - original.stride() == stride()) {
315 - memcpy(image_data_, original.data(), data_size_ * sizeof(T));
316 - return;
317 - }
318 -
319 - // Calculating the scaling factors based on target image size.
320 - const float factor_x = static_cast<float>(original.GetWidth()) /
321 - static_cast<float>(width_);
322 - const float factor_y = static_cast<float>(original.GetHeight()) /
323 - static_cast<float>(height_);
324 -
325 - // Calculating initial offset in x-axis.
326 - const float offset_x = 0;
327 - const int offset_x_fp = RealToFixed1616(offset_x);
328 -
329 - // Calculating initial offset in y-axis.
330 - const float offset_y = 0;
331 - const int offset_y_fp = RealToFixed1616(offset_y);
332 -
333 - // Get the fixed point scaling factor value.
334 - // Shift by 8 so we can fit everything into a 4 byte int later for speed
335 - // reasons. This means the precision is limited to 1 / 256th of a pixel,
336 - // but this should be good enough.
337 - const int factor_x_fp = RealToFixed1616(factor_x) >> 8;
338 - const int factor_y_fp = RealToFixed1616(factor_y) >> 8;
339 -
340 - int src_y_fp = offset_y_fp >> 8;
341 -
342 - static const int kFixedPointOne8 = 0x00000100;
343 - static const int kFixedPointHalf8 = 0x00000080;
344 - static const int kFixedPointTruncateMask8 = 0xFFFFFF00;
345 -
346 - // For every pixel in resulting image.
347 - for (int y = 0; y < height_; ++y) {
348 - int src_x_fp = offset_x_fp >> 8;
349 -
350 - int trunc_y = src_y_fp & kFixedPointTruncateMask8;
351 - const int fp_y = src_y_fp - trunc_y;
352 -
353 - // Scale the truncated values back to regular ints.
354 - trunc_y >>= 8;
355 -
356 - const int one_minus_fp_y = kFixedPointOne8 - fp_y;
357 -
358 - T* pixel_ptr = (*this)[y];
359 -
360 - // Make sure not to read from an invalid row.
361 - const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1);
362 - const T* other_top_ptr = original[trunc_y];
363 - const T* other_bot_ptr = original[trunc_y_b];
364 -
365 - int last_trunc_x = -1;
366 - int trunc_x = -1;
367 -
368 - T a = 0;
369 - T b = 0;
370 - T c = 0;
371 - T d = 0;
372 -
373 - for (int x = 0; x < width_; ++x) {
374 - trunc_x = src_x_fp & kFixedPointTruncateMask8;
375 -
376 - const int fp_x = (src_x_fp - trunc_x) >> 8;
377 -
378 - // Scale the truncated values back to regular ints.
379 - trunc_x >>= 8;
380 -
381 - // It's possible we're reading from the same pixels
382 - if (trunc_x != last_trunc_x) {
383 - // Make sure not to read from an invalid column.
384 - const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1);
385 - a = other_top_ptr[trunc_x];
386 - b = other_top_ptr[trunc_x_b];
387 - c = other_bot_ptr[trunc_x];
388 - d = other_bot_ptr[trunc_x_b];
389 - last_trunc_x = trunc_x;
390 - }
391 -
392 - const int one_minus_fp_x = kFixedPointOne8 - fp_x;
393 -
394 - const int32_t value =
395 - ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) +
396 - (fp_y * one_minus_fp_x * c + fp_x * d) + kFixedPointHalf8) >>
397 - 16;
398 -
399 - *pixel_ptr++ = value;
400 -
401 - src_x_fp += factor_x_fp;
402 - }
403 - src_y_fp += factor_y_fp;
404 - }
405 -}
406 -
407 -template <typename T>
408 -void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) {
409 - for (int y = 0; y < height_; ++y) {
410 - const int orig_y = Clip(2 * y, ZERO, original.height_less_one_);
411 - const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_);
412 - const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_);
413 -
414 - for (int x = 0; x < width_; ++x) {
415 - const int orig_x = Clip(2 * x, ZERO, original.width_less_one_);
416 - const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_);
417 - const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_);
418 -
419 - // Center.
420 - int32_t pixel_sum = original[orig_y][orig_x] * 4;
421 -
422 - // Sides.
423 - pixel_sum += (original[orig_y][max_x] +
424 - original[orig_y][min_x] +
425 - original[max_y][orig_x] +
426 - original[min_y][orig_x]) * 2;
427 -
428 - // Diagonals.
429 - pixel_sum += (original[min_y][max_x] +
430 - original[min_y][min_x] +
431 - original[max_y][max_x] +
432 - original[max_y][min_x]);
433 -
434 - (*this)[y][x] = pixel_sum >> 4; // 16
435 - }
436 - }
437 -}
438 -
439 -template <typename T>
440 -void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) {
441 - const int max_x = original.width_less_one_;
442 - const int max_y = original.height_less_one_;
443 -
444 - // The JY Bouget paper on Lucas-Kanade recommends a
445 - // [1/16 1/4 3/8 1/4 1/16]^2 filter.
446 - // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below.
447 - static const int window_radius = 2;
448 - static const int window_size = window_radius*2 + 1;
449 - static const int window_weights[] = {1, 4, 6, 4, 1, // 16 +
450 - 4, 16, 24, 16, 4, // 64 +
451 - 6, 24, 36, 24, 6, // 96 +
452 - 4, 16, 24, 16, 4, // 64 +
453 - 1, 4, 6, 4, 1}; // 16 = 256
454 -
455 - // We'll multiply and sum with the whole numbers first, then divide by
456 - // the total weight to normalize at the last moment.
457 - for (int y = 0; y < height_; ++y) {
458 - for (int x = 0; x < width_; ++x) {
459 - int32_t pixel_sum = 0;
460 -
461 - const int* w = window_weights;
462 - const int start_x = Clip((x << 1) - window_radius, ZERO, max_x);
463 -
464 - // Clip the boundaries to the size of the image.
465 - for (int window_y = 0; window_y < window_size; ++window_y) {
466 - const int start_y =
467 - Clip((y << 1) - window_radius + window_y, ZERO, max_y);
468 -
469 - const T* p = original[start_y] + start_x;
470 -
471 - for (int window_x = 0; window_x < window_size; ++window_x) {
472 - pixel_sum += *p++ * *w++;
473 - }
474 - }
475 -
476 - // Conversion to type T will happen here after shifting right 8 bits to
477 - // divide by 256.
478 - (*this)[y][x] = pixel_sum >> 8;
479 - }
480 - }
481 -}
482 -
483 -template <typename T>
484 -template <typename U>
485 -inline T Image<T>::ScharrPixelX(const Image<U>& original,
486 - const int center_x, const int center_y) const {
487 - const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_);
488 - const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_);
489 - const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_);
490 - const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_);
491 -
492 - // Convolution loop unrolled for performance...
493 - return (3 * (original[min_y][max_x]
494 - + original[max_y][max_x]
495 - - original[min_y][min_x]
496 - - original[max_y][min_x])
497 - + 10 * (original[center_y][max_x]
498 - - original[center_y][min_x])) / 32;
499 -}
500 -
501 -template <typename T>
502 -template <typename U>
503 -inline T Image<T>::ScharrPixelY(const Image<U>& original,
504 - const int center_x, const int center_y) const {
505 - const int min_x = Clip(center_x - 1, 0, original.width_less_one_);
506 - const int max_x = Clip(center_x + 1, 0, original.width_less_one_);
507 - const int min_y = Clip(center_y - 1, 0, original.height_less_one_);
508 - const int max_y = Clip(center_y + 1, 0, original.height_less_one_);
509 -
510 - // Convolution loop unrolled for performance...
511 - return (3 * (original[max_y][min_x]
512 - + original[max_y][max_x]
513 - - original[min_y][min_x]
514 - - original[min_y][max_x])
515 - + 10 * (original[max_y][center_x]
516 - - original[min_y][center_x])) / 32;
517 -}
518 -
519 -template <typename T>
520 -template <typename U>
521 -inline void Image<T>::ScharrX(const Image<U>& original) {
522 - for (int y = 0; y < height_; ++y) {
523 - for (int x = 0; x < width_; ++x) {
524 - SetPixel(x, y, ScharrPixelX(original, x, y));
525 - }
526 - }
527 -}
528 -
529 -template <typename T>
530 -template <typename U>
531 -inline void Image<T>::ScharrY(const Image<U>& original) {
532 - for (int y = 0; y < height_; ++y) {
533 - for (int x = 0; x < width_; ++x) {
534 - SetPixel(x, y, ScharrPixelY(original, x, y));
535 - }
536 - }
537 -}
538 -
539 -template <typename T>
540 -template <typename U>
541 -void Image<T>::DerivativeX(const Image<U>& original) {
542 - for (int y = 0; y < height_; ++y) {
543 - const U* const source_row = original[y];
544 - T* const dest_row = (*this)[y];
545 -
546 - // Compute first pixel. Approximated with forward difference.
547 - dest_row[0] = source_row[1] - source_row[0];
548 -
549 - // All the pixels in between. Central difference method.
550 - const U* source_prev_pixel = source_row;
551 - T* dest_pixel = dest_row + 1;
552 - const U* source_next_pixel = source_row + 2;
553 - for (int x = 1; x < width_less_one_; ++x) {
554 - *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
555 - }
556 -
557 - // Last pixel. Approximated with backward difference.
558 - dest_row[width_less_one_] =
559 - source_row[width_less_one_] - source_row[width_less_one_ - 1];
560 - }
561 -}
562 -
563 -template <typename T>
564 -template <typename U>
565 -void Image<T>::DerivativeY(const Image<U>& original) {
566 - const int src_stride = original.stride();
567 -
568 - // Compute 1st row. Approximated with forward difference.
569 - {
570 - const U* const src_row = original[0];
571 - T* dest_row = (*this)[0];
572 - for (int x = 0; x < width_; ++x) {
573 - dest_row[x] = src_row[x + src_stride] - src_row[x];
574 - }
575 - }
576 -
577 - // Compute all rows in between using central difference.
578 - for (int y = 1; y < height_less_one_; ++y) {
579 - T* dest_row = (*this)[y];
580 -
581 - const U* source_prev_pixel = original[y - 1];
582 - const U* source_next_pixel = original[y + 1];
583 - for (int x = 0; x < width_; ++x) {
584 - *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
585 - }
586 - }
587 -
588 - // Compute last row. Approximated with backward difference.
589 - {
590 - const U* const src_row = original[height_less_one_];
591 - T* dest_row = (*this)[height_less_one_];
592 - for (int x = 0; x < width_; ++x) {
593 - dest_row[x] = src_row[x] - src_row[x - src_stride];
594 - }
595 - }
596 -}
597 -
598 -template <typename T>
599 -template <typename U>
600 -inline T Image<T>::ConvolvePixel3x3(const Image<U>& original,
601 - const int* const filter,
602 - const int center_x, const int center_y,
603 - const int total) const {
604 - int32_t sum = 0;
605 - for (int filter_y = 0; filter_y < 3; ++filter_y) {
606 - const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight());
607 - for (int filter_x = 0; filter_x < 3; ++filter_x) {
608 - const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth());
609 - sum += original[y][x] * filter[filter_y * 3 + filter_x];
610 - }
611 - }
612 - return sum / total;
613 -}
614 -
615 -template <typename T>
616 -template <typename U>
617 -inline void Image<T>::Convolve3x3(const Image<U>& original,
618 - const int32_t* const filter) {
619 - int32_t sum = 0;
620 - for (int i = 0; i < 9; ++i) {
621 - sum += abs(filter[i]);
622 - }
623 - for (int y = 0; y < height_; ++y) {
624 - for (int x = 0; x < width_; ++x) {
625 - SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum));
626 - }
627 - }
628 -}
629 -
630 -template <typename T>
631 -inline void Image<T>::FromArray(const T* const pixels, const int stride,
632 - const int factor) {
633 - if (factor == 1 && stride == width_) {
634 - // If not subsampling, memcpy per line should be faster.
635 - memcpy(this->image_data_, pixels, data_size_ * sizeof(T));
636 - return;
637 - }
638 -
639 - DownsampleAveraged(pixels, stride, factor);
640 -}
641 -
642 -} // namespace tf_tracking
643 -
644 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
18 -
19 -#include <stdint.h>
20 -
21 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 -
24 -// TODO(andrewharp): Make this a cast to uint32_t if/when we go unsigned for
25 -// operations.
26 -#define ZERO 0
27 -
28 -#ifdef SANITY_CHECKS
29 - #define CHECK_PIXEL(IMAGE, X, Y) {\
30 - SCHECK((IMAGE)->ValidPixel((X), (Y)), \
31 - "CHECK_PIXEL(%d,%d) in %dx%d image.", \
32 - static_cast<int>(X), static_cast<int>(Y), \
33 - (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
34 - }
35 -
36 - #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\
37 - SCHECK((IMAGE)->validInterpPixel((X), (Y)), \
38 - "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \
39 - static_cast<float>(X), static_cast<float>(Y), \
40 - (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
41 - }
42 -#else
43 - #define CHECK_PIXEL(image, x, y) {}
44 - #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {}
45 -#endif
46 -
47 -namespace tf_tracking {
48 -
49 -#ifdef SANITY_CHECKS
50 -// Class which exists solely to provide bounds checking for array-style image
51 -// data access.
52 -template <typename T>
53 -class RowData {
54 - public:
55 - RowData(T* const row_data, const int max_col)
56 - : row_data_(row_data), max_col_(max_col) {}
57 -
58 - inline T& operator[](const int col) const {
59 - SCHECK(InRange(col, 0, max_col_),
60 - "Column out of range: %d (%d max)", col, max_col_);
61 - return row_data_[col];
62 - }
63 -
64 - inline operator T*() const {
65 - return row_data_;
66 - }
67 -
68 - private:
69 - T* const row_data_;
70 - const int max_col_;
71 -};
72 -#endif
73 -
74 -// Naive templated sorting function.
75 -template <typename T>
76 -int Comp(const void* a, const void* b) {
77 - const T val1 = *reinterpret_cast<const T*>(a);
78 - const T val2 = *reinterpret_cast<const T*>(b);
79 -
80 - if (val1 == val2) {
81 - return 0;
82 - } else if (val1 < val2) {
83 - return -1;
84 - } else {
85 - return 1;
86 - }
87 -}
88 -
89 -// TODO(andrewharp): Make explicit which operations support negative numbers or
90 -// struct/class types in image data (possibly create fast multi-dim array class
91 -// for data where pixel arithmetic does not make sense).
92 -
93 -// Image class optimized for working on numeric arrays as grayscale image data.
94 -// Supports other data types as a 2D array class, so long as no pixel math
95 -// operations are called (convolution, downsampling, etc).
96 -template <typename T>
97 -class Image {
98 - public:
99 - Image(const int width, const int height);
100 - explicit Image(const Size& size);
101 -
102 - // Constructor that creates an image from preallocated data.
103 - // Note: The image takes ownership of the data lifecycle, unless own_data is
104 - // set to false.
105 - Image(const int width, const int height, T* const image_data,
106 - const bool own_data = true);
107 -
108 - ~Image();
109 -
110 - // Extract a pixel patch from this image, starting at a subpixel location.
111 - // Uses 16:16 fixed point format for representing real values and doing the
112 - // bilinear interpolation.
113 - //
114 - // Arguments fp_x and fp_y tell the subpixel position in fixed point format,
115 - // patchwidth/patchheight give the size of the patch in pixels and
116 - // to_data must be a valid pointer to a *contiguous* destination data array.
117 - template<class DstType>
118 - bool ExtractPatchAtSubpixelFixed1616(const int fp_x,
119 - const int fp_y,
120 - const int patchwidth,
121 - const int patchheight,
122 - DstType* to_data) const;
123 -
124 - Image<T>* Crop(
125 - const int left, const int top, const int right, const int bottom) const;
126 -
127 - inline int GetWidth() const { return width_; }
128 - inline int GetHeight() const { return height_; }
129 -
130 - // Bilinearly sample a value between pixels. Values must be within the image.
131 - inline float GetPixelInterp(const float x, const float y) const;
132 -
133 - // Bilinearly sample a pixels at a subpixel position using fixed point
134 - // arithmetic.
135 - // Avoids float<->int conversions.
136 - // Values must be within the image.
137 - // Arguments fp_x and fp_y tell the subpixel position in
138 - // 16:16 fixed point format.
139 - //
140 - // Important: This function only makes sense for integer-valued images, such
141 - // as Image<uint8_t> or Image<int> etc.
142 - inline T GetPixelInterpFixed1616(const int fp_x_whole,
143 - const int fp_y_whole) const;
144 -
145 - // Returns true iff the pixel is in the image's boundaries.
146 - inline bool ValidPixel(const int x, const int y) const;
147 -
148 - inline BoundingBox GetContainingBox() const;
149 -
150 - inline bool Contains(const BoundingBox& bounding_box) const;
151 -
152 - inline T GetMedianValue() {
153 - qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>);
154 - return image_data_[data_size_ >> 1];
155 - }
156 -
157 - // Returns true iff the pixel is in the image's boundaries for interpolation
158 - // purposes.
159 - // TODO(andrewharp): check in interpolation follow-up change.
160 - inline bool ValidInterpPixel(const float x, const float y) const;
161 -
162 - // Safe lookup with boundary enforcement.
163 - inline T GetPixelClipped(const int x, const int y) const {
164 - return (*this)[Clip(y, ZERO, height_less_one_)]
165 - [Clip(x, ZERO, width_less_one_)];
166 - }
167 -
168 -#ifdef SANITY_CHECKS
169 - inline RowData<T> operator[](const int row) {
170 - SCHECK(InRange(row, 0, height_less_one_),
171 - "Row out of range: %d (%d max)", row, height_less_one_);
172 - return RowData<T>(image_data_ + row * stride_, width_less_one_);
173 - }
174 -
175 - inline const RowData<T> operator[](const int row) const {
176 - SCHECK(InRange(row, 0, height_less_one_),
177 - "Row out of range: %d (%d max)", row, height_less_one_);
178 - return RowData<T>(image_data_ + row * stride_, width_less_one_);
179 - }
180 -#else
181 - inline T* operator[](const int row) {
182 - return image_data_ + row * stride_;
183 - }
184 -
185 - inline const T* operator[](const int row) const {
186 - return image_data_ + row * stride_;
187 - }
188 -#endif
189 -
190 - const T* data() const { return image_data_; }
191 -
192 - inline int stride() const { return stride_; }
193 -
194 - // Clears image to a single value.
195 - inline void Clear(const T& val) {
196 - memset(image_data_, val, sizeof(*image_data_) * data_size_);
197 - }
198 -
199 -#ifdef __ARM_NEON
200 - void Downsample2x32ColumnsNeon(const uint8_t* const original,
201 - const int stride, const int orig_x);
202 -
203 - void Downsample4x32ColumnsNeon(const uint8_t* const original,
204 - const int stride, const int orig_x);
205 -
206 - void DownsampleAveragedNeon(const uint8_t* const original, const int stride,
207 - const int factor);
208 -#endif
209 -
210 - // Naive downsampler that reduces image size by factor by averaging pixels in
211 - // blocks of size factor x factor.
212 - void DownsampleAveraged(const T* const original, const int stride,
213 - const int factor);
214 -
215 - // Naive downsampler that reduces image size by factor by averaging pixels in
216 - // blocks of size factor x factor.
217 - inline void DownsampleAveraged(const Image<T>& original, const int factor) {
218 - DownsampleAveraged(original.data(), original.GetWidth(), factor);
219 - }
220 -
221 - // Native downsampler that reduces image size using nearest interpolation
222 - void DownsampleInterpolateNearest(const Image<T>& original);
223 -
224 - // Native downsampler that reduces image size using fixed-point bilinear
225 - // interpolation
226 - void DownsampleInterpolateLinear(const Image<T>& original);
227 -
228 - // Relatively efficient downsampling of an image by a factor of two with a
229 - // low-pass 3x3 smoothing operation thrown in.
230 - void DownsampleSmoothed3x3(const Image<T>& original);
231 -
232 - // Relatively efficient downsampling of an image by a factor of two with a
233 - // low-pass 5x5 smoothing operation thrown in.
234 - void DownsampleSmoothed5x5(const Image<T>& original);
235 -
236 - // Optimized Scharr filter on a single pixel in the X direction.
237 - // Scharr filters are like central-difference operators, but have more
238 - // rotational symmetry in their response because they also consider the
239 - // diagonal neighbors.
240 - template <typename U>
241 - inline T ScharrPixelX(const Image<U>& original,
242 - const int center_x, const int center_y) const;
243 -
244 - // Optimized Scharr filter on a single pixel in the X direction.
245 - // Scharr filters are like central-difference operators, but have more
246 - // rotational symmetry in their response because they also consider the
247 - // diagonal neighbors.
248 - template <typename U>
249 - inline T ScharrPixelY(const Image<U>& original,
250 - const int center_x, const int center_y) const;
251 -
252 - // Convolve the image with a Scharr filter in the X direction.
253 - // Much faster than an equivalent generic convolution.
254 - template <typename U>
255 - inline void ScharrX(const Image<U>& original);
256 -
257 - // Convolve the image with a Scharr filter in the Y direction.
258 - // Much faster than an equivalent generic convolution.
259 - template <typename U>
260 - inline void ScharrY(const Image<U>& original);
261 -
262 - static inline T HalfDiff(int32_t first, int32_t second) {
263 - return (second - first) / 2;
264 - }
265 -
266 - template <typename U>
267 - void DerivativeX(const Image<U>& original);
268 -
269 - template <typename U>
270 - void DerivativeY(const Image<U>& original);
271 -
272 - // Generic function for convolving pixel with 3x3 filter.
273 - // Filter pixels should be in row major order.
274 - template <typename U>
275 - inline T ConvolvePixel3x3(const Image<U>& original,
276 - const int* const filter,
277 - const int center_x, const int center_y,
278 - const int total) const;
279 -
280 - // Generic function for convolving an image with a 3x3 filter.
281 - // TODO(andrewharp): Generalize this for any size filter.
282 - template <typename U>
283 - inline void Convolve3x3(const Image<U>& original,
284 - const int32_t* const filter);
285 -
286 - // Load this image's data from a data array. The data at pixels is assumed to
287 - // have dimensions equivalent to this image's dimensions * factor.
288 - inline void FromArray(const T* const pixels, const int stride,
289 - const int factor = 1);
290 -
291 - // Copy the image back out to an appropriately sized data array.
292 - inline void ToArray(T* const pixels) const {
293 - // If not subsampling, memcpy should be faster.
294 - memcpy(pixels, this->image_data_, data_size_ * sizeof(T));
295 - }
296 -
297 - // Precompute these for efficiency's sake as they're used by a lot of
298 - // clipping code and loop code.
299 - // TODO(andrewharp): make these only accessible by other Images.
300 - const int width_less_one_;
301 - const int height_less_one_;
302 -
303 - // The raw size of the allocated data.
304 - const int data_size_;
305 -
306 - private:
307 - inline void Allocate() {
308 - image_data_ = new T[data_size_];
309 - if (image_data_ == NULL) {
310 - LOGE("Couldn't allocate image data!");
311 - }
312 - }
313 -
314 - T* image_data_;
315 -
316 - bool own_data_;
317 -
318 - const int width_;
319 - const int height_;
320 -
321 - // The image stride (offset to next row).
322 - // TODO(andrewharp): Make sure that stride is honored in all code.
323 - const int stride_;
324 -
325 - TF_DISALLOW_COPY_AND_ASSIGN(Image);
326 -};
327 -
328 -template <typename t>
329 -inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) {
330 - for (int y = 0; y < image.GetHeight(); ++y) {
331 - for (int x = 0; x < image.GetWidth(); ++x) {
332 - stream << image[y][x] << " ";
333 - }
334 - stream << std::endl;
335 - }
336 - return stream;
337 -}
338 -
339 -} // namespace tf_tracking
340 -
341 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
18 -
19 -#include <stdint.h>
20 -#include <memory>
21 -
22 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
28 -
29 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
30 -
31 -namespace tf_tracking {
32 -
33 -// Class that encapsulates all bulky processed data for a frame.
34 -class ImageData {
35 - public:
36 - explicit ImageData(const int width, const int height)
37 - : uv_frame_width_(width << 1),
38 - uv_frame_height_(height << 1),
39 - timestamp_(0),
40 - image_(width, height) {
41 - InitPyramid(width, height);
42 - ResetComputationCache();
43 - }
44 -
45 - private:
46 - void ResetComputationCache() {
47 - uv_data_computed_ = false;
48 - integral_image_computed_ = false;
49 - for (int i = 0; i < kNumPyramidLevels; ++i) {
50 - spatial_x_computed_[i] = false;
51 - spatial_y_computed_[i] = false;
52 - pyramid_sqrt2_computed_[i * 2] = false;
53 - pyramid_sqrt2_computed_[i * 2 + 1] = false;
54 - }
55 - }
56 -
57 - void InitPyramid(const int width, const int height) {
58 - int level_width = width;
59 - int level_height = height;
60 -
61 - for (int i = 0; i < kNumPyramidLevels; ++i) {
62 - pyramid_sqrt2_[i * 2] = NULL;
63 - pyramid_sqrt2_[i * 2 + 1] = NULL;
64 - spatial_x_[i] = NULL;
65 - spatial_y_[i] = NULL;
66 -
67 - level_width /= 2;
68 - level_height /= 2;
69 - }
70 -
71 - // Alias the first pyramid level to image_.
72 - pyramid_sqrt2_[0] = &image_;
73 - }
74 -
75 - public:
76 - ~ImageData() {
77 - // The first pyramid level is actually an alias to image_,
78 - // so make sure it doesn't get deleted here.
79 - pyramid_sqrt2_[0] = NULL;
80 -
81 - for (int i = 0; i < kNumPyramidLevels; ++i) {
82 - SAFE_DELETE(pyramid_sqrt2_[i * 2]);
83 - SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
84 - SAFE_DELETE(spatial_x_[i]);
85 - SAFE_DELETE(spatial_y_[i]);
86 - }
87 - }
88 -
89 - void SetData(const uint8_t* const new_frame, const int stride,
90 - const int64_t timestamp, const int downsample_factor) {
91 - SetData(new_frame, NULL, stride, timestamp, downsample_factor);
92 - }
93 -
94 - void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame,
95 - const int stride, const int64_t timestamp,
96 - const int downsample_factor) {
97 - ResetComputationCache();
98 -
99 - timestamp_ = timestamp;
100 -
101 - TimeLog("SetData!");
102 -
103 - pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
104 - pyramid_sqrt2_computed_[0] = true;
105 - TimeLog("Downsampled image");
106 -
107 - if (uv_frame != NULL) {
108 - if (u_data_.get() == NULL) {
109 - u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
110 - v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
111 - }
112 -
113 - GetUV(uv_frame, u_data_.get(), v_data_.get());
114 - uv_data_computed_ = true;
115 - TimeLog("Copied UV data");
116 - } else {
117 - LOGV("No uv data!");
118 - }
119 -
120 -#ifdef LOG_TIME
121 - // If profiling is enabled, precompute here to make it easier to distinguish
122 - // total costs.
123 - Precompute();
124 -#endif
125 - }
126 -
127 - inline const uint64_t GetTimestamp() const { return timestamp_; }
128 -
129 - inline const Image<uint8_t>* GetImage() const {
130 - SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
131 - return pyramid_sqrt2_[0];
132 - }
133 -
134 - const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const {
135 - if (!pyramid_sqrt2_computed_[level]) {
136 - SCHECK(level != 0, "Level equals 0!");
137 - if (level == 1) {
138 - const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0);
139 - if (pyramid_sqrt2_[level] == NULL) {
140 - const int new_width =
141 - (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
142 - const int new_height =
143 - (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
144 - 2;
145 -
146 - pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height);
147 - }
148 - pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
149 - } else {
150 - const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2);
151 - if (pyramid_sqrt2_[level] == NULL) {
152 - pyramid_sqrt2_[level] = new Image<uint8_t>(
153 - upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
154 - }
155 - pyramid_sqrt2_[level]->DownsampleAveraged(
156 - upper_level.data(), upper_level.stride(), 2);
157 - }
158 - pyramid_sqrt2_computed_[level] = true;
159 - }
160 - return pyramid_sqrt2_[level];
161 - }
162 -
163 - inline const Image<int32_t>* GetSpatialX(const int level) const {
164 - if (!spatial_x_computed_[level]) {
165 - const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
166 - if (spatial_x_[level] == NULL) {
167 - spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
168 - }
169 - spatial_x_[level]->DerivativeX(src);
170 - spatial_x_computed_[level] = true;
171 - }
172 - return spatial_x_[level];
173 - }
174 -
175 - inline const Image<int32_t>* GetSpatialY(const int level) const {
176 - if (!spatial_y_computed_[level]) {
177 - const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
178 - if (spatial_y_[level] == NULL) {
179 - spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
180 - }
181 - spatial_y_[level]->DerivativeY(src);
182 - spatial_y_computed_[level] = true;
183 - }
184 - return spatial_y_[level];
185 - }
186 -
187 - // The integral image is currently only used for object detection, so lazily
188 - // initialize it on request.
189 - inline const IntegralImage* GetIntegralImage() const {
190 - if (integral_image_.get() == NULL) {
191 - integral_image_.reset(new IntegralImage(image_));
192 - } else if (!integral_image_computed_) {
193 - integral_image_->Recompute(image_);
194 - }
195 - integral_image_computed_ = true;
196 - return integral_image_.get();
197 - }
198 -
199 - inline const Image<uint8_t>* GetU() const {
200 - SCHECK(uv_data_computed_, "UV data not provided!");
201 - return u_data_.get();
202 - }
203 -
204 - inline const Image<uint8_t>* GetV() const {
205 - SCHECK(uv_data_computed_, "UV data not provided!");
206 - return v_data_.get();
207 - }
208 -
209 - private:
210 - void Precompute() {
211 - // Create the smoothed pyramids.
212 - for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
213 - (void) GetPyramidSqrt2Level(i);
214 - }
215 - TimeLog("Created smoothed pyramids");
216 -
217 - // Create the smoothed pyramids.
218 - for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
219 - (void) GetPyramidSqrt2Level(i);
220 - }
221 - TimeLog("Created smoothed sqrt pyramids");
222 -
223 - // Create the spatial derivatives for frame 1.
224 - for (int i = 0; i < kNumPyramidLevels; ++i) {
225 - (void) GetSpatialX(i);
226 - (void) GetSpatialY(i);
227 - }
228 - TimeLog("Created spatial derivatives");
229 -
230 - (void) GetIntegralImage();
231 - TimeLog("Got integral image!");
232 - }
233 -
234 - const int uv_frame_width_;
235 - const int uv_frame_height_;
236 -
237 - int64_t timestamp_;
238 -
239 - Image<uint8_t> image_;
240 -
241 - bool uv_data_computed_;
242 - std::unique_ptr<Image<uint8_t> > u_data_;
243 - std::unique_ptr<Image<uint8_t> > v_data_;
244 -
245 - mutable bool spatial_x_computed_[kNumPyramidLevels];
246 - mutable Image<int32_t>* spatial_x_[kNumPyramidLevels];
247 -
248 - mutable bool spatial_y_computed_[kNumPyramidLevels];
249 - mutable Image<int32_t>* spatial_y_[kNumPyramidLevels];
250 -
251 - // Mutable so the lazy initialization can work when this class is const.
252 - // Whether or not the integral image has been computed for the current image.
253 - mutable bool integral_image_computed_;
254 - mutable std::unique_ptr<IntegralImage> integral_image_;
255 -
256 - mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
257 - mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2];
258 -
259 - TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
260 -};
261 -
262 -} // namespace tf_tracking
263 -
264 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// NEON implementations of Image methods for compatible devices. Control
17 -// should never enter this compilation unit on incompatible devices.
18 -
19 -#ifdef __ARM_NEON
20 -
21 -#include <arm_neon.h>
22 -
23 -#include <stdint.h>
24 -
25 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
27 -#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
28 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
29 -
30 -namespace tf_tracking {
31 -
32 -// This function does the bulk of the work.
33 -template <>
34 -void Image<uint8_t>::Downsample2x32ColumnsNeon(const uint8_t* const original,
35 - const int stride,
36 - const int orig_x) {
37 - // Divide input x offset by 2 to find output offset.
38 - const int new_x = orig_x >> 1;
39 -
40 - // Initial offset into top row.
41 - const uint8_t* offset = original + orig_x;
42 -
43 - // This points to the leftmost pixel of our 8 horizontally arranged
44 - // pixels in the destination data.
45 - uint8_t* ptr_dst = (*this)[0] + new_x;
46 -
47 - // Sum along vertical columns.
48 - // Process 32x2 input pixels and 16x1 output pixels per iteration.
49 - for (int new_y = 0; new_y < height_; ++new_y) {
50 - uint16x8_t accum1 = vdupq_n_u16(0);
51 - uint16x8_t accum2 = vdupq_n_u16(0);
52 -
53 - // Go top to bottom across the four rows of input pixels that make up
54 - // this output row.
55 - for (int row_num = 0; row_num < 2; ++row_num) {
56 - // First 16 bytes.
57 - {
58 - // Load 16 bytes of data from current offset.
59 - const uint8x16_t curr_data1 = vld1q_u8(offset);
60 -
61 - // Pairwise add and accumulate into accum vectors (16 bit to account
62 - // for values above 255).
63 - accum1 = vpadalq_u8(accum1, curr_data1);
64 - }
65 -
66 - // Second 16 bytes.
67 - {
68 - // Load 16 bytes of data from current offset.
69 - const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
70 -
71 - // Pairwise add and accumulate into accum vectors (16 bit to account
72 - // for values above 255).
73 - accum2 = vpadalq_u8(accum2, curr_data2);
74 - }
75 -
76 - // Move offset down one row.
77 - offset += stride;
78 - }
79 -
80 - // Divide by 4 (number of input pixels per output
81 - // pixel) and narrow data from 16 bits per pixel to 8 bpp.
82 - const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2);
83 - const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2);
84 -
85 - // Concatenate 8x1 pixel strips into 16x1 pixel strip.
86 - const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2);
87 -
88 - // Copy all pixels from composite 16x1 vector into output strip.
89 - vst1q_u8(ptr_dst, allpixels);
90 -
91 - ptr_dst += stride_;
92 - }
93 -}
94 -
95 -// This function does the bulk of the work.
96 -template <>
97 -void Image<uint8_t>::Downsample4x32ColumnsNeon(const uint8_t* const original,
98 - const int stride,
99 - const int orig_x) {
100 - // Divide input x offset by 4 to find output offset.
101 - const int new_x = orig_x >> 2;
102 -
103 - // Initial offset into top row.
104 - const uint8_t* offset = original + orig_x;
105 -
106 - // This points to the leftmost pixel of our 8 horizontally arranged
107 - // pixels in the destination data.
108 - uint8_t* ptr_dst = (*this)[0] + new_x;
109 -
110 - // Sum along vertical columns.
111 - // Process 32x4 input pixels and 8x1 output pixels per iteration.
112 - for (int new_y = 0; new_y < height_; ++new_y) {
113 - uint16x8_t accum1 = vdupq_n_u16(0);
114 - uint16x8_t accum2 = vdupq_n_u16(0);
115 -
116 - // Go top to bottom across the four rows of input pixels that make up
117 - // this output row.
118 - for (int row_num = 0; row_num < 4; ++row_num) {
119 - // First 16 bytes.
120 - {
121 - // Load 16 bytes of data from current offset.
122 - const uint8x16_t curr_data1 = vld1q_u8(offset);
123 -
124 - // Pairwise add and accumulate into accum vectors (16 bit to account
125 - // for values above 255).
126 - accum1 = vpadalq_u8(accum1, curr_data1);
127 - }
128 -
129 - // Second 16 bytes.
130 - {
131 - // Load 16 bytes of data from current offset.
132 - const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
133 -
134 - // Pairwise add and accumulate into accum vectors (16 bit to account
135 - // for values above 255).
136 - accum2 = vpadalq_u8(accum2, curr_data2);
137 - }
138 -
139 - // Move offset down one row.
140 - offset += stride;
141 - }
142 -
143 - // Add and widen, then divide by 16 (number of input pixels per output
144 - // pixel) and narrow data from 32 bits per pixel to 16 bpp.
145 - const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4);
146 - const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4);
147 -
148 - // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from
149 - // 16 bits to 8 bits per pixel.
150 - const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2));
151 -
152 - // Copy all pixels from composite 8x1 vector into output strip.
153 - vst1_u8(ptr_dst, allpixels);
154 -
155 - ptr_dst += stride_;
156 - }
157 -}
158 -
159 -
160 -// Hardware accelerated downsampling method for supported devices.
161 -// Requires that image size be a multiple of 16 pixels in each dimension,
162 -// and that downsampling be by a factor of 2 or 4.
163 -template <>
164 -void Image<uint8_t>::DownsampleAveragedNeon(const uint8_t* const original,
165 - const int stride,
166 - const int factor) {
167 - // TODO(andrewharp): stride is a bad approximation for the src image's width.
168 - // Better to pass that in directly.
169 - SCHECK(width_ * factor <= stride, "Uh oh!");
170 - const int last_starting_index = width_ * factor - 32;
171 -
172 - // We process 32 input pixels lengthwise at a time.
173 - // The output per pass of this loop is an 8 wide by downsampled height tall
174 - // pixel strip.
175 - int orig_x = 0;
176 - for (; orig_x <= last_starting_index; orig_x += 32) {
177 - if (factor == 2) {
178 - Downsample2x32ColumnsNeon(original, stride, orig_x);
179 - } else {
180 - Downsample4x32ColumnsNeon(original, stride, orig_x);
181 - }
182 - }
183 -
184 - // If a last pass is required, push it to the left enough so that it never
185 - // goes out of bounds. This will result in some extra computation on devices
186 - // whose frame widths are multiples of 16 and not 32.
187 - if (orig_x < last_starting_index + 32) {
188 - if (factor == 2) {
189 - Downsample2x32ColumnsNeon(original, stride, last_starting_index);
190 - } else {
191 - Downsample4x32ColumnsNeon(original, stride, last_starting_index);
192 - }
193 - }
194 -}
195 -
196 -
197 -// Puts the image gradient matrix about a pixel into the 2x2 float array G.
198 -// vals_x should be an array of the window x gradient values, whose indices
199 -// can be in any order but are parallel to the vals_y entries.
200 -// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
201 -void CalculateGNeon(const float* const vals_x, const float* const vals_y,
202 - const int num_vals, float* const G) {
203 - const float32_t* const arm_vals_x = (const float32_t*) vals_x;
204 - const float32_t* const arm_vals_y = (const float32_t*) vals_y;
205 -
206 - // Running sums.
207 - float32x4_t xx = vdupq_n_f32(0.0f);
208 - float32x4_t xy = vdupq_n_f32(0.0f);
209 - float32x4_t yy = vdupq_n_f32(0.0f);
210 -
211 - // Maximum index we can load 4 consecutive values from.
212 - // e.g. if there are 81 values, our last full pass can be from index 77:
213 - // 81-4=>77 (77, 78, 79, 80)
214 - const int max_i = num_vals - 4;
215 -
216 - // Defined here because we want to keep track of how many values were
217 - // processed by NEON, so that we can finish off the remainder the normal
218 - // way.
219 - int i = 0;
220 -
221 - // Process values 4 at a time, accumulating the sums of
222 - // the pixel-wise x*x, x*y, and y*y values.
223 - for (; i <= max_i; i += 4) {
224 - // Load xs
225 - float32x4_t x = vld1q_f32(arm_vals_x + i);
226 -
227 - // Multiply x*x and accumulate.
228 - xx = vmlaq_f32(xx, x, x);
229 -
230 - // Load ys
231 - float32x4_t y = vld1q_f32(arm_vals_y + i);
232 -
233 - // Multiply x*y and accumulate.
234 - xy = vmlaq_f32(xy, x, y);
235 -
236 - // Multiply y*y and accumulate.
237 - yy = vmlaq_f32(yy, y, y);
238 - }
239 -
240 - static float32_t xx_vals[4];
241 - static float32_t xy_vals[4];
242 - static float32_t yy_vals[4];
243 -
244 - vst1q_f32(xx_vals, xx);
245 - vst1q_f32(xy_vals, xy);
246 - vst1q_f32(yy_vals, yy);
247 -
248 - // Accumulated values are store in sets of 4, we have to manually add
249 - // the last bits together.
250 - for (int j = 0; j < 4; ++j) {
251 - G[0] += xx_vals[j];
252 - G[1] += xy_vals[j];
253 - G[3] += yy_vals[j];
254 - }
255 -
256 - // Finishes off last few values (< 4) from above.
257 - for (; i < num_vals; ++i) {
258 - G[0] += Square(vals_x[i]);
259 - G[1] += vals_x[i] * vals_y[i];
260 - G[3] += Square(vals_y[i]);
261 - }
262 -
263 - // The matrix is symmetric, so this is a given.
264 - G[2] = G[1];
265 -}
266 -
267 -} // namespace tf_tracking
268 -
269 -#endif
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
18 -
19 -#include <stdint.h>
20 -
21 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
25 -
26 -
27 -namespace tf_tracking {
28 -
29 -inline void GetUV(const uint8_t* const input, Image<uint8_t>* const u,
30 - Image<uint8_t>* const v) {
31 - const uint8_t* pUV = input;
32 -
33 - for (int row = 0; row < u->GetHeight(); ++row) {
34 - uint8_t* u_curr = (*u)[row];
35 - uint8_t* v_curr = (*v)[row];
36 - for (int col = 0; col < u->GetWidth(); ++col) {
37 -#ifdef __APPLE__
38 - *u_curr++ = *pUV++;
39 - *v_curr++ = *pUV++;
40 -#else
41 - *v_curr++ = *pUV++;
42 - *u_curr++ = *pUV++;
43 -#endif
44 - }
45 - }
46 -}
47 -
48 -// Marks every point within a circle of a given radius on the given boolean
49 -// image true.
50 -template <typename U>
51 -inline static void MarkImage(const int x, const int y, const int radius,
52 - Image<U>* const img) {
53 - SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y);
54 -
55 - // Precomputed for efficiency.
56 - const int squared_radius = Square(radius);
57 -
58 - // Mark every row in the circle.
59 - for (int d_y = 0; d_y <= radius; ++d_y) {
60 - const int squared_y_dist = Square(d_y);
61 -
62 - const int min_y = MAX(y - d_y, 0);
63 - const int max_y = MIN(y + d_y, img->height_less_one_);
64 -
65 - // The max d_x of the circle must be strictly greater or equal to
66 - // radius - d_y for any positive d_y. Thus, starting from radius - d_y will
67 - // reduce the number of iterations required as compared to starting from
68 - // either 0 and counting up or radius and counting down.
69 - for (int d_x = radius - d_y; d_x <= radius; ++d_x) {
70 - // The first time this criteria is met, we know the width of the circle at
71 - // this row (without using sqrt).
72 - if (squared_y_dist + Square(d_x) >= squared_radius) {
73 - const int min_x = MAX(x - d_x, 0);
74 - const int max_x = MIN(x + d_x, img->width_less_one_);
75 -
76 - // Mark both above and below the center row.
77 - bool* const top_row_start = (*img)[min_y] + min_x;
78 - bool* const bottom_row_start = (*img)[max_y] + min_x;
79 -
80 - const int x_width = max_x - min_x + 1;
81 - memset(top_row_start, true, sizeof(*top_row_start) * x_width);
82 - memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width);
83 -
84 - // This row is marked, time to move on to the next row.
85 - break;
86 - }
87 - }
88 - }
89 -}
90 -
91 -#ifdef __ARM_NEON
92 -void CalculateGNeon(
93 - const float* const vals_x, const float* const vals_y,
94 - const int num_vals, float* const G);
95 -#endif
96 -
97 -// Puts the image gradient matrix about a pixel into the 2x2 float array G.
98 -// vals_x should be an array of the window x gradient values, whose indices
99 -// can be in any order but are parallel to the vals_y entries.
100 -// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
101 -inline void CalculateG(const float* const vals_x, const float* const vals_y,
102 - const int num_vals, float* const G) {
103 -#ifdef __ARM_NEON
104 - CalculateGNeon(vals_x, vals_y, num_vals, G);
105 - return;
106 -#endif
107 -
108 - // Non-accelerated version.
109 - for (int i = 0; i < num_vals; ++i) {
110 - G[0] += Square(vals_x[i]);
111 - G[1] += vals_x[i] * vals_y[i];
112 - G[3] += Square(vals_y[i]);
113 - }
114 -
115 - // The matrix is symmetric, so this is a given.
116 - G[2] = G[1];
117 -}
118 -
119 -inline void CalculateGInt16(const int16_t* const vals_x,
120 - const int16_t* const vals_y, const int num_vals,
121 - int* const G) {
122 - // Non-accelerated version.
123 - for (int i = 0; i < num_vals; ++i) {
124 - G[0] += Square(vals_x[i]);
125 - G[1] += vals_x[i] * vals_y[i];
126 - G[3] += Square(vals_y[i]);
127 - }
128 -
129 - // The matrix is symmetric, so this is a given.
130 - G[2] = G[1];
131 -}
132 -
133 -
134 -// Puts the image gradient matrix about a pixel into the 2x2 float array G.
135 -// Looks up interpolated pixels, then calls above method for implementation.
136 -inline void CalculateG(const int window_radius, const float center_x,
137 - const float center_y, const Image<int32_t>& I_x,
138 - const Image<int32_t>& I_y, float* const G) {
139 - SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!");
140 -
141 - // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels).
142 - static const int kMaxWindowRadius = 5;
143 - SCHECK(window_radius <= kMaxWindowRadius,
144 - "Window %d > %d!", window_radius, kMaxWindowRadius);
145 -
146 - // Diameter of window is 2 * radius + 1 for center pixel.
147 - static const int kWindowBufferSize =
148 - (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1);
149 -
150 - // Preallocate buffers statically for efficiency.
151 - static int16_t vals_x[kWindowBufferSize];
152 - static int16_t vals_y[kWindowBufferSize];
153 -
154 - const int src_left_fixed = RealToFixed1616(center_x - window_radius);
155 - const int src_top_fixed = RealToFixed1616(center_y - window_radius);
156 -
157 - int16_t* vals_x_ptr = vals_x;
158 - int16_t* vals_y_ptr = vals_y;
159 -
160 - const int window_size = 2 * window_radius + 1;
161 - for (int y = 0; y < window_size; ++y) {
162 - const int fp_y = src_top_fixed + (y << 16);
163 -
164 - for (int x = 0; x < window_size; ++x) {
165 - const int fp_x = src_left_fixed + (x << 16);
166 -
167 - *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
168 - *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
169 - }
170 - }
171 -
172 - int32_t g_temp[] = {0, 0, 0, 0};
173 - CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp);
174 -
175 - for (int i = 0; i < 4; ++i) {
176 - G[i] = g_temp[i];
177 - }
178 -}
179 -
180 -inline float ImageCrossCorrelation(const Image<float>& image1,
181 - const Image<float>& image2,
182 - const int x_offset, const int y_offset) {
183 - SCHECK(image1.GetWidth() == image2.GetWidth() &&
184 - image1.GetHeight() == image2.GetHeight(),
185 - "Dimension mismatch! %dx%d vs %dx%d",
186 - image1.GetWidth(), image1.GetHeight(),
187 - image2.GetWidth(), image2.GetHeight());
188 -
189 - const int num_pixels = image1.GetWidth() * image1.GetHeight();
190 - const float* data1 = image1.data();
191 - const float* data2 = image2.data();
192 - return ComputeCrossCorrelation(data1, data2, num_pixels);
193 -}
194 -
195 -// Copies an arbitrary region of an image to another (floating point)
196 -// image, scaling as it goes using bilinear interpolation.
197 -inline void CopyArea(const Image<uint8_t>& image,
198 - const BoundingBox& area_to_copy,
199 - Image<float>* const patch_image) {
200 - VLOG(2) << "Copying from: " << area_to_copy << std::endl;
201 -
202 - const int patch_width = patch_image->GetWidth();
203 - const int patch_height = patch_image->GetHeight();
204 -
205 - const float x_dist_between_samples = patch_width > 0 ?
206 - area_to_copy.GetWidth() / (patch_width - 1) : 0;
207 -
208 - const float y_dist_between_samples = patch_height > 0 ?
209 - area_to_copy.GetHeight() / (patch_height - 1) : 0;
210 -
211 - for (int y_index = 0; y_index < patch_height; ++y_index) {
212 - const float sample_y =
213 - y_index * y_dist_between_samples + area_to_copy.top_;
214 -
215 - for (int x_index = 0; x_index < patch_width; ++x_index) {
216 - const float sample_x =
217 - x_index * x_dist_between_samples + area_to_copy.left_;
218 -
219 - if (image.ValidInterpPixel(sample_x, sample_y)) {
220 - // TODO(andrewharp): Do area averaging when downsampling.
221 - (*patch_image)[y_index][x_index] =
222 - image.GetPixelInterp(sample_x, sample_y);
223 - } else {
224 - (*patch_image)[y_index][x_index] = -1.0f;
225 - }
226 - }
227 - }
228 -}
229 -
230 -
231 -// Takes a floating point image and normalizes it in-place.
232 -//
233 -// First, negative values will be set to the mean of the non-negative pixels
234 -// in the image.
235 -//
236 -// Then, the resulting will be normalized such that it has mean value of 0.0 and
237 -// a standard deviation of 1.0.
238 -inline void NormalizeImage(Image<float>* const image) {
239 - const float* const data_ptr = image->data();
240 -
241 - // Copy only the non-negative values to some temp memory.
242 - float running_sum = 0.0f;
243 - int num_data_gte_zero = 0;
244 - {
245 - float* const curr_data = (*image)[0];
246 - for (int i = 0; i < image->data_size_; ++i) {
247 - if (curr_data[i] >= 0.0f) {
248 - running_sum += curr_data[i];
249 - ++num_data_gte_zero;
250 - } else {
251 - curr_data[i] = -1.0f;
252 - }
253 - }
254 - }
255 -
256 - // If none of the pixels are valid, just set the entire thing to 0.0f.
257 - if (num_data_gte_zero == 0) {
258 - image->Clear(0.0f);
259 - return;
260 - }
261 -
262 - const float corrected_mean = running_sum / num_data_gte_zero;
263 -
264 - float* curr_data = (*image)[0];
265 - for (int i = 0; i < image->data_size_; ++i) {
266 - const float curr_val = *curr_data;
267 - *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean;
268 - }
269 -
270 - const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f);
271 -
272 - if (std_dev > 0.0f) {
273 - curr_data = (*image)[0];
274 - for (int i = 0; i < image->data_size_; ++i) {
275 - *curr_data++ /= std_dev;
276 - }
277 -
278 -#ifdef SANITY_CHECKS
279 - LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev);
280 - const float correlation =
281 - ComputeCrossCorrelation(image->data(),
282 - image->data(),
283 - image->data_size_);
284 -
285 - if (std::abs(correlation - 1.0f) > EPSILON) {
286 - LOG(ERROR) << "Bad image!" << std::endl;
287 - LOG(ERROR) << *image << std::endl;
288 - }
289 -
290 - SCHECK(std::abs(correlation - 1.0f) < EPSILON,
291 - "Correlation wasn't 1.0f: %.10f", correlation);
292 -#endif
293 - }
294 -}
295 -
296 -} // namespace tf_tracking
297 -
298 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
18 -
19 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 -
24 -namespace tf_tracking {
25 -
26 -typedef uint8_t Code;
27 -
28 -class IntegralImage : public Image<uint32_t> {
29 - public:
30 - explicit IntegralImage(const Image<uint8_t>& image_base)
31 - : Image<uint32_t>(image_base.GetWidth(), image_base.GetHeight()) {
32 - Recompute(image_base);
33 - }
34 -
35 - IntegralImage(const int width, const int height)
36 - : Image<uint32_t>(width, height) {}
37 -
38 - void Recompute(const Image<uint8_t>& image_base) {
39 - SCHECK(image_base.GetWidth() == GetWidth() &&
40 - image_base.GetHeight() == GetHeight(), "Dimensions don't match!");
41 -
42 - // Sum along first row.
43 - {
44 - int x_sum = 0;
45 - for (int x = 0; x < image_base.GetWidth(); ++x) {
46 - x_sum += image_base[0][x];
47 - (*this)[0][x] = x_sum;
48 - }
49 - }
50 -
51 - // Sum everything else.
52 - for (int y = 1; y < image_base.GetHeight(); ++y) {
53 - uint32_t* curr_sum = (*this)[y];
54 -
55 - // Previously summed pointers.
56 - const uint32_t* up_one = (*this)[y - 1];
57 -
58 - // Current value pointer.
59 - const uint8_t* curr_delta = image_base[y];
60 -
61 - uint32_t row_till_now = 0;
62 -
63 - for (int x = 0; x < GetWidth(); ++x) {
64 - // Add the one above and the one to the left.
65 - row_till_now += *curr_delta;
66 - *curr_sum = *up_one + row_till_now;
67 -
68 - // Scoot everything along.
69 - ++curr_sum;
70 - ++up_one;
71 - ++curr_delta;
72 - }
73 - }
74 -
75 - SCHECK(VerifyData(image_base), "Images did not match!");
76 - }
77 -
78 - bool VerifyData(const Image<uint8_t>& image_base) {
79 - for (int y = 0; y < GetHeight(); ++y) {
80 - for (int x = 0; x < GetWidth(); ++x) {
81 - uint32_t curr_val = (*this)[y][x];
82 -
83 - if (x > 0) {
84 - curr_val -= (*this)[y][x - 1];
85 - }
86 -
87 - if (y > 0) {
88 - curr_val -= (*this)[y - 1][x];
89 - }
90 -
91 - if (x > 0 && y > 0) {
92 - curr_val += (*this)[y - 1][x - 1];
93 - }
94 -
95 - if (curr_val != image_base[y][x]) {
96 - LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]);
97 - return false;
98 - }
99 -
100 - if (GetRegionSum(x, y, x, y) != curr_val) {
101 - LOGE("Mismatch!");
102 - }
103 - }
104 - }
105 -
106 - return true;
107 - }
108 -
109 - // Returns the sum of all pixels in the specified region.
110 - inline uint32_t GetRegionSum(const int x1, const int y1, const int x2,
111 - const int y2) const {
112 - SCHECK(x1 >= 0 && y1 >= 0 &&
113 - x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(),
114 - "indices out of bounds! %d-%d / %d, %d-%d / %d, ",
115 - x1, x2, GetWidth(), y1, y2, GetHeight());
116 -
117 - const uint32_t everything = (*this)[y2][x2];
118 -
119 - uint32_t sum = everything;
120 - if (x1 > 0 && y1 > 0) {
121 - // Most common case.
122 - const uint32_t left = (*this)[y2][x1 - 1];
123 - const uint32_t top = (*this)[y1 - 1][x2];
124 - const uint32_t top_left = (*this)[y1 - 1][x1 - 1];
125 -
126 - sum = everything - left - top + top_left;
127 - SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d",
128 - everything, left, top, top_left, sum, x1, y1, x2, y2);
129 - } else if (x1 > 0) {
130 - // Flush against top of image.
131 - // Subtract out the region to the left only.
132 - const uint32_t top = (*this)[y2][x1 - 1];
133 - sum = everything - top;
134 - SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum);
135 - } else if (y1 > 0) {
136 - // Flush against left side of image.
137 - // Subtract out the region above only.
138 - const uint32_t left = (*this)[y1 - 1][x2];
139 - sum = everything - left;
140 - SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum);
141 - }
142 -
143 - SCHECK(sum >= 0, "Negative sum!");
144 -
145 - return sum;
146 - }
147 -
148 - // Returns the 2bit code associated with this region, which represents
149 - // the overall gradient.
150 - inline Code GetCode(const BoundingBox& bounding_box) const {
151 - return GetCode(bounding_box.left_, bounding_box.top_,
152 - bounding_box.right_, bounding_box.bottom_);
153 - }
154 -
155 - inline Code GetCode(const int x1, const int y1,
156 - const int x2, const int y2) const {
157 - SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d",
158 - x1, y1, x2, y2);
159 -
160 - // Gradient computed vertically.
161 - const int box_height = (y2 - y1) / 2;
162 - const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height);
163 - const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2);
164 - const bool vertical_code = top_sum > bottom_sum;
165 -
166 - // Gradient computed horizontally.
167 - const int box_width = (x2 - x1) / 2;
168 - const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2);
169 - const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2);
170 - const bool horizontal_code = left_sum > right_sum;
171 -
172 - const Code final_code = (vertical_code << 1) | horizontal_code;
173 -
174 - SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)),
175 - "Invalid code! %d", final_code);
176 -
177 - // Returns a value 0-3.
178 - return final_code;
179 - }
180 -
181 - private:
182 - TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage);
183 -};
184 -
185 -} // namespace tf_tracking
186 -
187 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
18 -
19 -#include <jni.h>
20 -#include <stdint.h>
21 -
22 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 -
24 -// The JniLongField class is used to access Java fields from native code. This
25 -// technique of hiding pointers to native objects in opaque Java fields is how
26 -// the Android hardware libraries work. This reduces the amount of static
27 -// native methods and makes it easier to manage the lifetime of native objects.
28 -class JniLongField {
29 - public:
30 - JniLongField(const char* field_name)
31 - : field_name_(field_name), field_ID_(0) {}
32 -
33 - int64_t get(JNIEnv* env, jobject thiz) {
34 - if (field_ID_ == 0) {
35 - jclass cls = env->GetObjectClass(thiz);
36 - CHECK_ALWAYS(cls != 0, "Unable to find class");
37 - field_ID_ = env->GetFieldID(cls, field_name_, "J");
38 - CHECK_ALWAYS(field_ID_ != 0,
39 - "Unable to find field %s. (Check proguard cfg)", field_name_);
40 - }
41 -
42 - return env->GetLongField(thiz, field_ID_);
43 - }
44 -
45 - void set(JNIEnv* env, jobject thiz, int64_t value) {
46 - if (field_ID_ == 0) {
47 - jclass cls = env->GetObjectClass(thiz);
48 - CHECK_ALWAYS(cls != 0, "Unable to find class");
49 - field_ID_ = env->GetFieldID(cls, field_name_, "J");
50 - CHECK_ALWAYS(field_ID_ != 0,
51 - "Unable to find field %s (Check proguard cfg)", field_name_);
52 - }
53 -
54 - env->SetLongField(thiz, field_ID_, value);
55 - }
56 -
57 - private:
58 - const char* const field_name_;
59 -
60 - // This is just a cache
61 - jfieldID field_ID_;
62 -};
63 -
64 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
18 -
19 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
25 -
26 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
27 -
28 -namespace tf_tracking {
29 -
30 -// For keeping track of keypoints.
31 -struct Keypoint {
32 - Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {}
33 - Keypoint(const float x, const float y)
34 - : pos_(x, y), score_(0.0f), type_(0) {}
35 -
36 - Point2f pos_;
37 - float score_;
38 - uint8_t type_;
39 -};
40 -
41 -inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) {
42 - return stream << "[" << keypoint.pos_ << ", "
43 - << keypoint.score_ << ", " << keypoint.type_ << "]";
44 -}
45 -
46 -} // namespace tf_tracking
47 -
48 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// Various keypoint detecting functions.
17 -
18 -#include <float.h>
19 -
20 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
24 -
25 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
27 -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
28 -
29 -namespace tf_tracking {
30 -
31 -static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) {
32 - return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]);
33 -}
34 -
35 -void KeypointDetector::ScoreKeypoints(const ImageData& image_data,
36 - const int num_candidates,
37 - Keypoint* const candidate_keypoints) {
38 - const Image<int>& I_x = *image_data.GetSpatialX(0);
39 - const Image<int>& I_y = *image_data.GetSpatialY(0);
40 -
41 - if (config_->detect_skin) {
42 - const Image<uint8_t>& u_data = *image_data.GetU();
43 - const Image<uint8_t>& v_data = *image_data.GetV();
44 -
45 - static const int reference[] = {111, 155};
46 -
47 - // Score all the keypoints.
48 - for (int i = 0; i < num_candidates; ++i) {
49 - Keypoint* const keypoint = candidate_keypoints + i;
50 -
51 - const int x_pos = keypoint->pos_.x * 2;
52 - const int y_pos = keypoint->pos_.y * 2;
53 -
54 - const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]};
55 - keypoint->score_ =
56 - HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) /
57 - GetDistSquaredBetween(reference, curr_color);
58 - }
59 - } else {
60 - // Score all the keypoints.
61 - for (int i = 0; i < num_candidates; ++i) {
62 - Keypoint* const keypoint = candidate_keypoints + i;
63 - keypoint->score_ =
64 - HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y);
65 - }
66 - }
67 -}
68 -
69 -
70 -inline int KeypointCompare(const void* const a, const void* const b) {
71 - return (reinterpret_cast<const Keypoint*>(a)->score_ -
72 - reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1;
73 -}
74 -
75 -
76 -// Quicksorts detected keypoints by score.
77 -void KeypointDetector::SortKeypoints(const int num_candidates,
78 - Keypoint* const candidate_keypoints) const {
79 - qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare);
80 -
81 -#ifdef SANITY_CHECKS
82 - // Verify that the array got sorted.
83 - float last_score = FLT_MAX;
84 - for (int i = 0; i < num_candidates; ++i) {
85 - const float curr_score = candidate_keypoints[i].score_;
86 -
87 - // Scores should be monotonically increasing.
88 - SCHECK(last_score >= curr_score,
89 - "Quicksort failure! %d: %.5f > %d: %.5f (%d total)",
90 - i - 1, last_score, i, curr_score, num_candidates);
91 -
92 - last_score = curr_score;
93 - }
94 -#endif
95 -}
96 -
97 -
98 -int KeypointDetector::SelectKeypointsInBox(
99 - const BoundingBox& box,
100 - const Keypoint* const candidate_keypoints,
101 - const int num_candidates,
102 - const int max_keypoints,
103 - const int num_existing_keypoints,
104 - const Keypoint* const existing_keypoints,
105 - Keypoint* const final_keypoints) const {
106 - if (max_keypoints <= 0) {
107 - return 0;
108 - }
109 -
110 - // This is the distance within which keypoints may be placed to each other
111 - // within this box, roughly based on the box dimensions.
112 - const int distance =
113 - MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f);
114 -
115 - // First, mark keypoints that already happen to be inside this region. Ignore
116 - // keypoints that are outside it, however close they might be.
117 - interest_map_->Clear(false);
118 - for (int i = 0; i < num_existing_keypoints; ++i) {
119 - const Keypoint& candidate = existing_keypoints[i];
120 -
121 - const int x_pos = candidate.pos_.x;
122 - const int y_pos = candidate.pos_.y;
123 - if (box.Contains(candidate.pos_)) {
124 - MarkImage(x_pos, y_pos, distance, interest_map_.get());
125 - }
126 - }
127 -
128 - // Now, go through and check which keypoints will still fit in the box.
129 - int num_keypoints_selected = 0;
130 - for (int i = 0; i < num_candidates; ++i) {
131 - const Keypoint& candidate = candidate_keypoints[i];
132 -
133 - const int x_pos = candidate.pos_.x;
134 - const int y_pos = candidate.pos_.y;
135 -
136 - if (!box.Contains(candidate.pos_) ||
137 - !interest_map_->ValidPixel(x_pos, y_pos)) {
138 - continue;
139 - }
140 -
141 - if (!(*interest_map_)[y_pos][x_pos]) {
142 - final_keypoints[num_keypoints_selected++] = candidate;
143 - if (num_keypoints_selected >= max_keypoints) {
144 - break;
145 - }
146 - MarkImage(x_pos, y_pos, distance, interest_map_.get());
147 - }
148 - }
149 - return num_keypoints_selected;
150 -}
151 -
152 -
153 -void KeypointDetector::SelectKeypoints(
154 - const std::vector<BoundingBox>& boxes,
155 - const Keypoint* const candidate_keypoints,
156 - const int num_candidates,
157 - FramePair* const curr_change) const {
158 - // Now select all the interesting keypoints that fall insider our boxes.
159 - curr_change->number_of_keypoints_ = 0;
160 - for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
161 - iter != boxes.end(); ++iter) {
162 - const BoundingBox bounding_box = *iter;
163 -
164 - // Count up keypoints that have already been selected, and fall within our
165 - // box.
166 - int num_keypoints_already_in_box = 0;
167 - for (int i = 0; i < curr_change->number_of_keypoints_; ++i) {
168 - if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) {
169 - ++num_keypoints_already_in_box;
170 - }
171 - }
172 -
173 - const int max_keypoints_to_find_in_box =
174 - MIN(kMaxKeypointsForObject - num_keypoints_already_in_box,
175 - kMaxKeypoints - curr_change->number_of_keypoints_);
176 -
177 - const int num_new_keypoints_in_box = SelectKeypointsInBox(
178 - bounding_box,
179 - candidate_keypoints,
180 - num_candidates,
181 - max_keypoints_to_find_in_box,
182 - curr_change->number_of_keypoints_,
183 - curr_change->frame1_keypoints_,
184 - curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_);
185 -
186 - curr_change->number_of_keypoints_ += num_new_keypoints_in_box;
187 -
188 - LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_);
189 - }
190 -}
191 -
192 -
193 -// Walks along the given circle checking for pixels above or below the center.
194 -// Returns a score, or 0 if the keypoint did not pass the criteria.
195 -//
196 -// Parameters:
197 -// circle_perimeter: the circumference in pixels of the circle.
198 -// threshold: the minimum number of contiguous pixels that must be above or
199 -// below the center value.
200 -// center_ptr: the location of the center pixel in memory
201 -// offsets: the relative offsets from the center pixel of the edge pixels.
202 -inline int TestCircle(const int circle_perimeter, const int threshold,
203 - const uint8_t* const center_ptr, const int* offsets) {
204 - // Get the actual value of the center pixel for easier reference later on.
205 - const int center_value = static_cast<int>(*center_ptr);
206 -
207 - // Number of total pixels to check. Have to wrap around some in case
208 - // the contiguous section is split by the array edges.
209 - const int num_total = circle_perimeter + threshold - 1;
210 -
211 - int num_above = 0;
212 - int above_diff = 0;
213 -
214 - int num_below = 0;
215 - int below_diff = 0;
216 -
217 - // Used to tell when this is definitely not going to meet the threshold so we
218 - // can early abort.
219 - int minimum_by_now = threshold - num_total + 1;
220 -
221 - // Go through every pixel along the perimeter of the circle, and then around
222 - // again a little bit.
223 - for (int i = 0; i < num_total; ++i) {
224 - // This should be faster than mod.
225 - const int perim_index = i < circle_perimeter ? i : i - circle_perimeter;
226 -
227 - // This gets the value of the current pixel along the perimeter by using
228 - // a precomputed offset.
229 - const int curr_value =
230 - static_cast<int>(center_ptr[offsets[perim_index]]);
231 -
232 - const int difference = curr_value - center_value;
233 -
234 - if (difference > kFastDiffAmount) {
235 - above_diff += difference;
236 - ++num_above;
237 -
238 - num_below = 0;
239 - below_diff = 0;
240 -
241 - if (num_above >= threshold) {
242 - return above_diff;
243 - }
244 - } else if (difference < -kFastDiffAmount) {
245 - below_diff += difference;
246 - ++num_below;
247 -
248 - num_above = 0;
249 - above_diff = 0;
250 -
251 - if (num_below >= threshold) {
252 - return below_diff;
253 - }
254 - } else {
255 - num_above = 0;
256 - num_below = 0;
257 - above_diff = 0;
258 - below_diff = 0;
259 - }
260 -
261 - // See if there's any chance of making the threshold.
262 - if (MAX(num_above, num_below) < minimum_by_now) {
263 - // Didn't pass.
264 - return 0;
265 - }
266 - ++minimum_by_now;
267 - }
268 -
269 - // Didn't pass.
270 - return 0;
271 -}
272 -
273 -
274 -// Returns a score in the range [0.0, positive infinity) which represents the
275 -// relative likelihood of a point being a corner.
276 -float KeypointDetector::HarrisFilter(const Image<int32_t>& I_x,
277 - const Image<int32_t>& I_y, const float x,
278 - const float y) const {
279 - if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) &&
280 - I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) {
281 - // Image gradient matrix.
282 - float G[] = { 0, 0, 0, 0 };
283 - CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G);
284 -
285 - const float dx = G[0];
286 - const float dy = G[3];
287 - const float dxy = G[1];
288 -
289 - // Harris-Nobel corner score.
290 - return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN);
291 - }
292 -
293 - return 0.0f;
294 -}
295 -
296 -
297 -int KeypointDetector::AddExtraCandidatesForBoxes(
298 - const std::vector<BoundingBox>& boxes,
299 - const int max_num_keypoints,
300 - Keypoint* const keypoints) const {
301 - int num_keypoints_added = 0;
302 -
303 - for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
304 - iter != boxes.end(); ++iter) {
305 - const BoundingBox box = *iter;
306 -
307 - for (int i = 0; i < kNumToAddAsCandidates; ++i) {
308 - for (int j = 0; j < kNumToAddAsCandidates; ++j) {
309 - if (num_keypoints_added >= max_num_keypoints) {
310 - LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints);
311 - return num_keypoints_added;
312 - }
313 -
314 - Keypoint& curr_keypoint = keypoints[num_keypoints_added++];
315 - curr_keypoint.pos_ = Point2f(
316 - box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates,
317 - box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates);
318 - curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST;
319 - }
320 - }
321 - }
322 -
323 - return num_keypoints_added;
324 -}
325 -
326 -
327 -void KeypointDetector::FindKeypoints(const ImageData& image_data,
328 - const std::vector<BoundingBox>& rois,
329 - const FramePair& prev_change,
330 - FramePair* const curr_change) {
331 - // Copy keypoints from second frame of last pass to temp keypoints of this
332 - // pass.
333 - int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_);
334 -
335 - const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints;
336 - number_of_tmp_keypoints +=
337 - FindFastKeypoints(image_data, max_num_fast,
338 - tmp_keypoints_ + number_of_tmp_keypoints);
339 -
340 - TimeLog("Found FAST keypoints");
341 -
342 - if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
343 - LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints",
344 - kMaxTempKeypoints, number_of_tmp_keypoints);
345 - }
346 -
347 - if (kAddArbitraryKeypoints) {
348 - // Add some for each object prior to scoring.
349 - const int max_num_box_keypoints =
350 - kMaxTempKeypoints - number_of_tmp_keypoints;
351 - number_of_tmp_keypoints +=
352 - AddExtraCandidatesForBoxes(rois, max_num_box_keypoints,
353 - tmp_keypoints_ + number_of_tmp_keypoints);
354 - TimeLog("Added box keypoints");
355 -
356 - if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
357 - LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints",
358 - kMaxTempKeypoints, number_of_tmp_keypoints);
359 - }
360 - }
361 -
362 - // Score them...
363 - LOGV("Scoring %d keypoints!", number_of_tmp_keypoints);
364 - ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_);
365 - TimeLog("Scored keypoints");
366 -
367 - // Now pare it down a bit.
368 - SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_);
369 - TimeLog("Sorted keypoints");
370 -
371 - LOGV("%d keypoints to select from!", number_of_tmp_keypoints);
372 -
373 - SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change);
374 - TimeLog("Selected keypoints");
375 -
376 - LOGV("Picked %d (%d max) final keypoints out of %d potential.",
377 - curr_change->number_of_keypoints_,
378 - kMaxKeypoints, number_of_tmp_keypoints);
379 -}
380 -
381 -
382 -int KeypointDetector::CopyKeypoints(const FramePair& prev_change,
383 - Keypoint* const new_keypoints) {
384 - int number_of_keypoints = 0;
385 -
386 - // Caching values from last pass, just copy and compact.
387 - for (int i = 0; i < prev_change.number_of_keypoints_; ++i) {
388 - if (prev_change.optical_flow_found_keypoint_[i]) {
389 - new_keypoints[number_of_keypoints] =
390 - prev_change.frame2_keypoints_[i];
391 -
392 - new_keypoints[number_of_keypoints].score_ =
393 - prev_change.frame1_keypoints_[i].score_;
394 -
395 - ++number_of_keypoints;
396 - }
397 - }
398 -
399 - TimeLog("Copied keypoints");
400 - return number_of_keypoints;
401 -}
402 -
403 -
404 -// FAST keypoint detector.
405 -int KeypointDetector::FindFastKeypoints(const Image<uint8_t>& frame,
406 - const int quadrant,
407 - const int downsample_factor,
408 - const int max_num_keypoints,
409 - Keypoint* const keypoints) {
410 - /*
411 - // Reference for a circle of diameter 7.
412 - const int circle[] = {0, 0, 1, 1, 1, 0, 0,
413 - 0, 1, 0, 0, 0, 1, 0,
414 - 1, 0, 0, 0, 0, 0, 1,
415 - 1, 0, 0, 0, 0, 0, 1,
416 - 1, 0, 0, 0, 0, 0, 1,
417 - 0, 1, 0, 0, 0, 1, 0,
418 - 0, 0, 1, 1, 1, 0, 0};
419 - const int circle_offset[] =
420 - {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46};
421 - */
422 -
423 - // Quick test of compass directions. Any length 16 circle with a break of up
424 - // to 4 pixels will have at least 3 of these 4 pixels active.
425 - static const int short_circle_perimeter = 4;
426 - static const int short_threshold = 3;
427 - static const int short_circle_x[] = { -3, 0, +3, 0 };
428 - static const int short_circle_y[] = { 0, -3, 0, +3 };
429 -
430 - // Precompute image offsets.
431 - int short_offsets[short_circle_perimeter];
432 - for (int i = 0; i < short_circle_perimeter; ++i) {
433 - short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth();
434 - }
435 -
436 - // Large circle values.
437 - static const int full_circle_perimeter = 16;
438 - static const int full_threshold = 12;
439 - static const int full_circle_x[] =
440 - { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 };
441 - static const int full_circle_y[] =
442 - { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 };
443 -
444 - // Precompute image offsets.
445 - int full_offsets[full_circle_perimeter];
446 - for (int i = 0; i < full_circle_perimeter; ++i) {
447 - full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth();
448 - }
449 -
450 - const int scratch_stride = frame.stride();
451 -
452 - keypoint_scratch_->Clear(0);
453 -
454 - // Set up the bounds on the region to test based on the passed-in quadrant.
455 - const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer;
456 - const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer;
457 - const int start_x =
458 - kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width);
459 - const int start_y =
460 - kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height);
461 - const int end_x = start_x + quadrant_width;
462 - const int end_y = start_y + quadrant_height;
463 -
464 - // Loop through once to find FAST keypoint clumps.
465 - for (int img_y = start_y; img_y < end_y; ++img_y) {
466 - const uint8_t* curr_pixel_ptr = frame[img_y] + start_x;
467 -
468 - for (int img_x = start_x; img_x < end_x; ++img_x) {
469 - // Only insert it if it meets the quick minimum requirements test.
470 - if (TestCircle(short_circle_perimeter, short_threshold,
471 - curr_pixel_ptr, short_offsets) != 0) {
472 - // Longer test for actual keypoint score..
473 - const int fast_score = TestCircle(full_circle_perimeter,
474 - full_threshold,
475 - curr_pixel_ptr,
476 - full_offsets);
477 -
478 - // Non-zero score means the keypoint was found.
479 - if (fast_score != 0) {
480 - uint8_t* const center_ptr = (*keypoint_scratch_)[img_y] + img_x;
481 -
482 - // Increase the keypoint count on this pixel and the pixels in all
483 - // 4 cardinal directions.
484 - *center_ptr += 5;
485 - *(center_ptr - 1) += 1;
486 - *(center_ptr + 1) += 1;
487 - *(center_ptr - scratch_stride) += 1;
488 - *(center_ptr + scratch_stride) += 1;
489 - }
490 - }
491 -
492 - ++curr_pixel_ptr;
493 - } // x
494 - } // y
495 -
496 - TimeLog("Found FAST keypoints.");
497 -
498 - int num_keypoints = 0;
499 - // Loop through again and Harris filter pixels in the center of clumps.
500 - // We can shrink the window by 1 pixel on every side.
501 - for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) {
502 - const uint8_t* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x;
503 -
504 - for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) {
505 - if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) {
506 - Keypoint* const keypoint = keypoints + num_keypoints;
507 - keypoint->pos_ = Point2f(
508 - img_x * downsample_factor, img_y * downsample_factor);
509 - keypoint->score_ = 0;
510 - keypoint->type_ = KEYPOINT_TYPE_FAST;
511 -
512 - ++num_keypoints;
513 - if (num_keypoints >= max_num_keypoints) {
514 - return num_keypoints;
515 - }
516 - }
517 -
518 - ++curr_pixel_ptr;
519 - } // x
520 - } // y
521 -
522 - TimeLog("Picked FAST keypoints.");
523 -
524 - return num_keypoints;
525 -}
526 -
527 -int KeypointDetector::FindFastKeypoints(const ImageData& image_data,
528 - const int max_num_keypoints,
529 - Keypoint* const keypoints) {
530 - int downsample_factor = 1;
531 - int num_found = 0;
532 -
533 - // TODO(andrewharp): Get this working for multiple image scales.
534 - for (int i = 0; i < 1; ++i) {
535 - const Image<uint8_t>& frame = *image_data.GetPyramidSqrt2Level(i);
536 - num_found += FindFastKeypoints(
537 - frame, fast_quadrant_,
538 - downsample_factor, max_num_keypoints, keypoints + num_found);
539 - downsample_factor *= 2;
540 - }
541 -
542 - // Increment the current quadrant.
543 - fast_quadrant_ = (fast_quadrant_ + 1) % 4;
544 -
545 - return num_found;
546 -}
547 -
548 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
18 -
19 -#include <stdint.h>
20 -#include <vector>
21 -
22 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
26 -
27 -namespace tf_tracking {
28 -
29 -struct Keypoint;
30 -
31 -class KeypointDetector {
32 - public:
33 - explicit KeypointDetector(const KeypointDetectorConfig* const config)
34 - : config_(config),
35 - keypoint_scratch_(new Image<uint8_t>(config_->image_size)),
36 - interest_map_(new Image<bool>(config_->image_size)),
37 - fast_quadrant_(0) {
38 - interest_map_->Clear(false);
39 - }
40 -
41 - ~KeypointDetector() {}
42 -
43 - // Finds a new set of keypoints for the current frame, picked from the current
44 - // set of keypoints and also from a set discovered via a keypoint detector.
45 - // Special attention is applied to make sure that keypoints are distributed
46 - // within the supplied ROIs.
47 - void FindKeypoints(const ImageData& image_data,
48 - const std::vector<BoundingBox>& rois,
49 - const FramePair& prev_change,
50 - FramePair* const curr_change);
51 -
52 - private:
53 - // Compute the corneriness of a point in the image.
54 - float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y,
55 - const float x, const float y) const;
56 -
57 - // Adds a grid of candidate keypoints to the given box, up to
58 - // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
59 - int AddExtraCandidatesForBoxes(
60 - const std::vector<BoundingBox>& boxes,
61 - const int max_num_keypoints,
62 - Keypoint* const keypoints) const;
63 -
64 - // Scan the frame for potential keypoints using the FAST keypoint detector.
65 - // Quadrant is an argument 0-3 which refers to the quadrant of the image in
66 - // which to detect keypoints.
67 - int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant,
68 - const int downsample_factor,
69 - const int max_num_keypoints, Keypoint* const keypoints);
70 -
71 - int FindFastKeypoints(const ImageData& image_data,
72 - const int max_num_keypoints,
73 - Keypoint* const keypoints);
74 -
75 - // Score a bunch of candidate keypoints. Assigns the scores to the input
76 - // candidate_keypoints array entries.
77 - void ScoreKeypoints(const ImageData& image_data,
78 - const int num_candidates,
79 - Keypoint* const candidate_keypoints);
80 -
81 - void SortKeypoints(const int num_candidates,
82 - Keypoint* const candidate_keypoints) const;
83 -
84 - // Selects a set of keypoints falling within the supplied box such that the
85 - // most highly rated keypoints are picked first, and so that none of them are
86 - // too close together.
87 - int SelectKeypointsInBox(
88 - const BoundingBox& box,
89 - const Keypoint* const candidate_keypoints,
90 - const int num_candidates,
91 - const int max_keypoints,
92 - const int num_existing_keypoints,
93 - const Keypoint* const existing_keypoints,
94 - Keypoint* const final_keypoints) const;
95 -
96 - // Selects from the supplied sorted keypoint pool a set of keypoints that will
97 - // best cover the given set of boxes, such that each box is covered at a
98 - // resolution proportional to its size.
99 - void SelectKeypoints(
100 - const std::vector<BoundingBox>& boxes,
101 - const Keypoint* const candidate_keypoints,
102 - const int num_candidates,
103 - FramePair* const frame_change) const;
104 -
105 - // Copies and compacts the found keypoints in the second frame of prev_change
106 - // into the array at new_keypoints.
107 - static int CopyKeypoints(const FramePair& prev_change,
108 - Keypoint* const new_keypoints);
109 -
110 - const KeypointDetectorConfig* const config_;
111 -
112 - // Scratch memory for keypoint candidacy detection and non-max suppression.
113 - std::unique_ptr<Image<uint8_t> > keypoint_scratch_;
114 -
115 - // Regions of the image to pay special attention to.
116 - std::unique_ptr<Image<bool> > interest_map_;
117 -
118 - // The current quadrant of the image to detect FAST keypoints in.
119 - // Keypoint detection is staggered for performance reasons. Every four frames
120 - // a full scan of the frame will have been performed.
121 - int fast_quadrant_;
122 -
123 - Keypoint tmp_keypoints_[kMaxTempKeypoints];
124 -};
125 -
126 -} // namespace tf_tracking
127 -
128 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
1 -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
17 -
18 -#ifdef STANDALONE_DEMO_LIB
19 -
20 -#include <android/log.h>
21 -#include <stdlib.h>
22 -#include <time.h>
23 -#include <iostream>
24 -#include <sstream>
25 -
26 -LogMessage::LogMessage(const char* fname, int line, int severity)
27 - : fname_(fname), line_(line), severity_(severity) {}
28 -
29 -void LogMessage::GenerateLogMessage() {
30 - int android_log_level;
31 - switch (severity_) {
32 - case INFO:
33 - android_log_level = ANDROID_LOG_INFO;
34 - break;
35 - case WARNING:
36 - android_log_level = ANDROID_LOG_WARN;
37 - break;
38 - case ERROR:
39 - android_log_level = ANDROID_LOG_ERROR;
40 - break;
41 - case FATAL:
42 - android_log_level = ANDROID_LOG_FATAL;
43 - break;
44 - default:
45 - if (severity_ < INFO) {
46 - android_log_level = ANDROID_LOG_VERBOSE;
47 - } else {
48 - android_log_level = ANDROID_LOG_ERROR;
49 - }
50 - break;
51 - }
52 -
53 - std::stringstream ss;
54 - const char* const partial_name = strrchr(fname_, '/');
55 - ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_
56 - << " " << str();
57 - __android_log_write(android_log_level, "native", ss.str().c_str());
58 -
59 - // Also log to stderr (for standalone Android apps).
60 - std::cerr << "native : " << ss.str() << std::endl;
61 -
62 - // Android logging at level FATAL does not terminate execution, so abort()
63 - // is still required to stop the program.
64 - if (severity_ == FATAL) {
65 - abort();
66 - }
67 -}
68 -
69 -namespace {
70 -
71 -// Parse log level (int64) from environment variable (char*)
72 -int64_t LogLevelStrToInt(const char* tf_env_var_val) {
73 - if (tf_env_var_val == nullptr) {
74 - return 0;
75 - }
76 -
77 - // Ideally we would use env_var / safe_strto64, but it is
78 - // hard to use here without pulling in a lot of dependencies,
79 - // so we use std:istringstream instead
80 - std::string min_log_level(tf_env_var_val);
81 - std::istringstream ss(min_log_level);
82 - int64_t level;
83 - if (!(ss >> level)) {
84 - // Invalid vlog level setting, set level to default (0)
85 - level = 0;
86 - }
87 -
88 - return level;
89 -}
90 -
91 -int64_t MinLogLevelFromEnv() {
92 - const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
93 - return LogLevelStrToInt(tf_env_var_val);
94 -}
95 -
96 -int64_t MinVLogLevelFromEnv() {
97 - const char* tf_env_var_val = getenv("TF_CPP_MIN_VLOG_LEVEL");
98 - return LogLevelStrToInt(tf_env_var_val);
99 -}
100 -
101 -} // namespace
102 -
103 -LogMessage::~LogMessage() {
104 - // Read the min log level once during the first call to logging.
105 - static int64_t min_log_level = MinLogLevelFromEnv();
106 - if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage();
107 -}
108 -
109 -int64_t LogMessage::MinVLogLevel() {
110 - static const int64_t min_vlog_level = MinVLogLevelFromEnv();
111 - return min_vlog_level;
112 -}
113 -
114 -LogMessageFatal::LogMessageFatal(const char* file, int line)
115 - : LogMessage(file, line, ANDROID_LOG_FATAL) {}
116 -LogMessageFatal::~LogMessageFatal() {
117 - // abort() ensures we don't return (we promised we would not via
118 - // ATTRIBUTE_NORETURN).
119 - GenerateLogMessage();
120 - abort();
121 -}
122 -
123 -void LogString(const char* fname, int line, int severity,
124 - const std::string& message) {
125 - LogMessage(fname, line, severity) << message;
126 -}
127 -
128 -void LogPrintF(const int severity, const char* format, ...) {
129 - char message[1024];
130 - va_list argptr;
131 - va_start(argptr, format);
132 - vsnprintf(message, 1024, format, argptr);
133 - va_end(argptr);
134 - __android_log_write(severity, "native", message);
135 -
136 - // Also log to stderr (for standalone Android apps).
137 - std::cerr << "native : " << message << std::endl;
138 -}
139 -
140 -#endif
1 -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
18 -
19 -#include <android/log.h>
20 -#include <string.h>
21 -#include <ostream>
22 -#include <sstream>
23 -#include <string>
24 -
25 -// Allow this library to be built without depending on TensorFlow by
26 -// defining STANDALONE_DEMO_LIB. Otherwise TensorFlow headers will be
27 -// used.
28 -#ifdef STANDALONE_DEMO_LIB
29 -
30 -// A macro to disallow the copy constructor and operator= functions
31 -// This is usually placed in the private: declarations for a class.
32 -#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \
33 - TypeName(const TypeName&) = delete; \
34 - void operator=(const TypeName&) = delete
35 -
36 -#if defined(COMPILER_GCC3)
37 -#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
38 -#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
39 -#else
40 -#define TF_PREDICT_FALSE(x) (x)
41 -#define TF_PREDICT_TRUE(x) (x)
42 -#endif
43 -
44 -// Log levels equivalent to those defined by
45 -// third_party/tensorflow/core/platform/logging.h
46 -const int INFO = 0; // base_logging::INFO;
47 -const int WARNING = 1; // base_logging::WARNING;
48 -const int ERROR = 2; // base_logging::ERROR;
49 -const int FATAL = 3; // base_logging::FATAL;
50 -const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES;
51 -
52 -class LogMessage : public std::basic_ostringstream<char> {
53 - public:
54 - LogMessage(const char* fname, int line, int severity);
55 - ~LogMessage();
56 -
57 - // Returns the minimum log level for VLOG statements.
58 - // E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output,
59 - // but VLOG(3) will not. Defaults to 0.
60 - static int64_t MinVLogLevel();
61 -
62 - protected:
63 - void GenerateLogMessage();
64 -
65 - private:
66 - const char* fname_;
67 - int line_;
68 - int severity_;
69 -};
70 -
71 -// LogMessageFatal ensures the process will exit in failure after
72 -// logging this message.
73 -class LogMessageFatal : public LogMessage {
74 - public:
75 - LogMessageFatal(const char* file, int line);
76 - ~LogMessageFatal();
77 -};
78 -
79 -#define _TF_LOG_INFO \
80 - ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
81 -#define _TF_LOG_WARNING \
82 - ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::WARNING)
83 -#define _TF_LOG_ERROR \
84 - ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::ERROR)
85 -#define _TF_LOG_FATAL \
86 - ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__)
87 -
88 -#define _TF_LOG_QFATAL _TF_LOG_FATAL
89 -
90 -#define LOG(severity) _TF_LOG_##severity
91 -
92 -#define VLOG_IS_ON(lvl) ((lvl) <= LogMessage::MinVLogLevel())
93 -
94 -#define VLOG(lvl) \
95 - if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \
96 - LogMessage(__FILE__, __LINE__, ANDROID_LOG_INFO)
97 -
98 -void LogPrintF(const int severity, const char* format, ...);
99 -
100 -// Support for printf style logging.
101 -#define LOGV(...)
102 -#define LOGD(...)
103 -#define LOGI(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__);
104 -#define LOGW(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__);
105 -#define LOGE(...) LogPrintF(ANDROID_LOG_ERROR, __VA_ARGS__);
106 -
107 -#else
108 -
109 -#include "tensorflow/core/lib/strings/stringprintf.h"
110 -#include "tensorflow/core/platform/logging.h"
111 -
112 -// Support for printf style logging.
113 -#define LOGV(...)
114 -#define LOGD(...)
115 -#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
116 -#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
117 -#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
118 -
119 -#endif
120 -
121 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// NOTE: no native object detectors are currently provided or used by the code
17 -// in this directory. This class remains mainly for historical reasons.
18 -// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
19 -
20 -#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
21 -
22 -namespace tf_tracking {
23 -
24 -// This is here so that the vtable gets created properly.
25 -ObjectDetectorBase::~ObjectDetectorBase() {}
26 -
27 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// NOTE: no native object detectors are currently provided or used by the code
17 -// in this directory. This class remains mainly for historical reasons.
18 -// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
19 -
20 -// Defines the ObjectDetector class that is the main interface for detecting
21 -// ObjectModelBases in frames.
22 -
23 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
24 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
25 -
26 -#include <float.h>
27 -#include <map>
28 -#include <memory>
29 -#include <sstream>
30 -#include <string>
31 -#include <vector>
32 -
33 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
34 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
35 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
36 -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
37 -#ifdef __RENDER_OPENGL__
38 -#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
39 -#endif
40 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
41 -
42 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
43 -#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
44 -#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
45 -
46 -namespace tf_tracking {
47 -
48 -// Adds BoundingSquares to a vector such that the first square added is centered
49 -// in the position given and of square_size, and the remaining squares are added
50 -// concentrentically, scaling down by scale_factor until the minimum threshold
51 -// size is passed.
52 -// Squares that do not fall completely within image_bounds will not be added.
53 -static inline void FillWithSquares(
54 - const BoundingBox& image_bounds,
55 - const BoundingBox& position,
56 - const float starting_square_size,
57 - const float smallest_square_size,
58 - const float scale_factor,
59 - std::vector<BoundingSquare>* const squares) {
60 - BoundingSquare descriptor_area =
61 - GetCenteredSquare(position, starting_square_size);
62 -
63 - SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);
64 -
65 - // Use a do/while loop to ensure that at least one descriptor is created.
66 - do {
67 - if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
68 - squares->push_back(descriptor_area);
69 - }
70 - descriptor_area.Scale(scale_factor);
71 - } while (descriptor_area.size_ >= smallest_square_size - EPSILON);
72 - LOGV("Created %zu squares starting from size %.2f to min size %.2f "
73 - "using scale factor: %.2f",
74 - squares->size(), starting_square_size, smallest_square_size,
75 - scale_factor);
76 -}
77 -
78 -
79 -// Represents a potential detection of a specific ObjectExemplar and Descriptor
80 -// at a specific position in the image.
81 -class Detection {
82 - public:
83 - explicit Detection(const ObjectModelBase* const object_model,
84 - const MatchScore match_score,
85 - const BoundingBox& bounding_box)
86 - : object_model_(object_model),
87 - match_score_(match_score),
88 - bounding_box_(bounding_box) {}
89 -
90 - Detection(const Detection& other)
91 - : object_model_(other.object_model_),
92 - match_score_(other.match_score_),
93 - bounding_box_(other.bounding_box_) {}
94 -
95 - virtual ~Detection() {}
96 -
97 - inline BoundingBox GetObjectBoundingBox() const {
98 - return bounding_box_;
99 - }
100 -
101 - inline MatchScore GetMatchScore() const {
102 - return match_score_;
103 - }
104 -
105 - inline const ObjectModelBase* GetObjectModel() const {
106 - return object_model_;
107 - }
108 -
109 - inline bool Intersects(const Detection& other) {
110 - // Check if any of the four axes separates us, there must be at least one.
111 - return bounding_box_.Intersects(other.bounding_box_);
112 - }
113 -
114 - struct Comp {
115 - inline bool operator()(const Detection& a, const Detection& b) const {
116 - return a.match_score_ > b.match_score_;
117 - }
118 - };
119 -
120 - // TODO(andrewharp): add accessors to update these instead.
121 - const ObjectModelBase* object_model_;
122 - MatchScore match_score_;
123 - BoundingBox bounding_box_;
124 -};
125 -
126 -inline std::ostream& operator<<(std::ostream& stream,
127 - const Detection& detection) {
128 - const BoundingBox actual_area = detection.GetObjectBoundingBox();
129 - stream << actual_area;
130 - return stream;
131 -}
132 -
133 -class ObjectDetectorBase {
134 - public:
135 - explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
136 - : config_(config),
137 - image_data_(NULL) {}
138 -
139 - virtual ~ObjectDetectorBase();
140 -
141 - // Sets the current image data. All calls to ObjectDetector other than
142 - // FillDescriptors use the image data last set.
143 - inline void SetImageData(const ImageData* const image_data) {
144 - image_data_ = image_data;
145 - }
146 -
147 - // Main entry point into the detection algorithm.
148 - // Scans the frame for candidates, tweaks them, and fills in the
149 - // given std::vector of Detection objects with acceptable matches.
150 - virtual void Detect(const std::vector<BoundingSquare>& positions,
151 - std::vector<Detection>* const detections) const = 0;
152 -
153 - virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;
154 -
155 - virtual void DeleteObjectModel(const std::string& name) = 0;
156 -
157 - virtual void GetObjectModels(
158 - std::vector<const ObjectModelBase*>* models) const = 0;
159 -
160 - // Creates a new ObjectExemplar from the given position in the context of
161 - // the last frame passed to NextFrame.
162 - // Will return null in the case that there's no room for a descriptor to be
163 - // created in the example area, or the example area is not completely
164 - // contained within the frame.
165 - virtual void UpdateModel(const Image<uint8_t>& base_image,
166 - const IntegralImage& integral_image,
167 - const BoundingBox& bounding_box, const bool locked,
168 - ObjectModelBase* model) const = 0;
169 -
170 - virtual void Draw() const = 0;
171 -
172 - virtual bool AllowSpontaneousDetections() = 0;
173 -
174 - protected:
175 - const std::unique_ptr<const ObjectDetectorConfig> config_;
176 -
177 - // The latest frame data, upon which all detections will be performed.
178 - // Not owned by this object, just provided for reference by ObjectTracker
179 - // via SetImageData().
180 - const ImageData* image_data_;
181 -
182 - private:
183 - TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
184 -};
185 -
186 -template <typename ModelType>
187 -class ObjectDetector : public ObjectDetectorBase {
188 - public:
189 - explicit ObjectDetector(const ObjectDetectorConfig* const config)
190 - : ObjectDetectorBase(config) {}
191 -
192 - virtual ~ObjectDetector() {
193 - typename std::map<std::string, ModelType*>::const_iterator it =
194 - object_models_.begin();
195 - for (; it != object_models_.end(); ++it) {
196 - ModelType* model = it->second;
197 - delete model;
198 - }
199 - }
200 -
201 - virtual void DeleteObjectModel(const std::string& name) {
202 - ModelType* model = object_models_[name];
203 - CHECK_ALWAYS(model != NULL, "Model was null!");
204 - object_models_.erase(name);
205 - SAFE_DELETE(model);
206 - }
207 -
208 - virtual void GetObjectModels(
209 - std::vector<const ObjectModelBase*>* models) const {
210 - typename std::map<std::string, ModelType*>::const_iterator it =
211 - object_models_.begin();
212 - for (; it != object_models_.end(); ++it) {
213 - models->push_back(it->second);
214 - }
215 - }
216 -
217 - virtual bool AllowSpontaneousDetections() {
218 - return false;
219 - }
220 -
221 - protected:
222 - std::map<std::string, ModelType*> object_models_;
223 -
224 - private:
225 - TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
226 -};
227 -
228 -} // namespace tf_tracking
229 -
230 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// NOTE: no native object detectors are currently provided or used by the code
17 -// in this directory. This class remains mainly for historical reasons.
18 -// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
19 -
20 -// Contains ObjectModelBase declaration.
21 -
22 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
23 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
24 -
25 -#ifdef __RENDER_OPENGL__
26 -#include <GLES/gl.h>
27 -#include <GLES/glext.h>
28 -#endif
29 -
30 -#include <vector>
31 -
32 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
33 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
34 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
35 -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
36 -#ifdef __RENDER_OPENGL__
37 -#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
38 -#endif
39 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
40 -
41 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
42 -#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
43 -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
44 -
45 -namespace tf_tracking {
46 -
47 -// The ObjectModelBase class represents all the known appearance information for
48 -// an object. It is not a specific instance of the object in the world,
49 -// but just the general appearance information that enables detection. An
50 -// ObjectModelBase can be reused across multiple-instances of TrackedObjects.
51 -class ObjectModelBase {
52 - public:
53 - ObjectModelBase(const std::string& name) : name_(name) {}
54 -
55 - virtual ~ObjectModelBase() {}
56 -
57 - // Called when the next step in an ongoing track occurs.
58 - virtual void TrackStep(const BoundingBox& position,
59 - const Image<uint8_t>& image,
60 - const IntegralImage& integral_image,
61 - const bool authoritative) {}
62 -
63 - // Called when an object track is lost.
64 - virtual void TrackLost() {}
65 -
66 - // Called when an object track is confirmed as legitimate.
67 - virtual void TrackConfirmed() {}
68 -
69 - virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0;
70 -
71 - virtual MatchScore GetMatchScore(
72 - const BoundingBox& position, const ImageData& image_data) const = 0;
73 -
74 - virtual void Draw(float* const depth) const = 0;
75 -
76 - inline const std::string& GetName() const {
77 - return name_;
78 - }
79 -
80 - protected:
81 - const std::string name_;
82 -
83 - private:
84 - TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase);
85 -};
86 -
87 -template <typename DetectorType>
88 -class ObjectModel : public ObjectModelBase {
89 - public:
90 - ObjectModel<DetectorType>(const DetectorType* const detector,
91 - const std::string& name)
92 - : ObjectModelBase(name), detector_(detector) {}
93 -
94 - protected:
95 - const DetectorType* const detector_;
96 -
97 - TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>);
98 -};
99 -
100 -} // namespace tf_tracking
101 -
102 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifdef __RENDER_OPENGL__
17 -#include <GLES/gl.h>
18 -#include <GLES/glext.h>
19 -#endif
20 -
21 -#include <cinttypes>
22 -#include <map>
23 -#include <string>
24 -
25 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
27 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
28 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
29 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
30 -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
31 -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
32 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
33 -#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
34 -#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
35 -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
36 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
37 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
38 -
39 -namespace tf_tracking {
40 -
41 -ObjectTracker::ObjectTracker(const TrackerConfig* const config,
42 - ObjectDetectorBase* const detector)
43 - : config_(config),
44 - frame_width_(config->image_size.width),
45 - frame_height_(config->image_size.height),
46 - curr_time_(0),
47 - num_frames_(0),
48 - flow_cache_(&config->flow_config),
49 - keypoint_detector_(&config->keypoint_detector_config),
50 - curr_num_frame_pairs_(0),
51 - first_frame_index_(0),
52 - frame1_(new ImageData(frame_width_, frame_height_)),
53 - frame2_(new ImageData(frame_width_, frame_height_)),
54 - detector_(detector),
55 - num_detected_(0) {
56 - for (int i = 0; i < kNumFrames; ++i) {
57 - frame_pairs_[i].Init(-1, -1);
58 - }
59 -}
60 -
61 -
62 -ObjectTracker::~ObjectTracker() {
63 - for (TrackedObjectMap::iterator iter = objects_.begin();
64 - iter != objects_.end(); iter++) {
65 - TrackedObject* object = iter->second;
66 - SAFE_DELETE(object);
67 - }
68 -}
69 -
70 -
71 -// Finds the correspondences for all the points in the current pair of frames.
72 -// Stores the results in the given FramePair.
73 -void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
74 - // Keypoints aren't found until they're found.
75 - memset(frame_pair->optical_flow_found_keypoint_, false,
76 - sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
77 - TimeLog("Cleared old found keypoints");
78 -
79 - int num_keypoints_found = 0;
80 -
81 - // For every keypoint...
82 - for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
83 - Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
84 - Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
85 -
86 - if (flow_cache_.FindNewPositionOfPoint(
87 - keypoint1->pos_.x, keypoint1->pos_.y,
88 - &keypoint2->pos_.x, &keypoint2->pos_.y)) {
89 - frame_pair->optical_flow_found_keypoint_[i_feat] = true;
90 - ++num_keypoints_found;
91 - }
92 - }
93 -
94 - TimeLog("Found correspondences");
95 -
96 - LOGV("Found %d of %d keypoint correspondences",
97 - num_keypoints_found, frame_pair->number_of_keypoints_);
98 -}
99 -
100 -void ObjectTracker::NextFrame(const uint8_t* const new_frame,
101 - const uint8_t* const uv_frame,
102 - const int64_t timestamp,
103 - const float* const alignment_matrix_2x3) {
104 - IncrementFrameIndex();
105 - LOGV("Received frame %d", num_frames_);
106 -
107 - FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
108 - curr_change->Init(curr_time_, timestamp);
109 -
110 - CHECK_ALWAYS(curr_time_ < timestamp,
111 - "Timestamp must monotonically increase! Went from %" PRId64
112 - " to %" PRId64 " on frame %d.",
113 - curr_time_, timestamp, num_frames_);
114 -
115 - curr_time_ = timestamp;
116 -
117 - // Swap the frames.
118 - frame1_.swap(frame2_);
119 -
120 - frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
121 -
122 - if (detector_.get() != NULL) {
123 - detector_->SetImageData(frame2_.get());
124 - }
125 -
126 - flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
127 -
128 - if (num_frames_ == 1) {
129 - // This must be the first frame, so abort.
130 - return;
131 - }
132 -
133 - if (config_->always_track || objects_.size() > 0) {
134 - LOGV("Tracking %zu targets", objects_.size());
135 - ComputeKeypoints(true);
136 - TimeLog("Keypoints computed!");
137 -
138 - FindCorrespondences(curr_change);
139 - TimeLog("Flow computed!");
140 -
141 - TrackObjects();
142 - }
143 - TimeLog("Targets tracked!");
144 -
145 - if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
146 - DetectTargets();
147 - }
148 - TimeLog("Detected objects.");
149 -}
150 -
151 -TrackedObject* ObjectTracker::MaybeAddObject(
152 - const std::string& id, const Image<uint8_t>& source_image,
153 - const BoundingBox& bounding_box, const ObjectModelBase* object_model) {
154 - // Train the detector if this is a new object.
155 - if (objects_.find(id) != objects_.end()) {
156 - return objects_[id];
157 - }
158 -
159 - // Need to get a non-const version of the model, or create a new one if it
160 - // wasn't given.
161 - ObjectModelBase* model = NULL;
162 - if (detector_ != NULL) {
163 - // If a detector is registered, then this new object must have a model.
164 - CHECK_ALWAYS(object_model != NULL, "No model given!");
165 - model = detector_->CreateObjectModel(object_model->GetName());
166 - }
167 - TrackedObject* const object =
168 - new TrackedObject(id, source_image, bounding_box, model);
169 -
170 - objects_[id] = object;
171 - return object;
172 -}
173 -
174 -void ObjectTracker::RegisterNewObjectWithAppearance(
175 - const std::string& id, const uint8_t* const new_frame,
176 - const BoundingBox& bounding_box) {
177 - ObjectModelBase* object_model = NULL;
178 -
179 - Image<uint8_t> image(frame_width_, frame_height_);
180 - image.FromArray(new_frame, frame_width_, 1);
181 -
182 - if (detector_ != NULL) {
183 - object_model = detector_->CreateObjectModel(id);
184 - CHECK_ALWAYS(object_model != NULL, "Null object model!");
185 -
186 - const IntegralImage integral_image(image);
187 - object_model->TrackStep(bounding_box, image, integral_image, true);
188 - }
189 -
190 - // Create an object at this position.
191 - CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
192 - if (objects_.find(id) == objects_.end()) {
193 - TrackedObject* const object =
194 - MaybeAddObject(id, image, bounding_box, object_model);
195 - CHECK_ALWAYS(object != NULL, "Object not created!");
196 - }
197 -}
198 -
199 -void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
200 - const BoundingBox& bounding_box,
201 - const int64_t timestamp) {
202 - CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
203 - CHECK_ALWAYS(timestamp <= curr_time_,
204 - "Timestamp too great! %" PRId64 " vs %" PRId64, timestamp,
205 - curr_time_);
206 -
207 - TrackedObject* const object = GetObject(id);
208 -
209 - // Track this bounding box from the past to the current time.
210 - const BoundingBox current_position = TrackBox(bounding_box, timestamp);
211 -
212 - object->UpdatePosition(current_position, curr_time_, *frame2_, false);
213 -
214 - VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
215 - << std::endl;
216 -}
217 -
218 -
219 -void ObjectTracker::SetCurrentPositionOfObject(
220 - const std::string& id, const BoundingBox& bounding_box) {
221 - SetPreviousPositionOfObject(id, bounding_box, curr_time_);
222 -}
223 -
224 -
225 -void ObjectTracker::ForgetTarget(const std::string& id) {
226 - LOGV("Forgetting object %s", id.c_str());
227 - TrackedObject* const object = GetObject(id);
228 - delete object;
229 - objects_.erase(id);
230 -
231 - if (detector_ != NULL) {
232 - detector_->DeleteObjectModel(id);
233 - }
234 -}
235 -
236 -int ObjectTracker::GetKeypointsPacked(uint16_t* const out_data,
237 - const float scale) const {
238 - const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
239 - uint16_t* curr_data = out_data;
240 - int num_keypoints = 0;
241 -
242 - for (int i = 0; i < change.number_of_keypoints_; ++i) {
243 - if (change.optical_flow_found_keypoint_[i]) {
244 - ++num_keypoints;
245 - const Point2f& point1 = change.frame1_keypoints_[i].pos_;
246 - *curr_data++ = RealToFixed115(point1.x * scale);
247 - *curr_data++ = RealToFixed115(point1.y * scale);
248 -
249 - const Point2f& point2 = change.frame2_keypoints_[i].pos_;
250 - *curr_data++ = RealToFixed115(point2.x * scale);
251 - *curr_data++ = RealToFixed115(point2.y * scale);
252 - }
253 - }
254 -
255 - return num_keypoints;
256 -}
257 -
258 -
259 -int ObjectTracker::GetKeypoints(const bool only_found,
260 - float* const out_data) const {
261 - int curr_keypoint = 0;
262 - const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
263 -
264 - for (int i = 0; i < change.number_of_keypoints_; ++i) {
265 - if (!only_found || change.optical_flow_found_keypoint_[i]) {
266 - const int base = curr_keypoint * kKeypointStep;
267 - out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
268 - out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
269 -
270 - out_data[base + 2] =
271 - change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
272 - out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
273 - out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
274 -
275 - out_data[base + 5] = change.frame1_keypoints_[i].score_;
276 - out_data[base + 6] = change.frame1_keypoints_[i].type_;
277 - ++curr_keypoint;
278 - }
279 - }
280 -
281 - LOGV("Got %d keypoints.", curr_keypoint);
282 -
283 - return curr_keypoint;
284 -}
285 -
286 -
287 -BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
288 - const FramePair& frame_pair) const {
289 - float translation_x;
290 - float translation_y;
291 -
292 - float scale_x;
293 - float scale_y;
294 -
295 - BoundingBox tracked_box(region);
296 - frame_pair.AdjustBox(
297 - tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
298 -
299 - tracked_box.Shift(Point2f(translation_x, translation_y));
300 -
301 - if (scale_x > 0 && scale_y > 0) {
302 - tracked_box.Scale(scale_x, scale_y);
303 - }
304 - return tracked_box;
305 -}
306 -
307 -BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
308 - const int64_t timestamp) const {
309 - CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
310 - CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
311 -
312 - // Anything that ended before the requested timestamp is of no concern to us.
313 - bool found_it = false;
314 - int num_frames_back = -1;
315 - for (int i = 0; i < curr_num_frame_pairs_; ++i) {
316 - const FramePair& frame_pair =
317 - frame_pairs_[GetNthIndexFromEnd(i)];
318 -
319 - if (frame_pair.end_time_ <= timestamp) {
320 - num_frames_back = i - 1;
321 -
322 - if (num_frames_back > 0) {
323 - LOGV("Went %d out of %d frames before finding frame. (index: %d)",
324 - num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
325 - }
326 -
327 - found_it = true;
328 - break;
329 - }
330 - }
331 -
332 - if (!found_it) {
333 - LOGW("History did not go back far enough! %" PRId64 " vs %" PRId64,
334 - frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
335 - frame_pairs_[GetNthIndexFromStart(0)].end_time_,
336 - frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
337 - }
338 -
339 - // Loop over all the frames in the queue, tracking the accumulated delta
340 - // of the point from frame to frame. It's possible the point could
341 - // go out of frame, but keep tracking as best we can, using points near
342 - // the edge of the screen where it went out of bounds.
343 - BoundingBox tracked_box(region);
344 - for (int i = num_frames_back; i >= 0; --i) {
345 - const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
346 - SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
347 - tracked_box = TrackBox(tracked_box, frame_pair);
348 - }
349 - return tracked_box;
350 -}
351 -
352 -
353 -// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
354 -// 3d transformation matrix.
355 -inline void Convert3x3To4x4(
356 - const float* const in_matrix, float* const out_matrix) {
357 - // X
358 - out_matrix[0] = in_matrix[0];
359 - out_matrix[1] = in_matrix[3];
360 - out_matrix[2] = 0.0f;
361 - out_matrix[3] = 0.0f;
362 -
363 - // Y
364 - out_matrix[4] = in_matrix[1];
365 - out_matrix[5] = in_matrix[4];
366 - out_matrix[6] = 0.0f;
367 - out_matrix[7] = 0.0f;
368 -
369 - // Z
370 - out_matrix[8] = 0.0f;
371 - out_matrix[9] = 0.0f;
372 - out_matrix[10] = 1.0f;
373 - out_matrix[11] = 0.0f;
374 -
375 - // Translation
376 - out_matrix[12] = in_matrix[2];
377 - out_matrix[13] = in_matrix[5];
378 - out_matrix[14] = 0.0f;
379 - out_matrix[15] = 1.0f;
380 -}
381 -
382 -
383 -void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
384 - const float* const frame_to_canvas) const {
385 -#ifdef __RENDER_OPENGL__
386 - glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
387 -
388 - glMatrixMode(GL_PROJECTION);
389 - glLoadIdentity();
390 -
391 - glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
392 -
393 - // To make Y go the right direction (0 at top of frame).
394 - glScalef(1.0f, -1.0f, 1.0f);
395 - glTranslatef(0.0f, -canvas_height, 0.0f);
396 -
397 - glMatrixMode(GL_MODELVIEW);
398 - glLoadIdentity();
399 -
400 - glPushMatrix();
401 -
402 - // Apply the frame to canvas transformation.
403 - static GLfloat transformation[16];
404 - Convert3x3To4x4(frame_to_canvas, transformation);
405 - glMultMatrixf(transformation);
406 -
407 - // Draw tracked object bounding boxes.
408 - for (TrackedObjectMap::const_iterator iter = objects_.begin();
409 - iter != objects_.end(); ++iter) {
410 - TrackedObject* tracked_object = iter->second;
411 - tracked_object->Draw();
412 - }
413 -
414 - static const bool kRenderDebugPyramid = false;
415 - if (kRenderDebugPyramid) {
416 - glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
417 - for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
418 - Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
419 - }
420 - }
421 -
422 - static const bool kRenderDebugDerivative = false;
423 - if (kRenderDebugDerivative) {
424 - glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
425 - for (int i = 0; i < kNumPyramidLevels; ++i) {
426 - const Image<int32_t>& dx = *frame1_->GetSpatialX(i);
427 - Image<uint8_t> render_image(dx.GetWidth(), dx.GetHeight());
428 - for (int y = 0; y < dx.GetHeight(); ++y) {
429 - const int32_t* dx_ptr = dx[y];
430 - uint8_t* dst_ptr = render_image[y];
431 - for (int x = 0; x < dx.GetWidth(); ++x) {
432 - *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
433 - }
434 - }
435 -
436 - Sprite(render_image).Draw();
437 - }
438 - }
439 -
440 - if (detector_ != NULL) {
441 - glDisable(GL_CULL_FACE);
442 - detector_->Draw();
443 - }
444 - glPopMatrix();
445 -#endif
446 -}
447 -
448 -static void AddQuadrants(const BoundingBox& box,
449 - std::vector<BoundingBox>* boxes) {
450 - const Point2f center = box.GetCenter();
451 -
452 - float x1 = box.left_;
453 - float x2 = center.x;
454 - float x3 = box.right_;
455 -
456 - float y1 = box.top_;
457 - float y2 = center.y;
458 - float y3 = box.bottom_;
459 -
460 - // Upper left.
461 - boxes->push_back(BoundingBox(x1, y1, x2, y2));
462 -
463 - // Upper right.
464 - boxes->push_back(BoundingBox(x2, y1, x3, y2));
465 -
466 - // Bottom left.
467 - boxes->push_back(BoundingBox(x1, y2, x2, y3));
468 -
469 - // Bottom right.
470 - boxes->push_back(BoundingBox(x2, y2, x3, y3));
471 -
472 - // Whole thing.
473 - boxes->push_back(box);
474 -}
475 -
476 -void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
477 - const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
478 - FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
479 -
480 - std::vector<BoundingBox> boxes;
481 -
482 - for (TrackedObjectMap::iterator object_iter = objects_.begin();
483 - object_iter != objects_.end(); ++object_iter) {
484 - BoundingBox box = object_iter->second->GetPosition();
485 - box.Scale(config_->object_box_scale_factor_for_features,
486 - config_->object_box_scale_factor_for_features);
487 - AddQuadrants(box, &boxes);
488 - }
489 -
490 - AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
491 -
492 - keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
493 -}
494 -
495 -
496 -// Given a vector of detections and a model, simply returns the Detection for
497 -// that model with the highest correlation.
498 -bool ObjectTracker::GetBestObjectForDetection(
499 - const Detection& detection, TrackedObject** match) const {
500 - TrackedObject* best_match = NULL;
501 - float best_overlap = -FLT_MAX;
502 -
503 - LOGV("Looking for matches in %zu objects!", objects_.size());
504 - for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
505 - object_iter != objects_.end(); ++object_iter) {
506 - TrackedObject* const tracked_object = object_iter->second;
507 -
508 - const float overlap = tracked_object->GetPosition().PascalScore(
509 - detection.GetObjectBoundingBox());
510 -
511 - if (!detector_->AllowSpontaneousDetections() &&
512 - (detection.GetObjectModel() != tracked_object->GetModel())) {
513 - if (overlap > 0.0f) {
514 - return false;
515 - }
516 - continue;
517 - }
518 -
519 - const float jump_distance =
520 - (tracked_object->GetPosition().GetCenter() -
521 - detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
522 -
523 - const float allowed_distance =
524 - tracked_object->GetAllowableDistanceSquared();
525 -
526 - LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
527 - jump_distance, allowed_distance, overlap);
528 -
529 - // TODO(andrewharp): No need to do this verification twice, eliminate
530 - // one of the score checks (the other being in OnDetection).
531 - if (jump_distance < allowed_distance &&
532 - overlap > best_overlap &&
533 - tracked_object->GetMatchScore() + kMatchScoreBuffer <
534 - detection.GetMatchScore()) {
535 - best_match = tracked_object;
536 - best_overlap = overlap;
537 - } else if (overlap > 0.0f) {
538 - return false;
539 - }
540 - }
541 -
542 - *match = best_match;
543 - return true;
544 -}
545 -
546 -
547 -void ObjectTracker::ProcessDetections(
548 - std::vector<Detection>* const detections) {
549 - LOGV("Initial detection done, iterating over %zu detections now.",
550 - detections->size());
551 -
552 - const bool spontaneous_detections_allowed =
553 - detector_->AllowSpontaneousDetections();
554 - for (std::vector<Detection>::const_iterator it = detections->begin();
555 - it != detections->end(); ++it) {
556 - const Detection& detection = *it;
557 - SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
558 - "Frame does not contain bounding box!");
559 -
560 - TrackedObject* best_match = NULL;
561 -
562 - const bool no_collisions =
563 - GetBestObjectForDetection(detection, &best_match);
564 -
565 - // Need to get a non-const version of the model, or create a new one if it
566 - // wasn't given.
567 - ObjectModelBase* model =
568 - const_cast<ObjectModelBase*>(detection.GetObjectModel());
569 -
570 - if (best_match != NULL) {
571 - if (model != best_match->GetModel()) {
572 - CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
573 - "Model for object changed but spontaneous detections not allowed!");
574 - }
575 - best_match->OnDetection(model,
576 - detection.GetObjectBoundingBox(),
577 - detection.GetMatchScore(),
578 - curr_time_, *frame2_);
579 - } else if (no_collisions && spontaneous_detections_allowed) {
580 - if (detection.GetMatchScore() > kMinimumMatchScore) {
581 - LOGV("No match, adding it!");
582 - const ObjectModelBase* model = detection.GetObjectModel();
583 - std::ostringstream ss;
584 - // TODO(andrewharp): Generate this in a more general fashion.
585 - ss << "hand_" << num_detected_++;
586 - std::string object_name = ss.str();
587 - MaybeAddObject(object_name, *frame2_->GetImage(),
588 - detection.GetObjectBoundingBox(), model);
589 - }
590 - }
591 - }
592 -}
593 -
594 -
595 -void ObjectTracker::DetectTargets() {
596 - // Detect all object model types that we're currently tracking.
597 - std::vector<const ObjectModelBase*> object_models;
598 - detector_->GetObjectModels(&object_models);
599 - if (object_models.size() == 0) {
600 - LOGV("No objects to search for, aborting.");
601 - return;
602 - }
603 -
604 - LOGV("Trying to detect %zu models", object_models.size());
605 -
606 - LOGV("Creating test vector!");
607 - std::vector<BoundingSquare> positions;
608 -
609 - for (TrackedObjectMap::iterator object_iter = objects_.begin();
610 - object_iter != objects_.end(); ++object_iter) {
611 - TrackedObject* const tracked_object = object_iter->second;
612 -
613 -#if DEBUG_PREDATOR
614 - positions.push_back(GetCenteredSquare(
615 - frame2_->GetImage()->GetContainingBox(), 32.0f));
616 -#else
617 - const BoundingBox& position = tracked_object->GetPosition();
618 -
619 - const float square_size = MAX(
620 - kScanMinSquareSize / (kLastKnownPositionScaleFactor *
621 - kLastKnownPositionScaleFactor),
622 - MIN(position.GetWidth(),
623 - position.GetHeight())) / kLastKnownPositionScaleFactor;
624 -
625 - FillWithSquares(frame2_->GetImage()->GetContainingBox(),
626 - tracked_object->GetPosition(),
627 - square_size,
628 - kScanMinSquareSize,
629 - kLastKnownPositionScaleFactor,
630 - &positions);
631 - }
632 -#endif
633 -
634 - LOGV("Created test vector!");
635 -
636 - std::vector<Detection> detections;
637 - LOGV("Detecting!");
638 - detector_->Detect(positions, &detections);
639 - LOGV("Found %zu detections", detections.size());
640 -
641 - TimeLog("Finished detection.");
642 -
643 - ProcessDetections(&detections);
644 -
645 - TimeLog("iterated over detections");
646 -
647 - LOGV("Done detecting!");
648 -}
649 -
650 -
651 -void ObjectTracker::TrackObjects() {
652 - // TODO(andrewharp): Correlation should be allowed to remove objects too.
653 - const bool automatic_removal_allowed = detector_.get() != NULL ?
654 - detector_->AllowSpontaneousDetections() : false;
655 -
656 - LOGV("Tracking %zu objects!", objects_.size());
657 - std::vector<std::string> dead_objects;
658 - for (TrackedObjectMap::iterator iter = objects_.begin();
659 - iter != objects_.end(); iter++) {
660 - TrackedObject* object = iter->second;
661 - const BoundingBox tracked_position = TrackBox(
662 - object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
663 - object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
664 -
665 - if (automatic_removal_allowed &&
666 - object->GetNumConsecutiveFramesBelowThreshold() >
667 - kMaxNumDetectionFailures * 5) {
668 - dead_objects.push_back(iter->first);
669 - }
670 - }
671 -
672 - if (detector_ != NULL && automatic_removal_allowed) {
673 - for (std::vector<std::string>::iterator iter = dead_objects.begin();
674 - iter != dead_objects.end(); iter++) {
675 - LOGE("Removing object! %s", iter->c_str());
676 - ForgetTarget(*iter);
677 - }
678 - }
679 - TimeLog("Tracked all objects.");
680 -
681 - LOGV("%zu objects tracked!", objects_.size());
682 -}
683 -
684 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
18 -
19 -#include <map>
20 -#include <string>
21 -
22 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
23 -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
27 -
28 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
29 -#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
30 -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
31 -#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
32 -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
33 -#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
34 -
35 -namespace tf_tracking {
36 -
37 -typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
38 -
39 -inline std::ostream& operator<<(std::ostream& stream,
40 - const TrackedObjectMap& map) {
41 - for (TrackedObjectMap::const_iterator iter = map.begin();
42 - iter != map.end(); ++iter) {
43 - const TrackedObject& tracked_object = *iter->second;
44 - const std::string& key = iter->first;
45 - stream << key << ": " << tracked_object;
46 - }
47 - return stream;
48 -}
49 -
50 -
51 -// ObjectTracker is the highest-level class in the tracking/detection framework.
52 -// It handles basic image processing, keypoint detection, keypoint tracking,
53 -// object tracking, and object detection/relocalization.
54 -class ObjectTracker {
55 - public:
56 - ObjectTracker(const TrackerConfig* const config,
57 - ObjectDetectorBase* const detector);
58 - virtual ~ObjectTracker();
59 -
60 - virtual void NextFrame(const uint8_t* const new_frame,
61 - const int64_t timestamp,
62 - const float* const alignment_matrix_2x3) {
63 - NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
64 - }
65 -
66 - // Called upon the arrival of a new frame of raw data.
67 - // Does all image processing, keypoint detection, and object
68 - // tracking/detection for registered objects.
69 - // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
70 - // represents the main transformation that has happened between the last
71 - // and the current frame.
72 - // Argument align_level is the pyramid level (where 0 == finest) that
73 - // the matrix is valid for.
74 - virtual void NextFrame(const uint8_t* const new_frame,
75 - const uint8_t* const uv_frame, const int64_t timestamp,
76 - const float* const alignment_matrix_2x3);
77 -
78 - virtual void RegisterNewObjectWithAppearance(const std::string& id,
79 - const uint8_t* const new_frame,
80 - const BoundingBox& bounding_box);
81 -
82 - // Updates the position of a tracked object, given that it was known to be at
83 - // a certain position at some point in the past.
84 - virtual void SetPreviousPositionOfObject(const std::string& id,
85 - const BoundingBox& bounding_box,
86 - const int64_t timestamp);
87 -
88 - // Sets the current position of the object in the most recent frame provided.
89 - virtual void SetCurrentPositionOfObject(const std::string& id,
90 - const BoundingBox& bounding_box);
91 -
92 - // Tells the ObjectTracker to stop tracking a target.
93 - void ForgetTarget(const std::string& id);
94 -
95 - // Fills the given out_data buffer with the latest detected keypoint
96 - // correspondences, first scaled by scale_factor (to adjust for downsampling
97 - // that may have occurred elsewhere), then packed in a fixed-point format.
98 - int GetKeypointsPacked(uint16_t* const out_data,
99 - const float scale_factor) const;
100 -
101 - // Copy the keypoint arrays after computeFlow is called.
102 - // out_data should be at least kMaxKeypoints * kKeypointStep long.
103 - // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
104 - // where N is the number of keypoints tracked. N is returned as the result.
105 - int GetKeypoints(const bool only_found, float* const out_data) const;
106 -
107 - // Returns the current position of a box, given that it was at a certain
108 - // position at the given time.
109 - BoundingBox TrackBox(const BoundingBox& region,
110 - const int64_t timestamp) const;
111 -
112 - // Returns the number of frames that have been passed to NextFrame().
113 - inline int GetNumFrames() const {
114 - return num_frames_;
115 - }
116 -
117 - inline bool HaveObject(const std::string& id) const {
118 - return objects_.find(id) != objects_.end();
119 - }
120 -
121 - // Returns the TrackedObject associated with the given id.
122 - inline const TrackedObject* GetObject(const std::string& id) const {
123 - TrackedObjectMap::const_iterator iter = objects_.find(id);
124 - CHECK_ALWAYS(iter != objects_.end(),
125 - "Unknown object key! \"%s\"", id.c_str());
126 - TrackedObject* const object = iter->second;
127 - return object;
128 - }
129 -
130 - // Returns the TrackedObject associated with the given id.
131 - inline TrackedObject* GetObject(const std::string& id) {
132 - TrackedObjectMap::iterator iter = objects_.find(id);
133 - CHECK_ALWAYS(iter != objects_.end(),
134 - "Unknown object key! \"%s\"", id.c_str());
135 - TrackedObject* const object = iter->second;
136 - return object;
137 - }
138 -
139 - bool IsObjectVisible(const std::string& id) const {
140 - SCHECK(HaveObject(id), "Don't have this object.");
141 -
142 - const TrackedObject* object = GetObject(id);
143 - return object->IsVisible();
144 - }
145 -
146 - virtual void Draw(const int canvas_width, const int canvas_height,
147 - const float* const frame_to_canvas) const;
148 -
149 - protected:
150 - // Creates a new tracked object at the given position.
151 - // If an object model is provided, then that model will be associated with the
152 - // object. If not, a new model may be created from the appearance at the
153 - // initial position and registered with the object detector.
154 - virtual TrackedObject* MaybeAddObject(const std::string& id,
155 - const Image<uint8_t>& image,
156 - const BoundingBox& bounding_box,
157 - const ObjectModelBase* object_model);
158 -
159 - // Find the keypoints in the frame before the current frame.
160 - // If only one frame exists, keypoints will be found in that frame.
161 - void ComputeKeypoints(const bool cached_ok = false);
162 -
163 - // Finds the correspondences for all the points in the current pair of frames.
164 - // Stores the results in the given FramePair.
165 - void FindCorrespondences(FramePair* const curr_change) const;
166 -
167 - inline int GetNthIndexFromEnd(const int offset) const {
168 - return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
169 - }
170 -
171 - BoundingBox TrackBox(const BoundingBox& region,
172 - const FramePair& frame_pair) const;
173 -
174 - inline void IncrementFrameIndex() {
175 - // Move the current framechange index up.
176 - ++num_frames_;
177 - ++curr_num_frame_pairs_;
178 -
179 - // If we've got too many, push up the start of the queue.
180 - if (curr_num_frame_pairs_ > kNumFrames) {
181 - first_frame_index_ = GetNthIndexFromStart(1);
182 - --curr_num_frame_pairs_;
183 - }
184 - }
185 -
186 - inline int GetNthIndexFromStart(const int offset) const {
187 - SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
188 - "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_);
189 - return (first_frame_index_ + offset) % kNumFrames;
190 - }
191 -
192 - void TrackObjects();
193 -
194 - const std::unique_ptr<const TrackerConfig> config_;
195 -
196 - const int frame_width_;
197 - const int frame_height_;
198 -
199 - int64_t curr_time_;
200 -
201 - int num_frames_;
202 -
203 - TrackedObjectMap objects_;
204 -
205 - FlowCache flow_cache_;
206 -
207 - KeypointDetector keypoint_detector_;
208 -
209 - int curr_num_frame_pairs_;
210 - int first_frame_index_;
211 -
212 - std::unique_ptr<ImageData> frame1_;
213 - std::unique_ptr<ImageData> frame2_;
214 -
215 - FramePair frame_pairs_[kNumFrames];
216 -
217 - std::unique_ptr<ObjectDetectorBase> detector_;
218 -
219 - int num_detected_;
220 -
221 - private:
222 - void TrackTarget(TrackedObject* const object);
223 -
224 - bool GetBestObjectForDetection(
225 - const Detection& detection, TrackedObject** match) const;
226 -
227 - void ProcessDetections(std::vector<Detection>* const detections);
228 -
229 - void DetectTargets();
230 -
231 - // Temp object used in ObjectTracker::CreateNewExample.
232 - mutable std::vector<BoundingSquare> squares;
233 -
234 - friend std::ostream& operator<<(std::ostream& stream,
235 - const ObjectTracker& tracker);
236 -
237 - TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
238 -};
239 -
240 -inline std::ostream& operator<<(std::ostream& stream,
241 - const ObjectTracker& tracker) {
242 - stream << "Frame size: " << tracker.frame_width_ << "x"
243 - << tracker.frame_height_ << std::endl;
244 -
245 - stream << "Num frames: " << tracker.num_frames_ << std::endl;
246 -
247 - stream << "Curr time: " << tracker.curr_time_ << std::endl;
248 -
249 - const int first_frame_index = tracker.GetNthIndexFromStart(0);
250 - const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
251 -
252 - const int last_frame_index = tracker.GetNthIndexFromEnd(0);
253 - const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
254 -
255 - stream << "first frame: " << first_frame_index << ","
256 - << first_frame_pair.end_time_ << " "
257 - << "last frame: " << last_frame_index << ","
258 - << last_frame_pair.end_time_ << " diff: "
259 - << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
260 - << std::endl;
261 -
262 - stream << "Tracked targets:";
263 - stream << tracker.objects_;
264 -
265 - return stream;
266 -}
267 -
268 -} // namespace tf_tracking
269 -
270 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#include <android/log.h>
17 -#include <jni.h>
18 -#include <stdint.h>
19 -#include <stdlib.h>
20 -#include <string.h>
21 -#include <cstdint>
22 -
23 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
27 -
28 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
29 -#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
30 -
31 -namespace tf_tracking {
32 -
33 -#define OBJECT_TRACKER_METHOD(METHOD_NAME) \
34 - Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT
35 -
36 -JniLongField object_tracker_field("nativeObjectTracker");
37 -
38 -ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
39 - ObjectTracker* const object_tracker =
40 - reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
41 - CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
42 - return object_tracker;
43 -}
44 -
45 -void set_object_tracker(JNIEnv* env, jobject thiz,
46 - const ObjectTracker* object_tracker) {
47 - object_tracker_field.set(env, thiz,
48 - reinterpret_cast<intptr_t>(object_tracker));
49 -}
50 -
51 -#ifdef __cplusplus
52 -extern "C" {
53 -#endif
54 -JNIEXPORT
55 -void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
56 - jint width, jint height,
57 - jboolean always_track);
58 -
59 -JNIEXPORT
60 -void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
61 - jobject thiz);
62 -
63 -JNIEXPORT
64 -void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
65 - JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
66 - jfloat x2, jfloat y2, jbyteArray frame_data);
67 -
68 -JNIEXPORT
69 -void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
70 - JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
71 - jfloat x2, jfloat y2, jlong timestamp);
72 -
73 -JNIEXPORT
74 -void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
75 - JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
76 - jfloat x2, jfloat y2);
77 -
78 -JNIEXPORT
79 -jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
80 - jstring object_id);
81 -
82 -JNIEXPORT
83 -jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
84 - jobject thiz,
85 - jstring object_id);
86 -
87 -JNIEXPORT
88 -jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
89 - jobject thiz,
90 - jstring object_id);
91 -
92 -JNIEXPORT
93 -jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
94 - jobject thiz,
95 - jstring object_id);
96 -
97 -JNIEXPORT
98 -jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
99 - jstring object_id);
100 -
101 -JNIEXPORT
102 -void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
103 - JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
104 -
105 -JNIEXPORT
106 -void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
107 - jbyteArray y_data,
108 - jbyteArray uv_data,
109 - jlong timestamp,
110 - jfloatArray vg_matrix_2x3);
111 -
112 -JNIEXPORT
113 -void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
114 - jstring object_id);
115 -
116 -JNIEXPORT
117 -jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
118 - JNIEnv* env, jobject thiz, jfloat scale_factor);
119 -
120 -JNIEXPORT
121 -jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
122 - JNIEnv* env, jobject thiz, jboolean only_found_);
123 -
124 -JNIEXPORT
125 -void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
126 - JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
127 - jfloat position_y1, jfloat position_x2, jfloat position_y2,
128 - jfloatArray delta);
129 -
130 -JNIEXPORT
131 -void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
132 - jint view_width,
133 - jint view_height,
134 - jfloatArray delta);
135 -
136 -JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
137 - JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
138 - jbyteArray input, jint factor, jbyteArray output);
139 -
140 -#ifdef __cplusplus
141 -}
142 -#endif
143 -
144 -JNIEXPORT
145 -void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
146 - jint width, jint height,
147 - jboolean always_track) {
148 - LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
149 - const Size image_size(width, height);
150 - TrackerConfig* const tracker_config = new TrackerConfig(image_size);
151 - tracker_config->always_track = always_track;
152 -
153 - // XXX detector
154 - ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
155 - set_object_tracker(env, thiz, tracker);
156 - LOGI("Initialized!");
157 -
158 - CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
159 - "Failure to set hand tracker!");
160 -}
161 -
162 -JNIEXPORT
163 -void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
164 - jobject thiz) {
165 - delete get_object_tracker(env, thiz);
166 - set_object_tracker(env, thiz, NULL);
167 -}
168 -
169 -JNIEXPORT
170 -void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
171 - JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
172 - jfloat x2, jfloat y2, jbyteArray frame_data) {
173 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
174 -
175 - LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
176 - x2, y2);
177 -
178 - jboolean iCopied = JNI_FALSE;
179 -
180 - // Copy image into currFrame.
181 - jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
182 -
183 - BoundingBox bounding_box(x1, y1, x2, y2);
184 - get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
185 - id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box);
186 -
187 - env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
188 -
189 - env->ReleaseStringUTFChars(object_id, id_str);
190 -}
191 -
192 -JNIEXPORT
193 -void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
194 - JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
195 - jfloat x2, jfloat y2, jlong timestamp) {
196 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
197 -
198 - LOGI(
199 - "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
200 - " at time %lld",
201 - id_str, x1, y1, x2, y2, static_cast<long long>(timestamp));
202 -
203 - get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
204 - id_str, BoundingBox(x1, y1, x2, y2), timestamp);
205 -
206 - env->ReleaseStringUTFChars(object_id, id_str);
207 -}
208 -
209 -JNIEXPORT
210 -void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
211 - JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
212 - jfloat x2, jfloat y2) {
213 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
214 -
215 - LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
216 - x2, y2);
217 -
218 - get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
219 - id_str, BoundingBox(x1, y1, x2, y2));
220 -
221 - env->ReleaseStringUTFChars(object_id, id_str);
222 -}
223 -
224 -JNIEXPORT
225 -jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
226 - jstring object_id) {
227 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
228 -
229 - const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
230 - env->ReleaseStringUTFChars(object_id, id_str);
231 - return haveObject;
232 -}
233 -
234 -JNIEXPORT
235 -jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
236 - jobject thiz,
237 - jstring object_id) {
238 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
239 -
240 - const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
241 - env->ReleaseStringUTFChars(object_id, id_str);
242 - return visible;
243 -}
244 -
245 -JNIEXPORT
246 -jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
247 - jobject thiz,
248 - jstring object_id) {
249 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
250 - const TrackedObject* const object =
251 - get_object_tracker(env, thiz)->GetObject(id_str);
252 - env->ReleaseStringUTFChars(object_id, id_str);
253 - jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
254 - return model_name;
255 -}
256 -
257 -JNIEXPORT
258 -jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
259 - jobject thiz,
260 - jstring object_id) {
261 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
262 -
263 - const float correlation =
264 - get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
265 - env->ReleaseStringUTFChars(object_id, id_str);
266 - return correlation;
267 -}
268 -
269 -JNIEXPORT
270 -jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
271 - jstring object_id) {
272 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
273 -
274 - const float match_score =
275 - get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
276 - env->ReleaseStringUTFChars(object_id, id_str);
277 - return match_score;
278 -}
279 -
280 -JNIEXPORT
281 -void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
282 - JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
283 - jboolean iCopied = JNI_FALSE;
284 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
285 -
286 - const BoundingBox bounding_box =
287 - get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
288 - env->ReleaseStringUTFChars(object_id, id_str);
289 -
290 - jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
291 - bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
292 - env->ReleaseFloatArrayElements(rect_array, rect, 0);
293 -}
294 -
295 -JNIEXPORT
296 -void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
297 - jbyteArray y_data,
298 - jbyteArray uv_data,
299 - jlong timestamp,
300 - jfloatArray vg_matrix_2x3) {
301 - TimeLog("Starting object tracker");
302 -
303 - jboolean iCopied = JNI_FALSE;
304 -
305 - float vision_gyro_matrix_array[6];
306 - jfloat* jmat = NULL;
307 -
308 - if (vg_matrix_2x3 != NULL) {
309 - // Copy the alignment matrix into a float array.
310 - jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
311 - for (int i = 0; i < 6; ++i) {
312 - vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
313 - }
314 - }
315 - // Copy image into currFrame.
316 - jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
317 - jbyte* uv_pixels =
318 - uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
319 -
320 - TimeLog("Got elements");
321 -
322 - // Add the frame to the object tracker object.
323 - get_object_tracker(env, thiz)->NextFrame(
324 - reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels),
325 - timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
326 -
327 - env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
328 -
329 - if (uv_data != NULL) {
330 - env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
331 - }
332 -
333 - if (vg_matrix_2x3 != NULL) {
334 - env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
335 - }
336 -
337 - TimeLog("Released elements");
338 -
339 - PrintTimeLog();
340 - ResetTimeLog();
341 -}
342 -
343 -JNIEXPORT
344 -void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
345 - jstring object_id) {
346 - const char* const id_str = env->GetStringUTFChars(object_id, 0);
347 -
348 - get_object_tracker(env, thiz)->ForgetTarget(id_str);
349 -
350 - env->ReleaseStringUTFChars(object_id, id_str);
351 -}
352 -
353 -JNIEXPORT
354 -jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
355 - JNIEnv* env, jobject thiz, jboolean only_found) {
356 - jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
357 -
358 - const int number_of_keypoints =
359 - get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
360 -
361 - // Create and return the array that will be passed back to Java.
362 - jfloatArray keypoints =
363 - env->NewFloatArray(number_of_keypoints * kKeypointStep);
364 - if (keypoints == NULL) {
365 - LOGE("null array!");
366 - return NULL;
367 - }
368 - env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
369 - keypoint_arr);
370 -
371 - return keypoints;
372 -}
373 -
374 -JNIEXPORT
375 -jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
376 - JNIEnv* env, jobject thiz, jfloat scale_factor) {
377 - // 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint.
378 - const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2;
379 - jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
380 -
381 - const int number_of_keypoints =
382 - get_object_tracker(env, thiz)->GetKeypointsPacked(
383 - reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor);
384 -
385 - // Create and return the array that will be passed back to Java.
386 - jbyteArray keypoints =
387 - env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
388 -
389 - if (keypoints == NULL) {
390 - LOGE("null array!");
391 - return NULL;
392 - }
393 -
394 - env->SetByteArrayRegion(
395 - keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
396 -
397 - return keypoints;
398 -}
399 -
400 -JNIEXPORT
401 -void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
402 - JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
403 - jfloat position_y1, jfloat position_x2, jfloat position_y2,
404 - jfloatArray delta) {
405 - jfloat point_arr[4];
406 -
407 - const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
408 - BoundingBox(position_x1, position_y1, position_x2, position_y2),
409 - timestamp);
410 -
411 - new_position.CopyToArray(point_arr);
412 - env->SetFloatArrayRegion(delta, 0, 4, point_arr);
413 -}
414 -
415 -JNIEXPORT
416 -void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
417 - JNIEnv* env, jobject thiz, jint view_width, jint view_height,
418 - jfloatArray frame_to_canvas_arr) {
419 - ObjectTracker* object_tracker = get_object_tracker(env, thiz);
420 - if (object_tracker != NULL) {
421 - jfloat* frame_to_canvas =
422 - env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
423 -
424 - object_tracker->Draw(view_width, view_height, frame_to_canvas);
425 - env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
426 - JNI_ABORT);
427 - }
428 -}
429 -
430 -JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
431 - JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
432 - jbyteArray input, jint factor, jbyteArray output) {
433 - if (input == NULL || output == NULL) {
434 - LOGW("Received null arrays, hopefully this is a test!");
435 - return;
436 - }
437 -
438 - jbyte* const input_array = env->GetByteArrayElements(input, 0);
439 - jbyte* const output_array = env->GetByteArrayElements(output, 0);
440 -
441 - {
442 - tf_tracking::Image<uint8_t> full_image(
443 - width, height, reinterpret_cast<uint8_t*>(input_array), false);
444 -
445 - const int new_width = (width + factor - 1) / factor;
446 - const int new_height = (height + factor - 1) / factor;
447 -
448 - tf_tracking::Image<uint8_t> downsampled_image(
449 - new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false);
450 -
451 - downsampled_image.DownsampleAveraged(
452 - reinterpret_cast<uint8_t*>(input_array), row_stride, factor);
453 - }
454 -
455 - env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
456 - env->ReleaseByteArrayElements(output, output_array, 0);
457 -}
458 -
459 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#include <math.h>
17 -
18 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
19 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
20 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
21 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 -
24 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
27 -#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
28 -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
29 -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
30 -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
31 -
32 -namespace tf_tracking {
33 -
34 -OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config)
35 - : config_(config),
36 - frame1_(NULL),
37 - frame2_(NULL),
38 - working_size_(config->image_size) {}
39 -
40 -
41 -void OpticalFlow::NextFrame(const ImageData* const image_data) {
42 - // Special case for the first frame: make sure the image ends up in
43 - // frame1_ so that keypoint detection can be done on it if desired.
44 - frame1_ = (frame1_ == NULL) ? image_data : frame2_;
45 - frame2_ = image_data;
46 -}
47 -
48 -
49 -// Static heart of the optical flow computation.
50 -// Lucas Kanade algorithm.
51 -bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8_t>& img_I,
52 - const Image<uint8_t>& img_J,
53 - const Image<int32_t>& I_x,
54 - const Image<int32_t>& I_y, const float p_x,
55 - const float p_y, float* out_g_x,
56 - float* out_g_y) {
57 - float g_x = *out_g_x;
58 - float g_y = *out_g_y;
59 - // Get values for frame 1. They remain constant through the inner
60 - // iteration loop.
61 - float vals_I[kFlowArraySize];
62 - float vals_I_x[kFlowArraySize];
63 - float vals_I_y[kFlowArraySize];
64 -
65 - const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1;
66 - const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize);
67 -
68 -#if USE_FIXED_POINT_FLOW
69 - const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1;
70 - const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1;
71 -#else
72 - const float real_x_max = I_x.width_less_one_ - EPSILON;
73 - const float real_y_max = I_x.height_less_one_ - EPSILON;
74 -#endif
75 -
76 - // Get the window around the original point.
77 - const float src_left_real = p_x - kWindowSizeFloat;
78 - const float src_top_real = p_y - kWindowSizeFloat;
79 - float* vals_I_ptr = vals_I;
80 - float* vals_I_x_ptr = vals_I_x;
81 - float* vals_I_y_ptr = vals_I_y;
82 -#if USE_FIXED_POINT_FLOW
83 - // Source integer coordinates.
84 - const int src_left_fixed = RealToFixed1616(src_left_real);
85 - const int src_top_fixed = RealToFixed1616(src_top_real);
86 -
87 - for (int y = 0; y < kPatchSize; ++y) {
88 - const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max);
89 -
90 - for (int x = 0; x < kPatchSize; ++x) {
91 - const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max);
92 -
93 - *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y);
94 - *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
95 - *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
96 - }
97 - }
98 -#else
99 - for (int y = 0; y < kPatchSize; ++y) {
100 - const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max);
101 -
102 - for (int x = 0; x < kPatchSize; ++x) {
103 - const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max);
104 -
105 - *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos);
106 - *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos);
107 - *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos);
108 - }
109 - }
110 -#endif
111 -
112 - // Compute the spatial gradient matrix about point p.
113 - float G[] = { 0, 0, 0, 0 };
114 - CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G);
115 -
116 - // Find the inverse of G.
117 - float G_inv[4];
118 - if (!Invert2x2(G, G_inv)) {
119 - return false;
120 - }
121 -
122 -#if NORMALIZE
123 - const float mean_I = ComputeMean(vals_I, kFlowArraySize);
124 - const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I);
125 -#endif
126 -
127 - // Iterate kNumIterations times or until we converge.
128 - for (int iteration = 0; iteration < kNumIterations; ++iteration) {
129 - // Get values for frame 2.
130 - float vals_J[kFlowArraySize];
131 -
132 - // Get the window around the destination point.
133 - const float left_real = p_x + g_x - kWindowSizeFloat;
134 - const float top_real = p_y + g_y - kWindowSizeFloat;
135 - float* vals_J_ptr = vals_J;
136 -#if USE_FIXED_POINT_FLOW
137 - // The top-left sub-pixel is set for the current iteration (in 16:16
138 - // fixed). This is constant over one iteration.
139 - const int left_fixed = RealToFixed1616(left_real);
140 - const int top_fixed = RealToFixed1616(top_real);
141 -
142 - for (int win_y = 0; win_y < kPatchSize; ++win_y) {
143 - const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max);
144 - for (int win_x = 0; win_x < kPatchSize; ++win_x) {
145 - const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max);
146 - *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y);
147 - }
148 - }
149 -#else
150 - for (int win_y = 0; win_y < kPatchSize; ++win_y) {
151 - const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max);
152 - for (int win_x = 0; win_x < kPatchSize; ++win_x) {
153 - const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max);
154 - *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos);
155 - }
156 - }
157 -#endif
158 -
159 -#if NORMALIZE
160 - const float mean_J = ComputeMean(vals_J, kFlowArraySize);
161 - const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J);
162 -
163 - // TODO(andrewharp): Probably better to completely detect and handle the
164 - // "corner case" where the patch is fully outside the image diagonally.
165 - const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f;
166 -#endif
167 -
168 - // Compute image mismatch vector.
169 - float b_x = 0.0f;
170 - float b_y = 0.0f;
171 -
172 - vals_I_ptr = vals_I;
173 - vals_J_ptr = vals_J;
174 - vals_I_x_ptr = vals_I_x;
175 - vals_I_y_ptr = vals_I_y;
176 -
177 - for (int win_y = 0; win_y < kPatchSize; ++win_y) {
178 - for (int win_x = 0; win_x < kPatchSize; ++win_x) {
179 -#if NORMALIZE
180 - // Normalized Image difference.
181 - const float dI =
182 - (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio;
183 -#else
184 - const float dI = *vals_I_ptr++ - *vals_J_ptr++;
185 -#endif
186 - b_x += dI * *vals_I_x_ptr++;
187 - b_y += dI * *vals_I_y_ptr++;
188 - }
189 - }
190 -
191 - // Optical flow... solve n = G^-1 * b
192 - const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y);
193 - const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y);
194 -
195 - // Update best guess with residual displacement from this level and
196 - // iteration.
197 - g_x += n_x;
198 - g_y += n_y;
199 -
200 - // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y);
201 -
202 - // Abort early if we're already below the threshold.
203 - if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) {
204 - break;
205 - }
206 - } // Iteration.
207 -
208 - // Copy value back into output.
209 - *out_g_x = g_x;
210 - *out_g_y = g_y;
211 - return true;
212 -}
213 -
214 -
215 -// Pointwise flow using translational 2dof ESM.
216 -bool OpticalFlow::FindFlowAtPoint_ESM(
217 - const Image<uint8_t>& img_I, const Image<uint8_t>& img_J,
218 - const Image<int32_t>& I_x, const Image<int32_t>& I_y,
219 - const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x,
220 - const float p_y, float* out_g_x, float* out_g_y) {
221 - float g_x = *out_g_x;
222 - float g_y = *out_g_y;
223 - const float area_inv = 1.0f / static_cast<float>(kFlowArraySize);
224 -
225 - // Get values for frame 1. They remain constant through the inner
226 - // iteration loop.
227 - uint8_t vals_I[kFlowArraySize];
228 - uint8_t vals_J[kFlowArraySize];
229 - int16_t src_gradient_x[kFlowArraySize];
230 - int16_t src_gradient_y[kFlowArraySize];
231 -
232 - // TODO(rspring): try out the IntegerPatchAlign() method once
233 - // the code for that is in ../common.
234 - const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize);
235 - const int src_left_fixed = RealToFixed1616(p_x - wsize_float);
236 - const int src_top_fixed = RealToFixed1616(p_y - wsize_float);
237 - const int patch_size = 2 * kFlowIntegrationWindowSize + 1;
238 -
239 - // Create the keypoint template patch from a subpixel location.
240 - if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
241 - patch_size, patch_size, vals_I) ||
242 - !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
243 - patch_size, patch_size,
244 - src_gradient_x) ||
245 - !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
246 - patch_size, patch_size,
247 - src_gradient_y)) {
248 - return false;
249 - }
250 -
251 - int bright_offset = 0;
252 - int sum_diff = 0;
253 -
254 - // The top-left sub-pixel is set for the current iteration (in 16:16 fixed).
255 - // This is constant over one iteration.
256 - int left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
257 - int top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
258 -
259 - // The truncated version gives the most top-left pixel that is used.
260 - int left_trunc = left_fixed >> 16;
261 - int top_trunc = top_fixed >> 16;
262 -
263 - // Compute an initial brightness offset.
264 - if (kDoBrightnessNormalize &&
265 - left_trunc >= 0 && top_trunc >= 0 &&
266 - (left_trunc + patch_size) < img_J.width_less_one_ &&
267 - (top_trunc + patch_size) < img_J.height_less_one_) {
268 - int templ_index = 0;
269 - const uint8_t* j_row = img_J[top_trunc] + left_trunc;
270 -
271 - const int j_stride = img_J.stride();
272 -
273 - for (int y = 0; y < patch_size; ++y, j_row += j_stride) {
274 - for (int x = 0; x < patch_size; ++x) {
275 - sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++];
276 - }
277 - }
278 -
279 - bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv);
280 - }
281 -
282 - // Iterate kNumIterations times or until we go out of image.
283 - for (int iteration = 0; iteration < kNumIterations; ++iteration) {
284 - int jtj[3] = { 0, 0, 0 };
285 - int jtr[2] = { 0, 0 };
286 - sum_diff = 0;
287 -
288 - // Extract the target image values.
289 - // Extract the gradient from the target image patch and accumulate to
290 - // the gradient of the source image patch.
291 - if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed,
292 - patch_size, patch_size,
293 - vals_J)) {
294 - break;
295 - }
296 -
297 - const uint8_t* templ_row = vals_I;
298 - const uint8_t* extract_row = vals_J;
299 - const int16_t* src_dx_row = src_gradient_x;
300 - const int16_t* src_dy_row = src_gradient_y;
301 -
302 - for (int y = 0; y < patch_size; ++y, templ_row += patch_size,
303 - src_dx_row += patch_size, src_dy_row += patch_size,
304 - extract_row += patch_size) {
305 - const int fp_y = top_fixed + (y << 16);
306 - for (int x = 0; x < patch_size; ++x) {
307 - const int fp_x = left_fixed + (x << 16);
308 - int32_t target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y);
309 - int32_t target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y);
310 -
311 - // Combine the two Jacobians.
312 - // Right-shift by one to account for the fact that we add
313 - // two Jacobians.
314 - int32_t dx = (src_dx_row[x] + target_dx) >> 1;
315 - int32_t dy = (src_dy_row[x] + target_dy) >> 1;
316 -
317 - // The current residual b - h(q) == extracted - (template + offset)
318 - int32_t diff = static_cast<int32_t>(extract_row[x]) -
319 - static_cast<int32_t>(templ_row[x]) - bright_offset;
320 -
321 - jtj[0] += dx * dx;
322 - jtj[1] += dx * dy;
323 - jtj[2] += dy * dy;
324 -
325 - jtr[0] += dx * diff;
326 - jtr[1] += dy * diff;
327 -
328 - sum_diff += diff;
329 - }
330 - }
331 -
332 - const float jtr1_float = static_cast<float>(jtr[0]);
333 - const float jtr2_float = static_cast<float>(jtr[1]);
334 -
335 - // Add some baseline stability to the system.
336 - jtj[0] += kEsmRegularizer;
337 - jtj[2] += kEsmRegularizer;
338 -
339 - const int64_t prod1 = static_cast<int64_t>(jtj[0]) * jtj[2];
340 - const int64_t prod2 = static_cast<int64_t>(jtj[1]) * jtj[1];
341 -
342 - // One ESM step.
343 - const float jtj_1[4] = { static_cast<float>(jtj[2]),
344 - static_cast<float>(-jtj[1]),
345 - static_cast<float>(-jtj[1]),
346 - static_cast<float>(jtj[0]) };
347 - const double det_inv = 1.0 / static_cast<double>(prod1 - prod2);
348 -
349 - g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float);
350 - g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float);
351 -
352 - if (kDoBrightnessNormalize) {
353 - bright_offset +=
354 - static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f);
355 - }
356 -
357 - // Update top left position.
358 - left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
359 - top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
360 -
361 - left_trunc = left_fixed >> 16;
362 - top_trunc = top_fixed >> 16;
363 -
364 - // Abort iterations if we go out of borders.
365 - if (left_trunc < 0 || top_trunc < 0 ||
366 - (left_trunc + patch_size) >= J_x.width_less_one_ ||
367 - (top_trunc + patch_size) >= J_y.height_less_one_) {
368 - break;
369 - }
370 - } // Iteration.
371 -
372 - // Copy value back into output.
373 - *out_g_x = g_x;
374 - *out_g_y = g_y;
375 - return true;
376 -}
377 -
378 -
379 -bool OpticalFlow::FindFlowAtPointReversible(
380 - const int level, const float u_x, const float u_y,
381 - const bool reverse_flow,
382 - float* flow_x, float* flow_y) const {
383 - const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_;
384 - const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_;
385 -
386 - // Images I (prev) and J (next).
387 - const Image<uint8_t>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2);
388 - const Image<uint8_t>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2);
389 -
390 - // Computed gradients.
391 - const Image<int32_t>& I_x = *frame_a.GetSpatialX(level);
392 - const Image<int32_t>& I_y = *frame_a.GetSpatialY(level);
393 - const Image<int32_t>& J_x = *frame_b.GetSpatialX(level);
394 - const Image<int32_t>& J_y = *frame_b.GetSpatialY(level);
395 -
396 - // Shrink factor from original.
397 - const float shrink_factor = (1 << level);
398 -
399 - // Image position vector (p := u^l), scaled for this level.
400 - const float scaled_p_x = u_x / shrink_factor;
401 - const float scaled_p_y = u_y / shrink_factor;
402 -
403 - float scaled_flow_x = *flow_x / shrink_factor;
404 - float scaled_flow_y = *flow_y / shrink_factor;
405 -
406 - // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level,
407 - // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y);
408 -
409 - const bool success = kUseEsm ?
410 - FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y,
411 - scaled_p_x, scaled_p_y,
412 - &scaled_flow_x, &scaled_flow_y) :
413 - FindFlowAtPoint_LK(img_I, img_J, I_x, I_y,
414 - scaled_p_x, scaled_p_y,
415 - &scaled_flow_x, &scaled_flow_y);
416 -
417 - *flow_x = scaled_flow_x * shrink_factor;
418 - *flow_y = scaled_flow_y * shrink_factor;
419 -
420 - return success;
421 -}
422 -
423 -
424 -bool OpticalFlow::FindFlowAtPointSingleLevel(
425 - const int level,
426 - const float u_x, const float u_y,
427 - const bool filter_by_fb_error,
428 - float* flow_x, float* flow_y) const {
429 - if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) {
430 - return false;
431 - }
432 -
433 - if (filter_by_fb_error) {
434 - const float new_position_x = u_x + *flow_x;
435 - const float new_position_y = u_y + *flow_y;
436 -
437 - float reverse_flow_x = 0.0f;
438 - float reverse_flow_y = 0.0f;
439 -
440 - // Now find the backwards flow and confirm it lines up with the original
441 - // starting point.
442 - if (!FindFlowAtPointReversible(level, new_position_x, new_position_y,
443 - true,
444 - &reverse_flow_x, &reverse_flow_y)) {
445 - LOGE("Backward error!");
446 - return false;
447 - }
448 -
449 - const float discrepancy_length =
450 - sqrtf(Square(*flow_x + reverse_flow_x) +
451 - Square(*flow_y + reverse_flow_y));
452 -
453 - const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y));
454 -
455 - return discrepancy_length <
456 - (kMaxForwardBackwardErrorAllowed * flow_length);
457 - }
458 -
459 - return true;
460 -}
461 -
462 -
463 -// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm.
464 -// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details.
465 -bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y,
466 - const bool filter_by_fb_error,
467 - float* flow_x, float* flow_y) const {
468 - const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment,
469 - kNumPyramidLevels - kNumCacheLevels);
470 -
471 - // For every level in the pyramid, update the coordinates of the best match.
472 - for (int l = max_level - 1; l >= 0; --l) {
473 - if (!FindFlowAtPointSingleLevel(l, u_x, u_y,
474 - filter_by_fb_error, flow_x, flow_y)) {
475 - return false;
476 - }
477 - }
478 -
479 - return true;
480 -}
481 -
482 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
18 -
19 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
20 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
21 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
22 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23 -
24 -#include "tensorflow/examples/android/jni/object_tracking/config.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
27 -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
28 -
29 -namespace tf_tracking {
30 -
31 -class FlowCache;
32 -
33 -// Class encapsulating all the data and logic necessary for performing optical
34 -// flow.
35 -class OpticalFlow {
36 - public:
37 - explicit OpticalFlow(const OpticalFlowConfig* const config);
38 -
39 - // Add a new frame to the optical flow. Will update all the non-keypoint
40 - // related member variables.
41 - //
42 - // new_frame should be a buffer of grayscale values, one byte per pixel,
43 - // at the original frame_width and frame_height used to initialize the
44 - // OpticalFlow object. Downsampling will be handled internally.
45 - //
46 - // time_stamp should be a time in milliseconds that later calls to this and
47 - // other methods will be relative to.
48 - void NextFrame(const ImageData* const image_data);
49 -
50 - // An implementation of the Lucas-Kanade Optical Flow algorithm.
51 - static bool FindFlowAtPoint_LK(const Image<uint8_t>& img_I,
52 - const Image<uint8_t>& img_J,
53 - const Image<int32_t>& I_x,
54 - const Image<int32_t>& I_y, const float p_x,
55 - const float p_y, float* out_g_x,
56 - float* out_g_y);
57 -
58 - // Pointwise flow using translational 2dof ESM.
59 - static bool FindFlowAtPoint_ESM(
60 - const Image<uint8_t>& img_I, const Image<uint8_t>& img_J,
61 - const Image<int32_t>& I_x, const Image<int32_t>& I_y,
62 - const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x,
63 - const float p_y, float* out_g_x, float* out_g_y);
64 -
65 - // Finds the flow using a specific level, in either direction.
66 - // If reversed, the coordinates are in the context of the latest
67 - // frame, not the frame before it.
68 - // All coordinates used in parameters are global, not scaled.
69 - bool FindFlowAtPointReversible(
70 - const int level, const float u_x, const float u_y,
71 - const bool reverse_flow,
72 - float* final_x, float* final_y) const;
73 -
74 - // Finds the flow using a specific level, filterable by forward-backward
75 - // error. All coordinates used in parameters are global, not scaled.
76 - bool FindFlowAtPointSingleLevel(const int level,
77 - const float u_x, const float u_y,
78 - const bool filter_by_fb_error,
79 - float* flow_x, float* flow_y) const;
80 -
81 - // Pyramidal optical-flow using all levels.
82 - bool FindFlowAtPointPyramidal(const float u_x, const float u_y,
83 - const bool filter_by_fb_error,
84 - float* flow_x, float* flow_y) const;
85 -
86 - private:
87 - const OpticalFlowConfig* const config_;
88 -
89 - const ImageData* frame1_;
90 - const ImageData* frame2_;
91 -
92 - // Size of the internally allocated images (after original is downsampled).
93 - const Size working_size_;
94 -
95 - TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow);
96 -};
97 -
98 -} // namespace tf_tracking
99 -
100 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
18 -
19 -#ifdef __RENDER_OPENGL__
20 -
21 -#include <GLES/gl.h>
22 -#include <GLES/glext.h>
23 -
24 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
26 -
27 -namespace tf_tracking {
28 -
29 -// This class encapsulates the logic necessary to load an render image data
30 -// at the same aspect ratio as the original source.
31 -class Sprite {
32 - public:
33 - // Only create Sprites when you have an OpenGl context.
34 - explicit Sprite(const Image<uint8_t>& image) { LoadTexture(image, NULL); }
35 -
36 - Sprite(const Image<uint8_t>& image, const BoundingBox* const area) {
37 - LoadTexture(image, area);
38 - }
39 -
40 - // Also, try to only delete a Sprite when holding an OpenGl context.
41 - ~Sprite() {
42 - glDeleteTextures(1, &texture_);
43 - }
44 -
45 - inline int GetWidth() const {
46 - return actual_width_;
47 - }
48 -
49 - inline int GetHeight() const {
50 - return actual_height_;
51 - }
52 -
53 - // Draw the sprite at 0,0 - original width/height in the current reference
54 - // frame. Any transformations desired must be applied before calling this
55 - // function.
56 - void Draw() const {
57 - const float float_width = static_cast<float>(actual_width_);
58 - const float float_height = static_cast<float>(actual_height_);
59 -
60 - // Where it gets rendered to.
61 - const float vertices[] = { 0.0f, 0.0f, 0.0f,
62 - 0.0f, float_height, 0.0f,
63 - float_width, 0.0f, 0.0f,
64 - float_width, float_height, 0.0f,
65 - };
66 -
67 - // The coordinates the texture gets drawn from.
68 - const float max_x = float_width / texture_width_;
69 - const float max_y = float_height / texture_height_;
70 - const float textureVertices[] = {
71 - 0, 0,
72 - 0, max_y,
73 - max_x, 0,
74 - max_x, max_y,
75 - };
76 -
77 - glEnable(GL_TEXTURE_2D);
78 - glBindTexture(GL_TEXTURE_2D, texture_);
79 -
80 - glEnableClientState(GL_VERTEX_ARRAY);
81 - glEnableClientState(GL_TEXTURE_COORD_ARRAY);
82 -
83 - glVertexPointer(3, GL_FLOAT, 0, vertices);
84 - glTexCoordPointer(2, GL_FLOAT, 0, textureVertices);
85 -
86 - glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
87 -
88 - glDisableClientState(GL_VERTEX_ARRAY);
89 - glDisableClientState(GL_TEXTURE_COORD_ARRAY);
90 - }
91 -
92 - private:
93 - inline int GetNextPowerOfTwo(const int number) const {
94 - int power_of_two = 1;
95 - while (power_of_two < number) {
96 - power_of_two *= 2;
97 - }
98 - return power_of_two;
99 - }
100 -
101 - // TODO(andrewharp): Allow sprites to have their textures reloaded.
102 - void LoadTexture(const Image<uint8_t>& texture_source,
103 - const BoundingBox* const area) {
104 - glEnable(GL_TEXTURE_2D);
105 -
106 - glGenTextures(1, &texture_);
107 -
108 - glBindTexture(GL_TEXTURE_2D, texture_);
109 -
110 - int left = 0;
111 - int top = 0;
112 -
113 - if (area != NULL) {
114 - // If a sub-region was provided to pull the texture from, use that.
115 - left = area->left_;
116 - top = area->top_;
117 - actual_width_ = area->GetWidth();
118 - actual_height_ = area->GetHeight();
119 - } else {
120 - actual_width_ = texture_source.GetWidth();
121 - actual_height_ = texture_source.GetHeight();
122 - }
123 -
124 - // The textures must be a power of two, so find the sizes that are large
125 - // enough to contain the image data.
126 - texture_width_ = GetNextPowerOfTwo(actual_width_);
127 - texture_height_ = GetNextPowerOfTwo(actual_height_);
128 -
129 - bool allocated_data = false;
130 - uint8_t* texture_data;
131 -
132 - // Except in the lucky case where we're not using a sub-region of the
133 - // original image AND the source data has dimensions that are power of two,
134 - // care must be taken to copy data at the appropriate source and destination
135 - // strides so that the final block can be copied directly into texture
136 - // memory.
137 - // TODO(andrewharp): Figure out if data can be pulled directly from the
138 - // source image with some alignment modifications.
139 - if (left != 0 || top != 0 ||
140 - actual_width_ != texture_source.GetWidth() ||
141 - actual_height_ != texture_source.GetHeight()) {
142 - texture_data = new uint8_t[actual_width_ * actual_height_];
143 -
144 - for (int y = 0; y < actual_height_; ++y) {
145 - memcpy(texture_data + actual_width_ * y, texture_source[top + y] + left,
146 - actual_width_ * sizeof(uint8_t));
147 - }
148 - allocated_data = true;
149 - } else {
150 - // Cast away const-ness because for some reason glTexSubImage2D wants
151 - // a non-const data pointer.
152 - texture_data = const_cast<uint8_t*>(texture_source.data());
153 - }
154 -
155 - glTexImage2D(GL_TEXTURE_2D,
156 - 0,
157 - GL_LUMINANCE,
158 - texture_width_,
159 - texture_height_,
160 - 0,
161 - GL_LUMINANCE,
162 - GL_UNSIGNED_BYTE,
163 - NULL);
164 -
165 - glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
166 - glTexSubImage2D(GL_TEXTURE_2D,
167 - 0,
168 - 0,
169 - 0,
170 - actual_width_,
171 - actual_height_,
172 - GL_LUMINANCE,
173 - GL_UNSIGNED_BYTE,
174 - texture_data);
175 -
176 - if (allocated_data) {
177 - delete(texture_data);
178 - }
179 -
180 - glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
181 - }
182 -
183 - // The id for the texture on the GPU.
184 - GLuint texture_;
185 -
186 - // The width and height to be used for display purposes, referring to the
187 - // dimensions of the original texture.
188 - int actual_width_;
189 - int actual_height_;
190 -
191 - // The allocated dimensions of the texture data, which must be powers of 2.
192 - int texture_width_;
193 - int texture_height_;
194 -
195 - TF_DISALLOW_COPY_AND_ASSIGN(Sprite);
196 -};
197 -
198 -} // namespace tf_tracking
199 -
200 -#endif // __RENDER_OPENGL__
201 -
202 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
17 -
18 -#ifdef LOG_TIME
19 -// Storage for logging functionality.
20 -int num_time_logs = 0;
21 -LogEntry time_logs[NUM_LOGS];
22 -
23 -int num_avg_entries = 0;
24 -AverageEntry avg_entries[NUM_LOGS];
25 -#endif
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// Utility functions for performance profiling.
17 -
18 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
19 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
20 -
21 -#include <stdint.h>
22 -
23 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
25 -
26 -#ifdef LOG_TIME
27 -
28 -// Blend constant for running average.
29 -#define ALPHA 0.98f
30 -#define NUM_LOGS 100
31 -
32 -struct LogEntry {
33 - const char* id;
34 - int64_t time_stamp;
35 -};
36 -
37 -struct AverageEntry {
38 - const char* id;
39 - float average_duration;
40 -};
41 -
42 -// Storage for keeping track of this frame's values.
43 -extern int num_time_logs;
44 -extern LogEntry time_logs[NUM_LOGS];
45 -
46 -// Storage for keeping track of average values (each entry may not be printed
47 -// out each frame).
48 -extern AverageEntry avg_entries[NUM_LOGS];
49 -extern int num_avg_entries;
50 -
51 -// Call this at the start of a logging phase.
52 -inline static void ResetTimeLog() {
53 - num_time_logs = 0;
54 -}
55 -
56 -
57 -// Log a message to be printed out when printTimeLog is called, along with the
58 -// amount of time in ms that has passed since the last call to this function.
59 -inline static void TimeLog(const char* const str) {
60 - LOGV("%s", str);
61 - if (num_time_logs >= NUM_LOGS) {
62 - LOGE("Out of log entries!");
63 - return;
64 - }
65 -
66 - time_logs[num_time_logs].id = str;
67 - time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos();
68 - ++num_time_logs;
69 -}
70 -
71 -
72 -inline static float Blend(float old_val, float new_val) {
73 - return ALPHA * old_val + (1.0f - ALPHA) * new_val;
74 -}
75 -
76 -
77 -inline static float UpdateAverage(const char* str, const float new_val) {
78 - for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) {
79 - AverageEntry* const entry = avg_entries + entry_num;
80 - if (str == entry->id) {
81 - entry->average_duration = Blend(entry->average_duration, new_val);
82 - return entry->average_duration;
83 - }
84 - }
85 -
86 - if (num_avg_entries >= NUM_LOGS) {
87 - LOGE("Too many log entries!");
88 - }
89 -
90 - // If it wasn't there already, add it.
91 - avg_entries[num_avg_entries].id = str;
92 - avg_entries[num_avg_entries].average_duration = new_val;
93 - ++num_avg_entries;
94 -
95 - return new_val;
96 -}
97 -
98 -
99 -// Prints out all the timeLog statements in chronological order with the
100 -// interval that passed between subsequent statements. The total time between
101 -// the first and last statements is printed last.
102 -inline static void PrintTimeLog() {
103 - LogEntry* last_time = time_logs;
104 -
105 - float average_running_total = 0.0f;
106 -
107 - for (int i = 0; i < num_time_logs; ++i) {
108 - LogEntry* const this_time = time_logs + i;
109 -
110 - const float curr_time =
111 - (this_time->time_stamp - last_time->time_stamp) / 1000000.0f;
112 -
113 - const float avg_time = UpdateAverage(this_time->id, curr_time);
114 - average_running_total += avg_time;
115 -
116 - LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time);
117 - last_time = this_time;
118 - }
119 -
120 - const float total_time =
121 - (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f;
122 -
123 - LOGD("TOTAL TIME: %6.3fms %6.4fms\n",
124 - total_time, average_running_total);
125 - LOGD(" ");
126 -}
127 -#else
128 -inline static void ResetTimeLog() {}
129 -
130 -inline static void TimeLog(const char* const str) {
131 - LOGV("%s", str);
132 -}
133 -
134 -inline static void PrintTimeLog() {}
135 -#endif
136 -
137 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
17 -
18 -namespace tf_tracking {
19 -
20 -static const float kInitialDistance = 20.0f;
21 -
22 -static void InitNormalized(const Image<uint8_t>& src_image,
23 - const BoundingBox& position,
24 - Image<float>* const dst_image) {
25 - BoundingBox scaled_box(position);
26 - CopyArea(src_image, scaled_box, dst_image);
27 - NormalizeImage(dst_image);
28 -}
29 -
30 -TrackedObject::TrackedObject(const std::string& id, const Image<uint8_t>& image,
31 - const BoundingBox& bounding_box,
32 - ObjectModelBase* const model)
33 - : id_(id),
34 - last_known_position_(bounding_box),
35 - last_detection_position_(bounding_box),
36 - position_last_computed_time_(-1),
37 - object_model_(model),
38 - last_detection_thumbnail_(kNormalizedThumbnailSize,
39 - kNormalizedThumbnailSize),
40 - last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize),
41 - tracked_correlation_(0.0f),
42 - tracked_match_score_(0.0),
43 - num_consecutive_frames_below_threshold_(0),
44 - allowable_detection_distance_(Square(kInitialDistance)) {
45 - InitNormalized(image, bounding_box, &last_detection_thumbnail_);
46 -}
47 -
48 -TrackedObject::~TrackedObject() {}
49 -
50 -void TrackedObject::UpdatePosition(const BoundingBox& new_position,
51 - const int64_t timestamp,
52 - const ImageData& image_data,
53 - const bool authoritative) {
54 - last_known_position_ = new_position;
55 - position_last_computed_time_ = timestamp;
56 -
57 - InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_);
58 -
59 - const float last_localization_correlation = ComputeCrossCorrelation(
60 - last_detection_thumbnail_.data(),
61 - last_frame_thumbnail_.data(),
62 - last_frame_thumbnail_.data_size_);
63 - LOGV("Tracked correlation to last localization: %.6f",
64 - last_localization_correlation);
65 -
66 - // Correlation to object model, if it exists.
67 - if (object_model_ != NULL) {
68 - tracked_correlation_ =
69 - object_model_->GetMaxCorrelation(last_frame_thumbnail_);
70 - LOGV("Tracked correlation to model: %.6f",
71 - tracked_correlation_);
72 -
73 - tracked_match_score_ =
74 - object_model_->GetMatchScore(new_position, image_data);
75 - LOGV("Tracked match score with model: %.6f",
76 - tracked_match_score_.value);
77 - } else {
78 - // If there's no model to check against, set the tracked correlation to
79 - // simply be the correlation to the last set position.
80 - tracked_correlation_ = last_localization_correlation;
81 - tracked_match_score_ = MatchScore(0.0f);
82 - }
83 -
84 - // Determine if it's still being tracked.
85 - if (tracked_correlation_ >= kMinimumCorrelationForTracking &&
86 - tracked_match_score_ >= kMinimumMatchScore) {
87 - num_consecutive_frames_below_threshold_ = 0;
88 -
89 - if (object_model_ != NULL) {
90 - object_model_->TrackStep(last_known_position_, *image_data.GetImage(),
91 - *image_data.GetIntegralImage(), authoritative);
92 - }
93 - } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) {
94 - if (num_consecutive_frames_below_threshold_ < 1000) {
95 - LOGD("Tracked match score is way too low (%.6f), aborting track.",
96 - tracked_match_score_.value);
97 - }
98 -
99 - // Add an absurd amount of missed frames so that all heuristics will
100 - // consider it a lost track.
101 - num_consecutive_frames_below_threshold_ += 1000;
102 -
103 - if (object_model_ != NULL) {
104 - object_model_->TrackLost();
105 - }
106 - } else {
107 - ++num_consecutive_frames_below_threshold_;
108 - allowable_detection_distance_ *= 1.1f;
109 - }
110 -}
111 -
112 -void TrackedObject::OnDetection(ObjectModelBase* const model,
113 - const BoundingBox& detection_position,
114 - const MatchScore match_score,
115 - const int64_t timestamp,
116 - const ImageData& image_data) {
117 - const float overlap = detection_position.PascalScore(last_known_position_);
118 - if (overlap > kPositionOverlapThreshold) {
119 - // If the position agreement with the current tracked position is good
120 - // enough, lock all the current unlocked examples.
121 - object_model_->TrackConfirmed();
122 - num_consecutive_frames_below_threshold_ = 0;
123 - }
124 -
125 - // Before relocalizing, make sure the new proposed position is better than
126 - // the existing position by a small amount to prevent thrashing.
127 - if (match_score <= tracked_match_score_ + kMatchScoreBuffer) {
128 - LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f",
129 - match_score.value, tracked_match_score_.value,
130 - kMatchScoreBuffer.value);
131 - return;
132 - }
133 -
134 - LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to "
135 - "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f",
136 - last_known_position_.left_, last_known_position_.top_,
137 - last_known_position_.GetWidth(), last_known_position_.GetHeight(),
138 - detection_position.left_, detection_position.top_,
139 - detection_position.GetWidth(), detection_position.GetHeight(),
140 - match_score.value, tracked_match_score_.value);
141 -
142 - if (overlap < kPositionOverlapThreshold) {
143 - // The path might be good, it might be bad, but it's no longer a path
144 - // since we're moving the box to a new position, so just nuke it from
145 - // orbit to be safe.
146 - object_model_->TrackLost();
147 - }
148 -
149 - object_model_ = model;
150 -
151 - // Reset the last detected appearance.
152 - InitNormalized(
153 - *image_data.GetImage(), detection_position, &last_detection_thumbnail_);
154 -
155 - num_consecutive_frames_below_threshold_ = 0;
156 - last_detection_position_ = detection_position;
157 -
158 - UpdatePosition(detection_position, timestamp, image_data, false);
159 - allowable_detection_distance_ = Square(kInitialDistance);
160 -}
161 -
162 -} // namespace tf_tracking
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
18 -
19 -#ifdef __RENDER_OPENGL__
20 -#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
21 -#endif
22 -#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
23 -
24 -namespace tf_tracking {
25 -
26 -// A TrackedObject is a specific instance of an ObjectModel, with a known
27 -// position in the world.
28 -// It provides the last known position and number of recent detection failures,
29 -// in addition to the more general appearance data associated with the object
30 -// class (which is in ObjectModel).
31 -// TODO(andrewharp): Make getters/setters follow styleguide.
32 -class TrackedObject {
33 - public:
34 - TrackedObject(const std::string& id, const Image<uint8_t>& image,
35 - const BoundingBox& bounding_box, ObjectModelBase* const model);
36 -
37 - ~TrackedObject();
38 -
39 - void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp,
40 - const ImageData& image_data, const bool authoritative);
41 -
42 - // This method is called when the tracked object is detected at a
43 - // given position, and allows the associated Model to grow and/or prune
44 - // itself based on where the detection occurred.
45 - void OnDetection(ObjectModelBase* const model,
46 - const BoundingBox& detection_position,
47 - const MatchScore match_score, const int64_t timestamp,
48 - const ImageData& image_data);
49 -
50 - // Called when there's no detection of the tracked object. This will cause
51 - // a tracking failure after enough consecutive failures if the area under
52 - // the current bounding box also doesn't meet a minimum correlation threshold
53 - // with the model.
54 - void OnDetectionFailure() {}
55 -
56 - inline bool IsVisible() const {
57 - return tracked_correlation_ >= kMinimumCorrelationForTracking ||
58 - num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
59 - }
60 -
61 - inline float GetCorrelation() {
62 - return tracked_correlation_;
63 - }
64 -
65 - inline MatchScore GetMatchScore() {
66 - return tracked_match_score_;
67 - }
68 -
69 - inline BoundingBox GetPosition() const {
70 - return last_known_position_;
71 - }
72 -
73 - inline BoundingBox GetLastDetectionPosition() const {
74 - return last_detection_position_;
75 - }
76 -
77 - inline const ObjectModelBase* GetModel() const {
78 - return object_model_;
79 - }
80 -
81 - inline const std::string& GetName() const {
82 - return id_;
83 - }
84 -
85 - inline void Draw() const {
86 -#ifdef __RENDER_OPENGL__
87 - if (tracked_correlation_ < kMinimumCorrelationForTracking) {
88 - glColor4f(MAX(0.0f, -tracked_correlation_),
89 - MAX(0.0f, tracked_correlation_),
90 - 0.0f,
91 - 1.0f);
92 - } else {
93 - glColor4f(MAX(0.0f, -tracked_correlation_),
94 - MAX(0.0f, tracked_correlation_),
95 - 1.0f,
96 - 1.0f);
97 - }
98 -
99 - // Render the box itself.
100 - BoundingBox temp_box(last_known_position_);
101 - DrawBox(temp_box);
102 -
103 - // Render a box inside this one (in case the actual box is hidden).
104 - const float kBufferSize = 1.0f;
105 - temp_box.left_ -= kBufferSize;
106 - temp_box.top_ -= kBufferSize;
107 - temp_box.right_ += kBufferSize;
108 - temp_box.bottom_ += kBufferSize;
109 - DrawBox(temp_box);
110 -
111 - // Render one outside as well.
112 - temp_box.left_ -= -2.0f * kBufferSize;
113 - temp_box.top_ -= -2.0f * kBufferSize;
114 - temp_box.right_ += -2.0f * kBufferSize;
115 - temp_box.bottom_ += -2.0f * kBufferSize;
116 - DrawBox(temp_box);
117 -#endif
118 - }
119 -
120 - // Get current object's num_consecutive_frames_below_threshold_.
121 - inline int64_t GetNumConsecutiveFramesBelowThreshold() {
122 - return num_consecutive_frames_below_threshold_;
123 - }
124 -
125 - // Reset num_consecutive_frames_below_threshold_ to 0.
126 - inline void resetNumConsecutiveFramesBelowThreshold() {
127 - num_consecutive_frames_below_threshold_ = 0;
128 - }
129 -
130 - inline float GetAllowableDistanceSquared() const {
131 - return allowable_detection_distance_;
132 - }
133 -
134 - private:
135 - // The unique id used throughout the system to identify this
136 - // tracked object.
137 - const std::string id_;
138 -
139 - // The last known position of the object.
140 - BoundingBox last_known_position_;
141 -
142 - // The last known position of the object.
143 - BoundingBox last_detection_position_;
144 -
145 - // When the position was last computed.
146 - int64_t position_last_computed_time_;
147 -
148 - // The object model this tracked object is representative of.
149 - ObjectModelBase* object_model_;
150 -
151 - Image<float> last_detection_thumbnail_;
152 -
153 - Image<float> last_frame_thumbnail_;
154 -
155 - // The correlation of the object model with the preview frame at its last
156 - // tracked position.
157 - float tracked_correlation_;
158 -
159 - MatchScore tracked_match_score_;
160 -
161 - // The number of consecutive frames that the tracked position for this object
162 - // has been under the correlation threshold.
163 - int num_consecutive_frames_below_threshold_;
164 -
165 - float allowable_detection_distance_;
166 -
167 - friend std::ostream& operator<<(std::ostream& stream,
168 - const TrackedObject& tracked_object);
169 -
170 - TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
171 -};
172 -
173 -inline std::ostream& operator<<(std::ostream& stream,
174 - const TrackedObject& tracked_object) {
175 - stream << tracked_object.id_
176 - << " " << tracked_object.last_known_position_
177 - << " " << tracked_object.position_last_computed_time_
178 - << " " << tracked_object.num_consecutive_frames_below_threshold_
179 - << " " << tracked_object.object_model_
180 - << " " << tracked_object.tracked_correlation_;
181 - return stream;
182 -}
183 -
184 -} // namespace tf_tracking
185 -
186 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
18 -
19 -#include <math.h>
20 -#include <stdint.h>
21 -#include <stdlib.h>
22 -#include <time.h>
23 -
24 -#include <cmath> // for std::abs(float)
25 -
26 -#ifndef HAVE_CLOCK_GETTIME
27 -// Use gettimeofday() instead of clock_gettime().
28 -#include <sys/time.h>
29 -#endif // ifdef HAVE_CLOCK_GETTIME
30 -
31 -#include "tensorflow/examples/android/jni/object_tracking/logging.h"
32 -
33 -// TODO(andrewharp): clean up these macros to use the codebase statndard.
34 -
35 -// A very small number, generally used as the tolerance for accumulated
36 -// floating point errors in bounds-checks.
37 -#define EPSILON 0.00001f
38 -
39 -#define SAFE_DELETE(pointer) {\
40 - if ((pointer) != NULL) {\
41 - LOGV("Safe deleting pointer: %s", #pointer);\
42 - delete (pointer);\
43 - (pointer) = NULL;\
44 - } else {\
45 - LOGV("Pointer already null: %s", #pointer);\
46 - }\
47 -}
48 -
49 -
50 -#ifdef __GOOGLE__
51 -
52 -#define CHECK_ALWAYS(condition, format, ...) {\
53 - CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
54 -}
55 -
56 -#define SCHECK(condition, format, ...) {\
57 - DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
58 -}
59 -
60 -#else
61 -
62 -#define CHECK_ALWAYS(condition, format, ...) {\
63 - if (!(condition)) {\
64 - LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\
65 - abort();\
66 - }\
67 -}
68 -
69 -#ifdef SANITY_CHECKS
70 -#define SCHECK(condition, format, ...) {\
71 - CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\
72 -}
73 -#else
74 -#define SCHECK(condition, format, ...) {}
75 -#endif // SANITY_CHECKS
76 -
77 -#endif // __GOOGLE__
78 -
79 -
80 -#ifndef MAX
81 -#define MAX(a, b) (((a) > (b)) ? (a) : (b))
82 -#endif
83 -#ifndef MIN
84 -#define MIN(a, b) (((a) > (b)) ? (b) : (a))
85 -#endif
86 -
87 -inline static int64_t CurrentThreadTimeNanos() {
88 -#ifdef HAVE_CLOCK_GETTIME
89 - struct timespec tm;
90 - clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm);
91 - return tm.tv_sec * 1000000000LL + tm.tv_nsec;
92 -#else
93 - struct timeval tv;
94 - gettimeofday(&tv, NULL);
95 - return tv.tv_sec * 1000000000 + tv.tv_usec * 1000;
96 -#endif
97 -}
98 -
99 -inline static int64_t CurrentRealTimeMillis() {
100 -#ifdef HAVE_CLOCK_GETTIME
101 - struct timespec tm;
102 - clock_gettime(CLOCK_MONOTONIC, &tm);
103 - return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL;
104 -#else
105 - struct timeval tv;
106 - gettimeofday(&tv, NULL);
107 - return tv.tv_sec * 1000 + tv.tv_usec / 1000;
108 -#endif
109 -}
110 -
111 -
112 -template<typename T>
113 -inline static T Square(const T a) {
114 - return a * a;
115 -}
116 -
117 -
118 -template<typename T>
119 -inline static T Clip(const T a, const T floor, const T ceil) {
120 - SCHECK(ceil >= floor, "Bounds mismatch!");
121 - return (a <= floor) ? floor : ((a >= ceil) ? ceil : a);
122 -}
123 -
124 -
125 -template<typename T>
126 -inline static int Floor(const T a) {
127 - return static_cast<int>(a);
128 -}
129 -
130 -
131 -template<typename T>
132 -inline static int Ceil(const T a) {
133 - return Floor(a) + 1;
134 -}
135 -
136 -
137 -template<typename T>
138 -inline static bool InRange(const T a, const T min, const T max) {
139 - return (a >= min) && (a <= max);
140 -}
141 -
142 -
143 -inline static bool ValidIndex(const int a, const int max) {
144 - return (a >= 0) && (a < max);
145 -}
146 -
147 -
148 -inline bool NearlyEqual(const float a, const float b, const float tolerance) {
149 - return std::abs(a - b) < tolerance;
150 -}
151 -
152 -
153 -inline bool NearlyEqual(const float a, const float b) {
154 - return NearlyEqual(a, b, EPSILON);
155 -}
156 -
157 -
158 -template<typename T>
159 -inline static int Round(const float a) {
160 - return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a));
161 -}
162 -
163 -
164 -template<typename T>
165 -inline static void Swap(T* const a, T* const b) {
166 - // Cache out the VALUE of what's at a.
167 - T tmp = *a;
168 - *a = *b;
169 -
170 - *b = tmp;
171 -}
172 -
173 -
174 -static inline float randf() {
175 - return rand() / static_cast<float>(RAND_MAX);
176 -}
177 -
178 -static inline float randf(const float min_value, const float max_value) {
179 - return randf() * (max_value - min_value) + min_value;
180 -}
181 -
182 -static inline uint16_t RealToFixed115(const float real_number) {
183 - SCHECK(InRange(real_number, 0.0f, 2048.0f),
184 - "Value out of range! %.2f", real_number);
185 -
186 - static const float kMult = 32.0f;
187 - const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
188 - return static_cast<uint16_t>(real_number * kMult + round_add);
189 -}
190 -
191 -static inline float FixedToFloat115(const uint16_t fp_number) {
192 - const float kDiv = 32.0f;
193 - return (static_cast<float>(fp_number) / kDiv);
194 -}
195 -
196 -static inline int RealToFixed1616(const float real_number) {
197 - static const float kMult = 65536.0f;
198 - SCHECK(InRange(real_number, -kMult, kMult),
199 - "Value out of range! %.2f", real_number);
200 -
201 - const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
202 - return static_cast<int>(real_number * kMult + round_add);
203 -}
204 -
205 -static inline float FixedToFloat1616(const int fp_number) {
206 - const float kDiv = 65536.0f;
207 - return (static_cast<float>(fp_number) / kDiv);
208 -}
209 -
210 -template<typename T>
211 -// produces numbers in range [0,2*M_PI] (rather than -PI,PI)
212 -inline T FastAtan2(const T y, const T x) {
213 - static const T coeff_1 = (T)(M_PI / 4.0);
214 - static const T coeff_2 = (T)(3.0 * coeff_1);
215 - const T abs_y = fabs(y);
216 - T angle;
217 - if (x >= 0) {
218 - T r = (x - abs_y) / (x + abs_y);
219 - angle = coeff_1 - coeff_1 * r;
220 - } else {
221 - T r = (x + abs_y) / (abs_y - x);
222 - angle = coeff_2 - coeff_1 * r;
223 - }
224 - static const T PI_2 = 2.0 * M_PI;
225 - return y < 0 ? PI_2 - angle : angle;
226 -}
227 -
228 -#define NELEMS(X) (sizeof(X) / sizeof(X[0]))
229 -
230 -namespace tf_tracking {
231 -
232 -#ifdef __ARM_NEON
233 -float ComputeMeanNeon(const float* const values, const int num_vals);
234 -
235 -float ComputeStdDevNeon(const float* const values, const int num_vals,
236 - const float mean);
237 -
238 -float ComputeWeightedMeanNeon(const float* const values,
239 - const float* const weights, const int num_vals);
240 -
241 -float ComputeCrossCorrelationNeon(const float* const values1,
242 - const float* const values2,
243 - const int num_vals);
244 -#endif
245 -
246 -inline float ComputeMeanCpu(const float* const values, const int num_vals) {
247 - // Get mean.
248 - float sum = values[0];
249 - for (int i = 1; i < num_vals; ++i) {
250 - sum += values[i];
251 - }
252 - return sum / static_cast<float>(num_vals);
253 -}
254 -
255 -
256 -inline float ComputeMean(const float* const values, const int num_vals) {
257 - return
258 -#ifdef __ARM_NEON
259 - (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) :
260 -#endif
261 - ComputeMeanCpu(values, num_vals);
262 -}
263 -
264 -
265 -inline float ComputeStdDevCpu(const float* const values,
266 - const int num_vals,
267 - const float mean) {
268 - // Get Std dev.
269 - float squared_sum = 0.0f;
270 - for (int i = 0; i < num_vals; ++i) {
271 - squared_sum += Square(values[i] - mean);
272 - }
273 - return sqrt(squared_sum / static_cast<float>(num_vals));
274 -}
275 -
276 -
277 -inline float ComputeStdDev(const float* const values,
278 - const int num_vals,
279 - const float mean) {
280 - return
281 -#ifdef __ARM_NEON
282 - (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) :
283 -#endif
284 - ComputeStdDevCpu(values, num_vals, mean);
285 -}
286 -
287 -
288 -// TODO(andrewharp): Accelerate with NEON.
289 -inline float ComputeWeightedMean(const float* const values,
290 - const float* const weights,
291 - const int num_vals) {
292 - float sum = 0.0f;
293 - float total_weight = 0.0f;
294 - for (int i = 0; i < num_vals; ++i) {
295 - sum += values[i] * weights[i];
296 - total_weight += weights[i];
297 - }
298 - return sum / num_vals;
299 -}
300 -
301 -
302 -inline float ComputeCrossCorrelationCpu(const float* const values1,
303 - const float* const values2,
304 - const int num_vals) {
305 - float sxy = 0.0f;
306 - for (int offset = 0; offset < num_vals; ++offset) {
307 - sxy += values1[offset] * values2[offset];
308 - }
309 -
310 - const float cross_correlation = sxy / num_vals;
311 -
312 - return cross_correlation;
313 -}
314 -
315 -
316 -inline float ComputeCrossCorrelation(const float* const values1,
317 - const float* const values2,
318 - const int num_vals) {
319 - return
320 -#ifdef __ARM_NEON
321 - (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals)
322 - :
323 -#endif
324 - ComputeCrossCorrelationCpu(values1, values2, num_vals);
325 -}
326 -
327 -
328 -inline void NormalizeNumbers(float* const values, const int num_vals) {
329 - // Find the mean and then subtract so that the new mean is 0.0.
330 - const float mean = ComputeMean(values, num_vals);
331 - VLOG(2) << "Mean is " << mean;
332 - float* curr_data = values;
333 - for (int i = 0; i < num_vals; ++i) {
334 - *curr_data -= mean;
335 - curr_data++;
336 - }
337 -
338 - // Now divide by the std deviation so the new standard deviation is 1.0.
339 - // The numbers might all be identical (and thus shifted to 0.0 now),
340 - // so only scale by the standard deviation if this is not the case.
341 - const float std_dev = ComputeStdDev(values, num_vals, 0.0f);
342 - if (std_dev > 0.0f) {
343 - VLOG(2) << "Std dev is " << std_dev;
344 - curr_data = values;
345 - for (int i = 0; i < num_vals; ++i) {
346 - *curr_data /= std_dev;
347 - curr_data++;
348 - }
349 - }
350 -}
351 -
352 -
353 -// Returns the determinant of a 2x2 matrix.
354 -template<class T>
355 -inline T FindDeterminant2x2(const T* const a) {
356 - // Determinant: (ad - bc)
357 - return a[0] * a[3] - a[1] * a[2];
358 -}
359 -
360 -
361 -// Finds the inverse of a 2x2 matrix.
362 -// Returns true upon success, false if the matrix is not invertible.
363 -template<class T>
364 -inline bool Invert2x2(const T* const a, float* const a_inv) {
365 - const float det = static_cast<float>(FindDeterminant2x2(a));
366 - if (fabs(det) < EPSILON) {
367 - return false;
368 - }
369 - const float inv_det = 1.0f / det;
370 -
371 - a_inv[0] = inv_det * static_cast<float>(a[3]); // d
372 - a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b
373 - a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c
374 - a_inv[3] = inv_det * static_cast<float>(a[0]); // a
375 -
376 - return true;
377 -}
378 -
379 -} // namespace tf_tracking
380 -
381 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// NEON implementations of Image methods for compatible devices. Control
17 -// should never enter this compilation unit on incompatible devices.
18 -
19 -#ifdef __ARM_NEON
20 -
21 -#include <arm_neon.h>
22 -
23 -#include "tensorflow/examples/android/jni/object_tracking/geom.h"
24 -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
25 -#include "tensorflow/examples/android/jni/object_tracking/image.h"
26 -#include "tensorflow/examples/android/jni/object_tracking/utils.h"
27 -
28 -namespace tf_tracking {
29 -
30 -inline static float GetSum(const float32x4_t& values) {
31 - static float32_t summed_values[4];
32 - vst1q_f32(summed_values, values);
33 - return summed_values[0]
34 - + summed_values[1]
35 - + summed_values[2]
36 - + summed_values[3];
37 -}
38 -
39 -
40 -float ComputeMeanNeon(const float* const values, const int num_vals) {
41 - SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
42 -
43 - const float32_t* const arm_vals = (const float32_t* const) values;
44 - float32x4_t accum = vdupq_n_f32(0.0f);
45 -
46 - int offset = 0;
47 - for (; offset <= num_vals - 4; offset += 4) {
48 - accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset]));
49 - }
50 -
51 - // Pull the accumulated values into a single variable.
52 - float sum = GetSum(accum);
53 -
54 - // Get the remaining 1 to 3 values.
55 - for (; offset < num_vals; ++offset) {
56 - sum += values[offset];
57 - }
58 -
59 - const float mean_neon = sum / static_cast<float>(num_vals);
60 -
61 -#ifdef SANITY_CHECKS
62 - const float mean_cpu = ComputeMeanCpu(values, num_vals);
63 - SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals),
64 - "Neon mismatch with CPU mean! %.10f vs %.10f",
65 - mean_neon, mean_cpu);
66 -#endif
67 -
68 - return mean_neon;
69 -}
70 -
71 -
72 -float ComputeStdDevNeon(const float* const values,
73 - const int num_vals, const float mean) {
74 - SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
75 -
76 - const float32_t* const arm_vals = (const float32_t* const) values;
77 - const float32x4_t mean_vec = vdupq_n_f32(-mean);
78 -
79 - float32x4_t accum = vdupq_n_f32(0.0f);
80 -
81 - int offset = 0;
82 - for (; offset <= num_vals - 4; offset += 4) {
83 - const float32x4_t deltas =
84 - vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset]));
85 -
86 - accum = vmlaq_f32(accum, deltas, deltas);
87 - }
88 -
89 - // Pull the accumulated values into a single variable.
90 - float squared_sum = GetSum(accum);
91 -
92 - // Get the remaining 1 to 3 values.
93 - for (; offset < num_vals; ++offset) {
94 - squared_sum += Square(values[offset] - mean);
95 - }
96 -
97 - const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals));
98 -
99 -#ifdef SANITY_CHECKS
100 - const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean);
101 - SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals),
102 - "Neon mismatch with CPU std dev! %.10f vs %.10f",
103 - std_dev_neon, std_dev_cpu);
104 -#endif
105 -
106 - return std_dev_neon;
107 -}
108 -
109 -
110 -float ComputeCrossCorrelationNeon(const float* const values1,
111 - const float* const values2,
112 - const int num_vals) {
113 - SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
114 -
115 - const float32_t* const arm_vals1 = (const float32_t* const) values1;
116 - const float32_t* const arm_vals2 = (const float32_t* const) values2;
117 -
118 - float32x4_t accum = vdupq_n_f32(0.0f);
119 -
120 - int offset = 0;
121 - for (; offset <= num_vals - 4; offset += 4) {
122 - accum = vmlaq_f32(accum,
123 - vld1q_f32(&arm_vals1[offset]),
124 - vld1q_f32(&arm_vals2[offset]));
125 - }
126 -
127 - // Pull the accumulated values into a single variable.
128 - float sxy = GetSum(accum);
129 -
130 - // Get the remaining 1 to 3 values.
131 - for (; offset < num_vals; ++offset) {
132 - sxy += values1[offset] * values2[offset];
133 - }
134 -
135 - const float cross_correlation_neon = sxy / num_vals;
136 -
137 -#ifdef SANITY_CHECKS
138 - const float cross_correlation_cpu =
139 - ComputeCrossCorrelationCpu(values1, values2, num_vals);
140 - SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu,
141 - EPSILON * num_vals),
142 - "Neon mismatch with CPU cross correlation! %.10f vs %.10f",
143 - cross_correlation_neon, cross_correlation_cpu);
144 -#endif
145 -
146 - return cross_correlation_neon;
147 -}
148 -
149 -} // namespace tf_tracking
150 -
151 -#endif // __ARM_NEON
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// These utility functions allow for the conversion of RGB data to YUV data.
17 -
18 -#include "tensorflow/examples/android/jni/rgb2yuv.h"
19 -
20 -static inline void WriteYUV(const int x, const int y, const int width,
21 - const int r8, const int g8, const int b8,
22 - uint8_t* const pY, uint8_t* const pUV) {
23 - // Using formulas from http://msdn.microsoft.com/en-us/library/ms893078
24 - *pY = ((66 * r8 + 129 * g8 + 25 * b8 + 128) >> 8) + 16;
25 -
26 - // Odd widths get rounded up so that UV blocks on the side don't get cut off.
27 - const int blocks_per_row = (width + 1) / 2;
28 -
29 - // 2 bytes per UV block
30 - const int offset = 2 * (((y / 2) * blocks_per_row + (x / 2)));
31 -
32 - // U and V are the average values of all 4 pixels in the block.
33 - if (!(x & 1) && !(y & 1)) {
34 - // Explicitly clear the block if this is the first pixel in it.
35 - pUV[offset] = 0;
36 - pUV[offset + 1] = 0;
37 - }
38 -
39 - // V (with divide by 4 factored in)
40 -#ifdef __APPLE__
41 - const int u_offset = 0;
42 - const int v_offset = 1;
43 -#else
44 - const int u_offset = 1;
45 - const int v_offset = 0;
46 -#endif
47 - pUV[offset + v_offset] += ((112 * r8 - 94 * g8 - 18 * b8 + 128) >> 10) + 32;
48 -
49 - // U (with divide by 4 factored in)
50 - pUV[offset + u_offset] += ((-38 * r8 - 74 * g8 + 112 * b8 + 128) >> 10) + 32;
51 -}
52 -
53 -void ConvertARGB8888ToYUV420SP(const uint32_t* const input,
54 - uint8_t* const output, int width, int height) {
55 - uint8_t* pY = output;
56 - uint8_t* pUV = output + (width * height);
57 - const uint32_t* in = input;
58 -
59 - for (int y = 0; y < height; y++) {
60 - for (int x = 0; x < width; x++) {
61 - const uint32_t rgb = *in++;
62 -#ifdef __APPLE__
63 - const int nB = (rgb >> 8) & 0xFF;
64 - const int nG = (rgb >> 16) & 0xFF;
65 - const int nR = (rgb >> 24) & 0xFF;
66 -#else
67 - const int nR = (rgb >> 16) & 0xFF;
68 - const int nG = (rgb >> 8) & 0xFF;
69 - const int nB = rgb & 0xFF;
70 -#endif
71 - WriteYUV(x, y, width, nR, nG, nB, pY++, pUV);
72 - }
73 - }
74 -}
75 -
76 -void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
77 - const int width, const int height) {
78 - uint8_t* pY = output;
79 - uint8_t* pUV = output + (width * height);
80 - const uint16_t* in = input;
81 -
82 - for (int y = 0; y < height; y++) {
83 - for (int x = 0; x < width; x++) {
84 - const uint32_t rgb = *in++;
85 -
86 - const int r5 = ((rgb >> 11) & 0x1F);
87 - const int g6 = ((rgb >> 5) & 0x3F);
88 - const int b5 = (rgb & 0x1F);
89 -
90 - // Shift left, then fill in the empty low bits with a copy of the high
91 - // bits so we can stretch across the entire 0 - 255 range.
92 - const int r8 = r5 << 3 | r5 >> 2;
93 - const int g8 = g6 << 2 | g6 >> 4;
94 - const int b8 = b5 << 3 | b5 >> 2;
95 -
96 - WriteYUV(x, y, width, r8, g8, b8, pY++, pUV);
97 - }
98 - }
99 -}
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
17 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
18 -
19 -#include <stdint.h>
20 -
21 -#ifdef __cplusplus
22 -extern "C" {
23 -#endif
24 -
25 -void ConvertARGB8888ToYUV420SP(const uint32_t* const input,
26 - uint8_t* const output, int width, int height);
27 -
28 -void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
29 - const int width, const int height);
30 -
31 -#ifdef __cplusplus
32 -}
33 -#endif
34 -
35 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
1 -VERS_1.0 {
2 - # Export JNI symbols.
3 - global:
4 - Java_*;
5 - JNI_OnLoad;
6 - JNI_OnUnload;
7 -
8 - # Hide everything else.
9 - local:
10 - *;
11 -};
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// This is a collection of routines which converts various YUV image formats
17 -// to ARGB.
18 -
19 -#include "tensorflow/examples/android/jni/yuv2rgb.h"
20 -
21 -#ifndef MAX
22 -#define MAX(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a > _b ? _a : _b; })
23 -#define MIN(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a < _b ? _a : _b; })
24 -#endif
25 -
26 -// This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
27 -// are normalized to eight bits.
28 -static const int kMaxChannelValue = 262143;
29 -
30 -static inline uint32_t YUV2RGB(int nY, int nU, int nV) {
31 - nY -= 16;
32 - nU -= 128;
33 - nV -= 128;
34 - if (nY < 0) nY = 0;
35 -
36 - // This is the floating point equivalent. We do the conversion in integer
37 - // because some Android devices do not have floating point in hardware.
38 - // nR = (int)(1.164 * nY + 2.018 * nU);
39 - // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
40 - // nB = (int)(1.164 * nY + 1.596 * nV);
41 -
42 - int nR = 1192 * nY + 1634 * nV;
43 - int nG = 1192 * nY - 833 * nV - 400 * nU;
44 - int nB = 1192 * nY + 2066 * nU;
45 -
46 - nR = MIN(kMaxChannelValue, MAX(0, nR));
47 - nG = MIN(kMaxChannelValue, MAX(0, nG));
48 - nB = MIN(kMaxChannelValue, MAX(0, nB));
49 -
50 - nR = (nR >> 10) & 0xff;
51 - nG = (nG >> 10) & 0xff;
52 - nB = (nB >> 10) & 0xff;
53 -
54 - return 0xff000000 | (nR << 16) | (nG << 8) | nB;
55 -}
56 -
57 -// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by
58 -// separate u and v planes with arbitrary row and column strides,
59 -// containing 8 bit 2x2 subsampled chroma samples.
60 -// Converts to a packed ARGB 32 bit output of the same pixel dimensions.
61 -void ConvertYUV420ToARGB8888(const uint8_t* const yData,
62 - const uint8_t* const uData,
63 - const uint8_t* const vData, uint32_t* const output,
64 - const int width, const int height,
65 - const int y_row_stride, const int uv_row_stride,
66 - const int uv_pixel_stride) {
67 - uint32_t* out = output;
68 -
69 - for (int y = 0; y < height; y++) {
70 - const uint8_t* pY = yData + y_row_stride * y;
71 -
72 - const int uv_row_start = uv_row_stride * (y >> 1);
73 - const uint8_t* pU = uData + uv_row_start;
74 - const uint8_t* pV = vData + uv_row_start;
75 -
76 - for (int x = 0; x < width; x++) {
77 - const int uv_offset = (x >> 1) * uv_pixel_stride;
78 - *out++ = YUV2RGB(pY[x], pU[uv_offset], pV[uv_offset]);
79 - }
80 - }
81 -}
82 -
83 -// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
84 -// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
85 -// except the interleave order of U and V is reversed. Converts to a packed
86 -// ARGB 32 bit output of the same pixel dimensions.
87 -void ConvertYUV420SPToARGB8888(const uint8_t* const yData,
88 - const uint8_t* const uvData,
89 - uint32_t* const output, const int width,
90 - const int height) {
91 - const uint8_t* pY = yData;
92 - const uint8_t* pUV = uvData;
93 - uint32_t* out = output;
94 -
95 - for (int y = 0; y < height; y++) {
96 - for (int x = 0; x < width; x++) {
97 - int nY = *pY++;
98 - int offset = (y >> 1) * width + 2 * (x >> 1);
99 -#ifdef __APPLE__
100 - int nU = pUV[offset];
101 - int nV = pUV[offset + 1];
102 -#else
103 - int nV = pUV[offset];
104 - int nU = pUV[offset + 1];
105 -#endif
106 -
107 - *out++ = YUV2RGB(nY, nU, nV);
108 - }
109 - }
110 -}
111 -
112 -// The same as above, but downsamples each dimension to half size.
113 -void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input,
114 - uint32_t* const output, int width,
115 - int height) {
116 - const uint8_t* pY = input;
117 - const uint8_t* pUV = input + (width * height);
118 - uint32_t* out = output;
119 - int stride = width;
120 - width >>= 1;
121 - height >>= 1;
122 -
123 - for (int y = 0; y < height; y++) {
124 - for (int x = 0; x < width; x++) {
125 - int nY = (pY[0] + pY[1] + pY[stride] + pY[stride + 1]) >> 2;
126 - pY += 2;
127 -#ifdef __APPLE__
128 - int nU = *pUV++;
129 - int nV = *pUV++;
130 -#else
131 - int nV = *pUV++;
132 - int nU = *pUV++;
133 -#endif
134 -
135 - *out++ = YUV2RGB(nY, nU, nV);
136 - }
137 - pY += stride;
138 - }
139 -}
140 -
141 -// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
142 -// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
143 -// except the interleave order of U and V is reversed. Converts to a packed
144 -// RGB 565 bit output of the same pixel dimensions.
145 -void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
146 - const int width, const int height) {
147 - const uint8_t* pY = input;
148 - const uint8_t* pUV = input + (width * height);
149 - uint16_t* out = output;
150 -
151 - for (int y = 0; y < height; y++) {
152 - for (int x = 0; x < width; x++) {
153 - int nY = *pY++;
154 - int offset = (y >> 1) * width + 2 * (x >> 1);
155 -#ifdef __APPLE__
156 - int nU = pUV[offset];
157 - int nV = pUV[offset + 1];
158 -#else
159 - int nV = pUV[offset];
160 - int nU = pUV[offset + 1];
161 -#endif
162 -
163 - nY -= 16;
164 - nU -= 128;
165 - nV -= 128;
166 - if (nY < 0) nY = 0;
167 -
168 - // This is the floating point equivalent. We do the conversion in integer
169 - // because some Android devices do not have floating point in hardware.
170 - // nR = (int)(1.164 * nY + 2.018 * nU);
171 - // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
172 - // nB = (int)(1.164 * nY + 1.596 * nV);
173 -
174 - int nR = 1192 * nY + 1634 * nV;
175 - int nG = 1192 * nY - 833 * nV - 400 * nU;
176 - int nB = 1192 * nY + 2066 * nU;
177 -
178 - nR = MIN(kMaxChannelValue, MAX(0, nR));
179 - nG = MIN(kMaxChannelValue, MAX(0, nG));
180 - nB = MIN(kMaxChannelValue, MAX(0, nB));
181 -
182 - // Shift more than for ARGB8888 and apply appropriate bitmask.
183 - nR = (nR >> 13) & 0x1f;
184 - nG = (nG >> 12) & 0x3f;
185 - nB = (nB >> 13) & 0x1f;
186 -
187 - // R is high 5 bits, G is middle 6 bits, and B is low 5 bits.
188 - *out++ = (nR << 11) | (nG << 5) | nB;
189 - }
190 - }
191 -}
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -// This is a collection of routines which converts various YUV image formats
17 -// to (A)RGB.
18 -
19 -#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
20 -#define TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
21 -
22 -#include <stdint.h>
23 -
24 -#ifdef __cplusplus
25 -extern "C" {
26 -#endif
27 -
28 -void ConvertYUV420ToARGB8888(const uint8_t* const yData,
29 - const uint8_t* const uData,
30 - const uint8_t* const vData, uint32_t* const output,
31 - const int width, const int height,
32 - const int y_row_stride, const int uv_row_stride,
33 - const int uv_pixel_stride);
34 -
35 -// Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
36 -// and height. The input and output must already be allocated and non-null.
37 -// For efficiency, no error checking is performed.
38 -void ConvertYUV420SPToARGB8888(const uint8_t* const pY,
39 - const uint8_t* const pUV, uint32_t* const output,
40 - const int width, const int height);
41 -
42 -// The same as above, but downsamples each dimension to half size.
43 -void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input,
44 - uint32_t* const output, int width,
45 - int height);
46 -
47 -// Converts YUV420 semi-planar data to RGB 565 data using the supplied width
48 -// and height. The input and output must already be allocated and non-null.
49 -// For efficiency, no error checking is performed.
50 -void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
51 - const int width, const int height);
52 -
53 -#ifdef __cplusplus
54 -}
55 -#endif
56 -
57 -#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<set xmlns:android="http://schemas.android.com/apk/res/android"
17 - android:ordering="sequentially">
18 - <objectAnimator
19 - android:propertyName="backgroundColor"
20 - android:duration="375"
21 - android:valueFrom="0x00b3ccff"
22 - android:valueTo="0xffb3ccff"
23 - android:valueType="colorType"/>
24 - <objectAnimator
25 - android:propertyName="backgroundColor"
26 - android:duration="375"
27 - android:valueFrom="0xffb3ccff"
28 - android:valueTo="0x00b3ccff"
29 - android:valueType="colorType"/>
30 -</set>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle" >
17 - <solid android:color="#00000000" />
18 - <stroke android:width="1dip" android:color="#cccccc" />
19 -</shape>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 - xmlns:tools="http://schemas.android.com/tools"
18 - android:id="@+id/container"
19 - android:layout_width="match_parent"
20 - android:layout_height="match_parent"
21 - android:background="#000"
22 - tools:context="org.tensorflow.demo.CameraActivity" />
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<FrameLayout
17 - xmlns:android="http://schemas.android.com/apk/res/android"
18 - xmlns:app="http://schemas.android.com/apk/res-auto"
19 - xmlns:tools="http://schemas.android.com/tools"
20 - android:layout_width="match_parent"
21 - android:layout_height="match_parent"
22 - tools:context="org.tensorflow.demo.SpeechActivity">
23 -
24 - <TextView
25 - android:layout_width="wrap_content"
26 - android:layout_height="wrap_content"
27 - android:text="Say one of the words below!"
28 - android:id="@+id/textView"
29 - android:textAlignment="center"
30 - android:layout_gravity="top"
31 - android:textSize="24dp"
32 - android:layout_marginTop="10dp"
33 - android:layout_marginLeft="10dp"
34 - />
35 -
36 - <ListView
37 - android:id="@+id/list_view"
38 - android:layout_width="240dp"
39 - android:layout_height="wrap_content"
40 - android:background="@drawable/border"
41 - android:layout_gravity="top|center_horizontal"
42 - android:textAlignment="center"
43 - android:layout_marginTop="100dp"
44 - />
45 -
46 - <Button
47 - android:id="@+id/quit"
48 - android:layout_width="wrap_content"
49 - android:layout_height="wrap_content"
50 - android:text="Quit"
51 - android:layout_gravity="bottom|center_horizontal"
52 - android:layout_marginBottom="10dp"
53 - />
54 -
55 -</FrameLayout>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 - android:layout_width="match_parent"
18 - android:layout_height="match_parent">
19 -
20 - <org.tensorflow.demo.AutoFitTextureView
21 - android:id="@+id/texture"
22 - android:layout_width="wrap_content"
23 - android:layout_height="wrap_content"
24 - android:layout_alignParentBottom="true" />
25 -
26 - <org.tensorflow.demo.RecognitionScoreView
27 - android:id="@+id/results"
28 - android:layout_width="match_parent"
29 - android:layout_height="112dp"
30 - android:layout_alignParentTop="true" />
31 -
32 - <org.tensorflow.demo.OverlayView
33 - android:id="@+id/debug_overlay"
34 - android:layout_width="match_parent"
35 - android:layout_height="match_parent"
36 - android:layout_alignParentBottom="true" />
37 -
38 -</RelativeLayout>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 - android:orientation="vertical"
18 - android:layout_width="match_parent"
19 - android:layout_height="match_parent">
20 - <org.tensorflow.demo.AutoFitTextureView
21 - android:id="@+id/texture"
22 - android:layout_width="wrap_content"
23 - android:layout_height="wrap_content"
24 - android:layout_alignParentTop="true" />
25 -
26 - <RelativeLayout
27 - android:id="@+id/black"
28 - android:layout_width="match_parent"
29 - android:layout_height="match_parent"
30 - android:background="#FF000000" />
31 -
32 - <GridView
33 - android:id="@+id/grid_layout"
34 - android:numColumns="7"
35 - android:stretchMode="columnWidth"
36 - android:layout_alignParentBottom="true"
37 - android:layout_width="match_parent"
38 - android:layout_height="wrap_content" />
39 -
40 - <org.tensorflow.demo.OverlayView
41 - android:id="@+id/overlay"
42 - android:layout_width="match_parent"
43 - android:layout_height="match_parent"
44 - android:layout_alignParentTop="true" />
45 -
46 - <org.tensorflow.demo.OverlayView
47 - android:id="@+id/debug_overlay"
48 - android:layout_width="match_parent"
49 - android:layout_height="match_parent"
50 - android:layout_alignParentTop="true" />
51 -</RelativeLayout>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
17 - android:layout_width="match_parent"
18 - android:layout_height="match_parent">
19 -
20 - <org.tensorflow.demo.AutoFitTextureView
21 - android:id="@+id/texture"
22 - android:layout_width="wrap_content"
23 - android:layout_height="wrap_content"/>
24 -
25 - <org.tensorflow.demo.OverlayView
26 - android:id="@+id/tracking_overlay"
27 - android:layout_width="match_parent"
28 - android:layout_height="match_parent"/>
29 -
30 - <org.tensorflow.demo.OverlayView
31 - android:id="@+id/debug_overlay"
32 - android:layout_width="match_parent"
33 - android:layout_height="match_parent"/>
34 -</FrameLayout>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<TextView
17 - xmlns:android="http://schemas.android.com/apk/res/android"
18 - android:id="@+id/list_text_item"
19 - android:layout_width="match_parent"
20 - android:layout_height="wrap_content"
21 - android:text="TextView"
22 - android:textSize="24dp"
23 - android:textAlignment="center"
24 - android:gravity="center_horizontal"
25 - />
1 -<!--
2 - Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 - -->
16 -
17 -<resources>
18 -
19 - <!-- Semantic definitions -->
20 -
21 - <dimen name="horizontal_page_margin">@dimen/margin_huge</dimen>
22 - <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
23 -
24 -</resources>
1 -<!--
2 - Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 - -->
16 -
17 -<resources>
18 -
19 - <style name="Widget.SampleMessage">
20 - <item name="android:textAppearance">?android:textAppearanceLarge</item>
21 - <item name="android:lineSpacingMultiplier">1.2</item>
22 - <item name="android:shadowDy">-6.5</item>
23 - </style>
24 -
25 -</resources>
1 -<?xml version="1.0" encoding="utf-8"?>
2 -<resources>
3 -
4 - <!--
5 - Base application theme for API 11+. This theme completely replaces
6 - AppBaseTheme from res/values/styles.xml on API 11+ devices.
7 - -->
8 - <style name="AppBaseTheme" parent="android:Theme.Holo.Light">
9 - <!-- API 11 theme customizations can go here. -->
10 - </style>
11 -
12 - <style name="FullscreenTheme" parent="android:Theme.Holo">
13 - <item name="android:actionBarStyle">@style/FullscreenActionBarStyle</item>
14 - <item name="android:windowActionBarOverlay">true</item>
15 - <item name="android:windowBackground">@null</item>
16 - <item name="metaButtonBarStyle">?android:attr/buttonBarStyle</item>
17 - <item name="metaButtonBarButtonStyle">?android:attr/buttonBarButtonStyle</item>
18 - </style>
19 -
20 - <style name="FullscreenActionBarStyle" parent="android:Widget.Holo.ActionBar">
21 - <!-- <item name="android:background">@color/black_overlay</item> -->
22 - </style>
23 -
24 -</resources>
1 -<!--
2 - Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 - -->
16 -
17 -<resources>
18 -
19 - <!-- Activity themes -->
20 - <style name="Theme.Base" parent="android:Theme.Holo.Light" />
21 -
22 -</resources>
1 -<resources>
2 -
3 - <!--
4 - Base application theme for API 14+. This theme completely replaces
5 - AppBaseTheme from BOTH res/values/styles.xml and
6 - res/values-v11/styles.xml on API 14+ devices.
7 - -->
8 - <style name="AppBaseTheme" parent="android:Theme.Holo.Light.DarkActionBar">
9 - <!-- API 14 theme customizations can go here. -->
10 - </style>
11 -
12 -</resources>
1 -<?xml version="1.0" encoding="UTF-8"?>
2 -<!--
3 - Copyright 2013 The TensorFlow Authors. All Rights Reserved.
4 -
5 - Licensed under the Apache License, Version 2.0 (the "License");
6 - you may not use this file except in compliance with the License.
7 - You may obtain a copy of the License at
8 -
9 - http://www.apache.org/licenses/LICENSE-2.0
10 -
11 - Unless required by applicable law or agreed to in writing, software
12 - distributed under the License is distributed on an "AS IS" BASIS,
13 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 - See the License for the specific language governing permissions and
15 - limitations under the License.
16 --->
17 -
18 -<resources>
19 -
20 -
21 -</resources>
1 -<?xml version="1.0" encoding="UTF-8"?>
2 -<!--
3 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 -
5 - Licensed under the Apache License, Version 2.0 (the "License");
6 - you may not use this file except in compliance with the License.
7 - You may obtain a copy of the License at
8 -
9 - http://www.apache.org/licenses/LICENSE-2.0
10 -
11 - Unless required by applicable law or agreed to in writing, software
12 - distributed under the License is distributed on an "AS IS" BASIS,
13 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 - See the License for the specific language governing permissions and
15 - limitations under the License.
16 --->
17 -
18 -<resources>
19 -
20 - <!-- Activity themes -->
21 - <style name="Theme.Base" parent="android:Theme.Material.Light">
22 - </style>
23 -
24 -</resources>
1 -<resources>
2 -
3 - <!--
4 - Declare custom theme attributes that allow changing which styles are
5 - used for button bars depending on the API level.
6 - ?android:attr/buttonBarStyle is new as of API 11 so this is
7 - necessary to support previous API levels.
8 - -->
9 - <declare-styleable name="ButtonBarContainerTheme">
10 - <attr name="metaButtonBarStyle" format="reference" />
11 - <attr name="metaButtonBarButtonStyle" format="reference" />
12 - </declare-styleable>
13 -
14 -</resources>
1 -<?xml version="1.0" encoding="UTF-8"?>
2 -<!--
3 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
4 -
5 - Licensed under the Apache License, Version 2.0 (the "License");
6 - you may not use this file except in compliance with the License.
7 - You may obtain a copy of the License at
8 -
9 - http://www.apache.org/licenses/LICENSE-2.0
10 -
11 - Unless required by applicable law or agreed to in writing, software
12 - distributed under the License is distributed on an "AS IS" BASIS,
13 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 - See the License for the specific language governing permissions and
15 - limitations under the License.
16 --->
17 -
18 -<resources>
19 - <string name="app_name">TensorFlow Demo</string>
20 - <string name="activity_name_classification">TF Classify</string>
21 - <string name="activity_name_detection">TF Detect</string>
22 - <string name="activity_name_stylize">TF Stylize</string>
23 - <string name="activity_name_speech">TF Speech</string>
24 -</resources>
1 -<?xml version="1.0" encoding="utf-8"?>
2 -<!--
3 - Copyright 2015 The TensorFlow Authors. All Rights Reserved.
4 -
5 - Licensed under the Apache License, Version 2.0 (the "License");
6 - you may not use this file except in compliance with the License.
7 - You may obtain a copy of the License at
8 -
9 - http://www.apache.org/licenses/LICENSE-2.0
10 -
11 - Unless required by applicable law or agreed to in writing, software
12 - distributed under the License is distributed on an "AS IS" BASIS,
13 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 - See the License for the specific language governing permissions and
15 - limitations under the License.
16 --->
17 -<resources>
18 - <color name="control_background">#cc4285f4</color>
19 -</resources>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<resources>
17 - <string name="description_info">Info</string>
18 - <string name="request_permission">This sample needs camera permission.</string>
19 - <string name="camera_error">This device doesn\'t support Camera2 API.</string>
20 -</resources>
1 -<?xml version="1.0" encoding="utf-8"?><!--
2 - Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 --->
16 -<resources>
17 - <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
18 -</resources>
1 -<!--
2 - Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 - -->
16 -
17 -<resources>
18 -
19 - <!-- Define standard dimensions to comply with Holo-style grids and rhythm. -->
20 -
21 - <dimen name="margin_tiny">4dp</dimen>
22 - <dimen name="margin_small">8dp</dimen>
23 - <dimen name="margin_medium">16dp</dimen>
24 - <dimen name="margin_large">32dp</dimen>
25 - <dimen name="margin_huge">64dp</dimen>
26 -
27 - <!-- Semantic definitions -->
28 -
29 - <dimen name="horizontal_page_margin">@dimen/margin_medium</dimen>
30 - <dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
31 -
32 -</resources>
1 -<!--
2 - Copyright 2013 The TensorFlow Authors. All Rights Reserved.
3 -
4 - Licensed under the Apache License, Version 2.0 (the "License");
5 - you may not use this file except in compliance with the License.
6 - You may obtain a copy of the License at
7 -
8 - http://www.apache.org/licenses/LICENSE-2.0
9 -
10 - Unless required by applicable law or agreed to in writing, software
11 - distributed under the License is distributed on an "AS IS" BASIS,
12 - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - See the License for the specific language governing permissions and
14 - limitations under the License.
15 - -->
16 -
17 -<resources>
18 -
19 - <!-- Activity themes -->
20 -
21 - <style name="Theme.Base" parent="android:Theme.Light" />
22 -
23 - <style name="Theme.Sample" parent="Theme.Base" />
24 -
25 - <style name="AppTheme" parent="Theme.Sample" />
26 - <!-- Widget styling -->
27 -
28 - <style name="Widget" />
29 -
30 - <style name="Widget.SampleMessage">
31 - <item name="android:textAppearance">?android:textAppearanceMedium</item>
32 - <item name="android:lineSpacingMultiplier">1.1</item>
33 - </style>
34 -
35 - <style name="Widget.SampleMessageTile">
36 - <item name="android:background">@drawable/tile</item>
37 - <item name="android:shadowColor">#7F000000</item>
38 - <item name="android:shadowDy">-3.5</item>
39 - <item name="android:shadowRadius">2</item>
40 - </style>
41 -
42 -</resources>
1 -/*
2 - * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.content.Context;
20 -import android.util.AttributeSet;
21 -import android.view.TextureView;
22 -
23 -/**
24 - * A {@link TextureView} that can be adjusted to a specified aspect ratio.
25 - */
26 -public class AutoFitTextureView extends TextureView {
27 - private int ratioWidth = 0;
28 - private int ratioHeight = 0;
29 -
30 - public AutoFitTextureView(final Context context) {
31 - this(context, null);
32 - }
33 -
34 - public AutoFitTextureView(final Context context, final AttributeSet attrs) {
35 - this(context, attrs, 0);
36 - }
37 -
38 - public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) {
39 - super(context, attrs, defStyle);
40 - }
41 -
42 - /**
43 - * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio
44 - * calculated from the parameters. Note that the actual sizes of parameters don't matter, that
45 - * is, calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result.
46 - *
47 - * @param width Relative horizontal size
48 - * @param height Relative vertical size
49 - */
50 - public void setAspectRatio(final int width, final int height) {
51 - if (width < 0 || height < 0) {
52 - throw new IllegalArgumentException("Size cannot be negative.");
53 - }
54 - ratioWidth = width;
55 - ratioHeight = height;
56 - requestLayout();
57 - }
58 -
59 - @Override
60 - protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
61 - super.onMeasure(widthMeasureSpec, heightMeasureSpec);
62 - final int width = MeasureSpec.getSize(widthMeasureSpec);
63 - final int height = MeasureSpec.getSize(heightMeasureSpec);
64 - if (0 == ratioWidth || 0 == ratioHeight) {
65 - setMeasuredDimension(width, height);
66 - } else {
67 - if (width < height * ratioWidth / ratioHeight) {
68 - setMeasuredDimension(width, width * ratioHeight / ratioWidth);
69 - } else {
70 - setMeasuredDimension(height * ratioWidth / ratioHeight, height);
71 - }
72 - }
73 - }
74 -}
1 -/*
2 - * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.Manifest;
20 -import android.app.Activity;
21 -import android.app.Fragment;
22 -import android.content.Context;
23 -import android.content.pm.PackageManager;
24 -import android.hardware.Camera;
25 -import android.hardware.camera2.CameraAccessException;
26 -import android.hardware.camera2.CameraCharacteristics;
27 -import android.hardware.camera2.CameraManager;
28 -import android.hardware.camera2.params.StreamConfigurationMap;
29 -import android.media.Image;
30 -import android.media.Image.Plane;
31 -import android.media.ImageReader;
32 -import android.media.ImageReader.OnImageAvailableListener;
33 -import android.os.Build;
34 -import android.os.Bundle;
35 -import android.os.Handler;
36 -import android.os.HandlerThread;
37 -import android.os.Trace;
38 -import android.util.Size;
39 -import android.view.KeyEvent;
40 -import android.view.Surface;
41 -import android.view.WindowManager;
42 -import android.widget.Toast;
43 -import java.nio.ByteBuffer;
44 -import org.tensorflow.demo.env.ImageUtils;
45 -import org.tensorflow.demo.env.Logger;
46 -import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
47 -
48 -public abstract class CameraActivity extends Activity
49 - implements OnImageAvailableListener, Camera.PreviewCallback {
50 - private static final Logger LOGGER = new Logger();
51 -
52 - private static final int PERMISSIONS_REQUEST = 1;
53 -
54 - private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
55 - private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
56 -
57 - private boolean debug = false;
58 -
59 - private Handler handler;
60 - private HandlerThread handlerThread;
61 - private boolean useCamera2API;
62 - private boolean isProcessingFrame = false;
63 - private byte[][] yuvBytes = new byte[3][];
64 - private int[] rgbBytes = null;
65 - private int yRowStride;
66 -
67 - protected int previewWidth = 0;
68 - protected int previewHeight = 0;
69 -
70 - private Runnable postInferenceCallback;
71 - private Runnable imageConverter;
72 -
73 - @Override
74 - protected void onCreate(final Bundle savedInstanceState) {
75 - LOGGER.d("onCreate " + this);
76 - super.onCreate(null);
77 - getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
78 -
79 - setContentView(R.layout.activity_camera);
80 -
81 - if (hasPermission()) {
82 - setFragment();
83 - } else {
84 - requestPermission();
85 - }
86 - }
87 -
88 - private byte[] lastPreviewFrame;
89 -
90 - protected int[] getRgbBytes() {
91 - imageConverter.run();
92 - return rgbBytes;
93 - }
94 -
95 - protected int getLuminanceStride() {
96 - return yRowStride;
97 - }
98 -
99 - protected byte[] getLuminance() {
100 - return yuvBytes[0];
101 - }
102 -
103 - /**
104 - * Callback for android.hardware.Camera API
105 - */
106 - @Override
107 - public void onPreviewFrame(final byte[] bytes, final Camera camera) {
108 - if (isProcessingFrame) {
109 - LOGGER.w("Dropping frame!");
110 - return;
111 - }
112 -
113 - try {
114 - // Initialize the storage bitmaps once when the resolution is known.
115 - if (rgbBytes == null) {
116 - Camera.Size previewSize = camera.getParameters().getPreviewSize();
117 - previewHeight = previewSize.height;
118 - previewWidth = previewSize.width;
119 - rgbBytes = new int[previewWidth * previewHeight];
120 - onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90);
121 - }
122 - } catch (final Exception e) {
123 - LOGGER.e(e, "Exception!");
124 - return;
125 - }
126 -
127 - isProcessingFrame = true;
128 - lastPreviewFrame = bytes;
129 - yuvBytes[0] = bytes;
130 - yRowStride = previewWidth;
131 -
132 - imageConverter =
133 - new Runnable() {
134 - @Override
135 - public void run() {
136 - ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes);
137 - }
138 - };
139 -
140 - postInferenceCallback =
141 - new Runnable() {
142 - @Override
143 - public void run() {
144 - camera.addCallbackBuffer(bytes);
145 - isProcessingFrame = false;
146 - }
147 - };
148 - processImage();
149 - }
150 -
151 - /**
152 - * Callback for Camera2 API
153 - */
154 - @Override
155 - public void onImageAvailable(final ImageReader reader) {
156 - //We need wait until we have some size from onPreviewSizeChosen
157 - if (previewWidth == 0 || previewHeight == 0) {
158 - return;
159 - }
160 - if (rgbBytes == null) {
161 - rgbBytes = new int[previewWidth * previewHeight];
162 - }
163 - try {
164 - final Image image = reader.acquireLatestImage();
165 -
166 - if (image == null) {
167 - return;
168 - }
169 -
170 - if (isProcessingFrame) {
171 - image.close();
172 - return;
173 - }
174 - isProcessingFrame = true;
175 - Trace.beginSection("imageAvailable");
176 - final Plane[] planes = image.getPlanes();
177 - fillBytes(planes, yuvBytes);
178 - yRowStride = planes[0].getRowStride();
179 - final int uvRowStride = planes[1].getRowStride();
180 - final int uvPixelStride = planes[1].getPixelStride();
181 -
182 - imageConverter =
183 - new Runnable() {
184 - @Override
185 - public void run() {
186 - ImageUtils.convertYUV420ToARGB8888(
187 - yuvBytes[0],
188 - yuvBytes[1],
189 - yuvBytes[2],
190 - previewWidth,
191 - previewHeight,
192 - yRowStride,
193 - uvRowStride,
194 - uvPixelStride,
195 - rgbBytes);
196 - }
197 - };
198 -
199 - postInferenceCallback =
200 - new Runnable() {
201 - @Override
202 - public void run() {
203 - image.close();
204 - isProcessingFrame = false;
205 - }
206 - };
207 -
208 - processImage();
209 - } catch (final Exception e) {
210 - LOGGER.e(e, "Exception!");
211 - Trace.endSection();
212 - return;
213 - }
214 - Trace.endSection();
215 - }
216 -
217 - @Override
218 - public synchronized void onStart() {
219 - LOGGER.d("onStart " + this);
220 - super.onStart();
221 - }
222 -
223 - @Override
224 - public synchronized void onResume() {
225 - LOGGER.d("onResume " + this);
226 - super.onResume();
227 -
228 - handlerThread = new HandlerThread("inference");
229 - handlerThread.start();
230 - handler = new Handler(handlerThread.getLooper());
231 - }
232 -
233 - @Override
234 - public synchronized void onPause() {
235 - LOGGER.d("onPause " + this);
236 -
237 - if (!isFinishing()) {
238 - LOGGER.d("Requesting finish");
239 - finish();
240 - }
241 -
242 - handlerThread.quitSafely();
243 - try {
244 - handlerThread.join();
245 - handlerThread = null;
246 - handler = null;
247 - } catch (final InterruptedException e) {
248 - LOGGER.e(e, "Exception!");
249 - }
250 -
251 - super.onPause();
252 - }
253 -
254 - @Override
255 - public synchronized void onStop() {
256 - LOGGER.d("onStop " + this);
257 - super.onStop();
258 - }
259 -
260 - @Override
261 - public synchronized void onDestroy() {
262 - LOGGER.d("onDestroy " + this);
263 - super.onDestroy();
264 - }
265 -
266 - protected synchronized void runInBackground(final Runnable r) {
267 - if (handler != null) {
268 - handler.post(r);
269 - }
270 - }
271 -
272 - @Override
273 - public void onRequestPermissionsResult(
274 - final int requestCode, final String[] permissions, final int[] grantResults) {
275 - if (requestCode == PERMISSIONS_REQUEST) {
276 - if (grantResults.length > 0
277 - && grantResults[0] == PackageManager.PERMISSION_GRANTED
278 - && grantResults[1] == PackageManager.PERMISSION_GRANTED) {
279 - setFragment();
280 - } else {
281 - requestPermission();
282 - }
283 - }
284 - }
285 -
286 - private boolean hasPermission() {
287 - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
288 - return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED &&
289 - checkSelfPermission(PERMISSION_STORAGE) == PackageManager.PERMISSION_GRANTED;
290 - } else {
291 - return true;
292 - }
293 - }
294 -
295 - private void requestPermission() {
296 - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
297 - if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA) ||
298 - shouldShowRequestPermissionRationale(PERMISSION_STORAGE)) {
299 - Toast.makeText(CameraActivity.this,
300 - "Camera AND storage permission are required for this demo", Toast.LENGTH_LONG).show();
301 - }
302 - requestPermissions(new String[] {PERMISSION_CAMERA, PERMISSION_STORAGE}, PERMISSIONS_REQUEST);
303 - }
304 - }
305 -
306 - // Returns true if the device supports the required hardware level, or better.
307 - private boolean isHardwareLevelSupported(
308 - CameraCharacteristics characteristics, int requiredLevel) {
309 - int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL);
310 - if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) {
311 - return requiredLevel == deviceLevel;
312 - }
313 - // deviceLevel is not LEGACY, can use numerical sort
314 - return requiredLevel <= deviceLevel;
315 - }
316 -
317 - private String chooseCamera() {
318 - final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE);
319 - try {
320 - for (final String cameraId : manager.getCameraIdList()) {
321 - final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
322 -
323 - // We don't use a front facing camera in this sample.
324 - final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING);
325 - if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) {
326 - continue;
327 - }
328 -
329 - final StreamConfigurationMap map =
330 - characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
331 -
332 - if (map == null) {
333 - continue;
334 - }
335 -
336 - // Fallback to camera1 API for internal cameras that don't have full support.
337 - // This should help with legacy situations where using the camera2 API causes
338 - // distorted or otherwise broken previews.
339 - useCamera2API = (facing == CameraCharacteristics.LENS_FACING_EXTERNAL)
340 - || isHardwareLevelSupported(characteristics,
341 - CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL);
342 - LOGGER.i("Camera API lv2?: %s", useCamera2API);
343 - return cameraId;
344 - }
345 - } catch (CameraAccessException e) {
346 - LOGGER.e(e, "Not allowed to access camera");
347 - }
348 -
349 - return null;
350 - }
351 -
352 - protected void setFragment() {
353 - String cameraId = chooseCamera();
354 - if (cameraId == null) {
355 - Toast.makeText(this, "No Camera Detected", Toast.LENGTH_SHORT).show();
356 - finish();
357 - }
358 -
359 - Fragment fragment;
360 - if (useCamera2API) {
361 - CameraConnectionFragment camera2Fragment =
362 - CameraConnectionFragment.newInstance(
363 - new CameraConnectionFragment.ConnectionCallback() {
364 - @Override
365 - public void onPreviewSizeChosen(final Size size, final int rotation) {
366 - previewHeight = size.getHeight();
367 - previewWidth = size.getWidth();
368 - CameraActivity.this.onPreviewSizeChosen(size, rotation);
369 - }
370 - },
371 - this,
372 - getLayoutId(),
373 - getDesiredPreviewFrameSize());
374 -
375 - camera2Fragment.setCamera(cameraId);
376 - fragment = camera2Fragment;
377 - } else {
378 - fragment =
379 - new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize());
380 - }
381 -
382 - getFragmentManager()
383 - .beginTransaction()
384 - .replace(R.id.container, fragment)
385 - .commit();
386 - }
387 -
388 - protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) {
389 - // Because of the variable row stride it's not possible to know in
390 - // advance the actual necessary dimensions of the yuv planes.
391 - for (int i = 0; i < planes.length; ++i) {
392 - final ByteBuffer buffer = planes[i].getBuffer();
393 - if (yuvBytes[i] == null) {
394 - LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity());
395 - yuvBytes[i] = new byte[buffer.capacity()];
396 - }
397 - buffer.get(yuvBytes[i]);
398 - }
399 - }
400 -
401 - public boolean isDebug() {
402 - return debug;
403 - }
404 -
405 - public void requestRender() {
406 - final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay);
407 - if (overlay != null) {
408 - overlay.postInvalidate();
409 - }
410 - }
411 -
412 - public void addCallback(final OverlayView.DrawCallback callback) {
413 - final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay);
414 - if (overlay != null) {
415 - overlay.addCallback(callback);
416 - }
417 - }
418 -
419 - public void onSetDebug(final boolean debug) {}
420 -
421 - @Override
422 - public boolean onKeyDown(final int keyCode, final KeyEvent event) {
423 - if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP
424 - || keyCode == KeyEvent.KEYCODE_BUTTON_L1 || keyCode == KeyEvent.KEYCODE_DPAD_CENTER) {
425 - debug = !debug;
426 - requestRender();
427 - onSetDebug(debug);
428 - return true;
429 - }
430 - return super.onKeyDown(keyCode, event);
431 - }
432 -
433 - protected void readyForNextImage() {
434 - if (postInferenceCallback != null) {
435 - postInferenceCallback.run();
436 - }
437 - }
438 -
439 - protected int getScreenOrientation() {
440 - switch (getWindowManager().getDefaultDisplay().getRotation()) {
441 - case Surface.ROTATION_270:
442 - return 270;
443 - case Surface.ROTATION_180:
444 - return 180;
445 - case Surface.ROTATION_90:
446 - return 90;
447 - default:
448 - return 0;
449 - }
450 - }
451 -
452 - protected abstract void processImage();
453 -
454 - protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
455 - protected abstract int getLayoutId();
456 - protected abstract Size getDesiredPreviewFrameSize();
457 -}
1 -/*
2 - * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.app.Activity;
20 -import android.app.AlertDialog;
21 -import android.app.Dialog;
22 -import android.app.DialogFragment;
23 -import android.app.Fragment;
24 -import android.content.Context;
25 -import android.content.DialogInterface;
26 -import android.content.res.Configuration;
27 -import android.graphics.ImageFormat;
28 -import android.graphics.Matrix;
29 -import android.graphics.RectF;
30 -import android.graphics.SurfaceTexture;
31 -import android.hardware.camera2.CameraAccessException;
32 -import android.hardware.camera2.CameraCaptureSession;
33 -import android.hardware.camera2.CameraCharacteristics;
34 -import android.hardware.camera2.CameraDevice;
35 -import android.hardware.camera2.CameraManager;
36 -import android.hardware.camera2.CaptureRequest;
37 -import android.hardware.camera2.CaptureResult;
38 -import android.hardware.camera2.TotalCaptureResult;
39 -import android.hardware.camera2.params.StreamConfigurationMap;
40 -import android.media.ImageReader;
41 -import android.media.ImageReader.OnImageAvailableListener;
42 -import android.os.Bundle;
43 -import android.os.Handler;
44 -import android.os.HandlerThread;
45 -import android.text.TextUtils;
46 -import android.util.Size;
47 -import android.util.SparseIntArray;
48 -import android.view.LayoutInflater;
49 -import android.view.Surface;
50 -import android.view.TextureView;
51 -import android.view.View;
52 -import android.view.ViewGroup;
53 -import android.widget.Toast;
54 -import java.util.ArrayList;
55 -import java.util.Arrays;
56 -import java.util.Collections;
57 -import java.util.Comparator;
58 -import java.util.List;
59 -import java.util.concurrent.Semaphore;
60 -import java.util.concurrent.TimeUnit;
61 -import org.tensorflow.demo.env.Logger;
62 -import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
63 -
64 -public class CameraConnectionFragment extends Fragment {
65 - private static final Logger LOGGER = new Logger();
66 -
67 - /**
68 - * The camera preview size will be chosen to be the smallest frame by pixel size capable of
69 - * containing a DESIRED_SIZE x DESIRED_SIZE square.
70 - */
71 - private static final int MINIMUM_PREVIEW_SIZE = 320;
72 -
73 - /**
74 - * Conversion from screen rotation to JPEG orientation.
75 - */
76 - private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
77 - private static final String FRAGMENT_DIALOG = "dialog";
78 -
79 - static {
80 - ORIENTATIONS.append(Surface.ROTATION_0, 90);
81 - ORIENTATIONS.append(Surface.ROTATION_90, 0);
82 - ORIENTATIONS.append(Surface.ROTATION_180, 270);
83 - ORIENTATIONS.append(Surface.ROTATION_270, 180);
84 - }
85 -
86 - /**
87 - * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a
88 - * {@link TextureView}.
89 - */
90 - private final TextureView.SurfaceTextureListener surfaceTextureListener =
91 - new TextureView.SurfaceTextureListener() {
92 - @Override
93 - public void onSurfaceTextureAvailable(
94 - final SurfaceTexture texture, final int width, final int height) {
95 - openCamera(width, height);
96 - }
97 -
98 - @Override
99 - public void onSurfaceTextureSizeChanged(
100 - final SurfaceTexture texture, final int width, final int height) {
101 - configureTransform(width, height);
102 - }
103 -
104 - @Override
105 - public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) {
106 - return true;
107 - }
108 -
109 - @Override
110 - public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
111 - };
112 -
113 - /**
114 - * Callback for Activities to use to initialize their data once the
115 - * selected preview size is known.
116 - */
117 - public interface ConnectionCallback {
118 - void onPreviewSizeChosen(Size size, int cameraRotation);
119 - }
120 -
121 - /**
122 - * ID of the current {@link CameraDevice}.
123 - */
124 - private String cameraId;
125 -
126 - /**
127 - * An {@link AutoFitTextureView} for camera preview.
128 - */
129 - private AutoFitTextureView textureView;
130 -
131 - /**
132 - * A {@link CameraCaptureSession } for camera preview.
133 - */
134 - private CameraCaptureSession captureSession;
135 -
136 - /**
137 - * A reference to the opened {@link CameraDevice}.
138 - */
139 - private CameraDevice cameraDevice;
140 -
141 - /**
142 - * The rotation in degrees of the camera sensor from the display.
143 - */
144 - private Integer sensorOrientation;
145 -
146 - /**
147 - * The {@link android.util.Size} of camera preview.
148 - */
149 - private Size previewSize;
150 -
151 - /**
152 - * {@link android.hardware.camera2.CameraDevice.StateCallback}
153 - * is called when {@link CameraDevice} changes its state.
154 - */
155 - private final CameraDevice.StateCallback stateCallback =
156 - new CameraDevice.StateCallback() {
157 - @Override
158 - public void onOpened(final CameraDevice cd) {
159 - // This method is called when the camera is opened. We start camera preview here.
160 - cameraOpenCloseLock.release();
161 - cameraDevice = cd;
162 - createCameraPreviewSession();
163 - }
164 -
165 - @Override
166 - public void onDisconnected(final CameraDevice cd) {
167 - cameraOpenCloseLock.release();
168 - cd.close();
169 - cameraDevice = null;
170 - }
171 -
172 - @Override
173 - public void onError(final CameraDevice cd, final int error) {
174 - cameraOpenCloseLock.release();
175 - cd.close();
176 - cameraDevice = null;
177 - final Activity activity = getActivity();
178 - if (null != activity) {
179 - activity.finish();
180 - }
181 - }
182 - };
183 -
184 - /**
185 - * An additional thread for running tasks that shouldn't block the UI.
186 - */
187 - private HandlerThread backgroundThread;
188 -
189 - /**
190 - * A {@link Handler} for running tasks in the background.
191 - */
192 - private Handler backgroundHandler;
193 -
194 - /**
195 - * An {@link ImageReader} that handles preview frame capture.
196 - */
197 - private ImageReader previewReader;
198 -
199 - /**
200 - * {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview
201 - */
202 - private CaptureRequest.Builder previewRequestBuilder;
203 -
204 - /**
205 - * {@link CaptureRequest} generated by {@link #previewRequestBuilder}
206 - */
207 - private CaptureRequest previewRequest;
208 -
209 - /**
210 - * A {@link Semaphore} to prevent the app from exiting before closing the camera.
211 - */
212 - private final Semaphore cameraOpenCloseLock = new Semaphore(1);
213 -
214 - /**
215 - * A {@link OnImageAvailableListener} to receive frames as they are available.
216 - */
217 - private final OnImageAvailableListener imageListener;
218 -
219 - /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */
220 - private final Size inputSize;
221 -
222 - /**
223 - * The layout identifier to inflate for this Fragment.
224 - */
225 - private final int layout;
226 -
227 -
228 - private final ConnectionCallback cameraConnectionCallback;
229 -
230 - private CameraConnectionFragment(
231 - final ConnectionCallback connectionCallback,
232 - final OnImageAvailableListener imageListener,
233 - final int layout,
234 - final Size inputSize) {
235 - this.cameraConnectionCallback = connectionCallback;
236 - this.imageListener = imageListener;
237 - this.layout = layout;
238 - this.inputSize = inputSize;
239 - }
240 -
241 - /**
242 - * Shows a {@link Toast} on the UI thread.
243 - *
244 - * @param text The message to show
245 - */
246 - private void showToast(final String text) {
247 - final Activity activity = getActivity();
248 - if (activity != null) {
249 - activity.runOnUiThread(
250 - new Runnable() {
251 - @Override
252 - public void run() {
253 - Toast.makeText(activity, text, Toast.LENGTH_SHORT).show();
254 - }
255 - });
256 - }
257 - }
258 -
259 - /**
260 - * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose
261 - * width and height are at least as large as the minimum of both, or an exact match if possible.
262 - *
263 - * @param choices The list of sizes that the camera supports for the intended output class
264 - * @param width The minimum desired width
265 - * @param height The minimum desired height
266 - * @return The optimal {@code Size}, or an arbitrary one if none were big enough
267 - */
268 - protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) {
269 - final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE);
270 - final Size desiredSize = new Size(width, height);
271 -
272 - // Collect the supported resolutions that are at least as big as the preview Surface
273 - boolean exactSizeFound = false;
274 - final List<Size> bigEnough = new ArrayList<Size>();
275 - final List<Size> tooSmall = new ArrayList<Size>();
276 - for (final Size option : choices) {
277 - if (option.equals(desiredSize)) {
278 - // Set the size but don't return yet so that remaining sizes will still be logged.
279 - exactSizeFound = true;
280 - }
281 -
282 - if (option.getHeight() >= minSize && option.getWidth() >= minSize) {
283 - bigEnough.add(option);
284 - } else {
285 - tooSmall.add(option);
286 - }
287 - }
288 -
289 - LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize);
290 - LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]");
291 - LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]");
292 -
293 - if (exactSizeFound) {
294 - LOGGER.i("Exact size match found.");
295 - return desiredSize;
296 - }
297 -
298 - // Pick the smallest of those, assuming we found any
299 - if (bigEnough.size() > 0) {
300 - final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea());
301 - LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight());
302 - return chosenSize;
303 - } else {
304 - LOGGER.e("Couldn't find any suitable preview size");
305 - return choices[0];
306 - }
307 - }
308 -
309 - public static CameraConnectionFragment newInstance(
310 - final ConnectionCallback callback,
311 - final OnImageAvailableListener imageListener,
312 - final int layout,
313 - final Size inputSize) {
314 - return new CameraConnectionFragment(callback, imageListener, layout, inputSize);
315 - }
316 -
317 - @Override
318 - public View onCreateView(
319 - final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) {
320 - return inflater.inflate(layout, container, false);
321 - }
322 -
323 - @Override
324 - public void onViewCreated(final View view, final Bundle savedInstanceState) {
325 - textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
326 - }
327 -
328 - @Override
329 - public void onActivityCreated(final Bundle savedInstanceState) {
330 - super.onActivityCreated(savedInstanceState);
331 - }
332 -
333 - @Override
334 - public void onResume() {
335 - super.onResume();
336 - startBackgroundThread();
337 -
338 - // When the screen is turned off and turned back on, the SurfaceTexture is already
339 - // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
340 - // a camera and start preview from here (otherwise, we wait until the surface is ready in
341 - // the SurfaceTextureListener).
342 - if (textureView.isAvailable()) {
343 - openCamera(textureView.getWidth(), textureView.getHeight());
344 - } else {
345 - textureView.setSurfaceTextureListener(surfaceTextureListener);
346 - }
347 - }
348 -
349 - @Override
350 - public void onPause() {
351 - closeCamera();
352 - stopBackgroundThread();
353 - super.onPause();
354 - }
355 -
356 - public void setCamera(String cameraId) {
357 - this.cameraId = cameraId;
358 - }
359 -
360 - /**
361 - * Sets up member variables related to camera.
362 - */
363 - private void setUpCameraOutputs() {
364 - final Activity activity = getActivity();
365 - final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
366 - try {
367 - final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
368 -
369 - final StreamConfigurationMap map =
370 - characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
371 -
372 - // For still image captures, we use the largest available size.
373 - final Size largest =
374 - Collections.max(
375 - Arrays.asList(map.getOutputSizes(ImageFormat.YUV_420_888)),
376 - new CompareSizesByArea());
377 -
378 - sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION);
379 -
380 - // Danger, W.R.! Attempting to use too large a preview size could exceed the camera
381 - // bus' bandwidth limitation, resulting in gorgeous previews but the storage of
382 - // garbage capture data.
383 - previewSize =
384 - chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class),
385 - inputSize.getWidth(),
386 - inputSize.getHeight());
387 -
388 - // We fit the aspect ratio of TextureView to the size of preview we picked.
389 - final int orientation = getResources().getConfiguration().orientation;
390 - if (orientation == Configuration.ORIENTATION_LANDSCAPE) {
391 - textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight());
392 - } else {
393 - textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth());
394 - }
395 - } catch (final CameraAccessException e) {
396 - LOGGER.e(e, "Exception!");
397 - } catch (final NullPointerException e) {
398 - // Currently an NPE is thrown when the Camera2API is used but not supported on the
399 - // device this code runs.
400 - // TODO(andrewharp): abstract ErrorDialog/RuntimeException handling out into new method and
401 - // reuse throughout app.
402 - ErrorDialog.newInstance(getString(R.string.camera_error))
403 - .show(getChildFragmentManager(), FRAGMENT_DIALOG);
404 - throw new RuntimeException(getString(R.string.camera_error));
405 - }
406 -
407 - cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation);
408 - }
409 -
410 - /**
411 - * Opens the camera specified by {@link CameraConnectionFragment#cameraId}.
412 - */
413 - private void openCamera(final int width, final int height) {
414 - setUpCameraOutputs();
415 - configureTransform(width, height);
416 - final Activity activity = getActivity();
417 - final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
418 - try {
419 - if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
420 - throw new RuntimeException("Time out waiting to lock camera opening.");
421 - }
422 - manager.openCamera(cameraId, stateCallback, backgroundHandler);
423 - } catch (final CameraAccessException e) {
424 - LOGGER.e(e, "Exception!");
425 - } catch (final InterruptedException e) {
426 - throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
427 - }
428 - }
429 -
430 - /**
431 - * Closes the current {@link CameraDevice}.
432 - */
433 - private void closeCamera() {
434 - try {
435 - cameraOpenCloseLock.acquire();
436 - if (null != captureSession) {
437 - captureSession.close();
438 - captureSession = null;
439 - }
440 - if (null != cameraDevice) {
441 - cameraDevice.close();
442 - cameraDevice = null;
443 - }
444 - if (null != previewReader) {
445 - previewReader.close();
446 - previewReader = null;
447 - }
448 - } catch (final InterruptedException e) {
449 - throw new RuntimeException("Interrupted while trying to lock camera closing.", e);
450 - } finally {
451 - cameraOpenCloseLock.release();
452 - }
453 - }
454 -
455 - /**
456 - * Starts a background thread and its {@link Handler}.
457 - */
458 - private void startBackgroundThread() {
459 - backgroundThread = new HandlerThread("ImageListener");
460 - backgroundThread.start();
461 - backgroundHandler = new Handler(backgroundThread.getLooper());
462 - }
463 -
464 - /**
465 - * Stops the background thread and its {@link Handler}.
466 - */
467 - private void stopBackgroundThread() {
468 - backgroundThread.quitSafely();
469 - try {
470 - backgroundThread.join();
471 - backgroundThread = null;
472 - backgroundHandler = null;
473 - } catch (final InterruptedException e) {
474 - LOGGER.e(e, "Exception!");
475 - }
476 - }
477 -
478 - private final CameraCaptureSession.CaptureCallback captureCallback =
479 - new CameraCaptureSession.CaptureCallback() {
480 - @Override
481 - public void onCaptureProgressed(
482 - final CameraCaptureSession session,
483 - final CaptureRequest request,
484 - final CaptureResult partialResult) {}
485 -
486 - @Override
487 - public void onCaptureCompleted(
488 - final CameraCaptureSession session,
489 - final CaptureRequest request,
490 - final TotalCaptureResult result) {}
491 - };
492 -
493 - /**
494 - * Creates a new {@link CameraCaptureSession} for camera preview.
495 - */
496 - private void createCameraPreviewSession() {
497 - try {
498 - final SurfaceTexture texture = textureView.getSurfaceTexture();
499 - assert texture != null;
500 -
501 - // We configure the size of default buffer to be the size of camera preview we want.
502 - texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());
503 -
504 - // This is the output Surface we need to start preview.
505 - final Surface surface = new Surface(texture);
506 -
507 - // We set up a CaptureRequest.Builder with the output Surface.
508 - previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
509 - previewRequestBuilder.addTarget(surface);
510 -
511 - LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight());
512 -
513 - // Create the reader for the preview frames.
514 - previewReader =
515 - ImageReader.newInstance(
516 - previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
517 -
518 - previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
519 - previewRequestBuilder.addTarget(previewReader.getSurface());
520 -
521 - // Here, we create a CameraCaptureSession for camera preview.
522 - cameraDevice.createCaptureSession(
523 - Arrays.asList(surface, previewReader.getSurface()),
524 - new CameraCaptureSession.StateCallback() {
525 -
526 - @Override
527 - public void onConfigured(final CameraCaptureSession cameraCaptureSession) {
528 - // The camera is already closed
529 - if (null == cameraDevice) {
530 - return;
531 - }
532 -
533 - // When the session is ready, we start displaying the preview.
534 - captureSession = cameraCaptureSession;
535 - try {
536 - // Auto focus should be continuous for camera preview.
537 - previewRequestBuilder.set(
538 - CaptureRequest.CONTROL_AF_MODE,
539 - CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
540 - // Flash is automatically enabled when necessary.
541 - previewRequestBuilder.set(
542 - CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH);
543 -
544 - // Finally, we start displaying the camera preview.
545 - previewRequest = previewRequestBuilder.build();
546 - captureSession.setRepeatingRequest(
547 - previewRequest, captureCallback, backgroundHandler);
548 - } catch (final CameraAccessException e) {
549 - LOGGER.e(e, "Exception!");
550 - }
551 - }
552 -
553 - @Override
554 - public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) {
555 - showToast("Failed");
556 - }
557 - },
558 - null);
559 - } catch (final CameraAccessException e) {
560 - LOGGER.e(e, "Exception!");
561 - }
562 - }
563 -
564 - /**
565 - * Configures the necessary {@link android.graphics.Matrix} transformation to `mTextureView`.
566 - * This method should be called after the camera preview size is determined in
567 - * setUpCameraOutputs and also the size of `mTextureView` is fixed.
568 - *
569 - * @param viewWidth The width of `mTextureView`
570 - * @param viewHeight The height of `mTextureView`
571 - */
572 - private void configureTransform(final int viewWidth, final int viewHeight) {
573 - final Activity activity = getActivity();
574 - if (null == textureView || null == previewSize || null == activity) {
575 - return;
576 - }
577 - final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation();
578 - final Matrix matrix = new Matrix();
579 - final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight);
580 - final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth());
581 - final float centerX = viewRect.centerX();
582 - final float centerY = viewRect.centerY();
583 - if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) {
584 - bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY());
585 - matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL);
586 - final float scale =
587 - Math.max(
588 - (float) viewHeight / previewSize.getHeight(),
589 - (float) viewWidth / previewSize.getWidth());
590 - matrix.postScale(scale, scale, centerX, centerY);
591 - matrix.postRotate(90 * (rotation - 2), centerX, centerY);
592 - } else if (Surface.ROTATION_180 == rotation) {
593 - matrix.postRotate(180, centerX, centerY);
594 - }
595 - textureView.setTransform(matrix);
596 - }
597 -
598 - /**
599 - * Compares two {@code Size}s based on their areas.
600 - */
601 - static class CompareSizesByArea implements Comparator<Size> {
602 - @Override
603 - public int compare(final Size lhs, final Size rhs) {
604 - // We cast here to ensure the multiplications won't overflow
605 - return Long.signum(
606 - (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
607 - }
608 - }
609 -
610 - /**
611 - * Shows an error message dialog.
612 - */
613 - public static class ErrorDialog extends DialogFragment {
614 - private static final String ARG_MESSAGE = "message";
615 -
616 - public static ErrorDialog newInstance(final String message) {
617 - final ErrorDialog dialog = new ErrorDialog();
618 - final Bundle args = new Bundle();
619 - args.putString(ARG_MESSAGE, message);
620 - dialog.setArguments(args);
621 - return dialog;
622 - }
623 -
624 - @Override
625 - public Dialog onCreateDialog(final Bundle savedInstanceState) {
626 - final Activity activity = getActivity();
627 - return new AlertDialog.Builder(activity)
628 - .setMessage(getArguments().getString(ARG_MESSAGE))
629 - .setPositiveButton(
630 - android.R.string.ok,
631 - new DialogInterface.OnClickListener() {
632 - @Override
633 - public void onClick(final DialogInterface dialogInterface, final int i) {
634 - activity.finish();
635 - }
636 - })
637 - .create();
638 - }
639 - }
640 -}
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.graphics.Bitmap;
19 -import android.graphics.RectF;
20 -import java.util.List;
21 -
22 -/**
23 - * Generic interface for interacting with different recognition engines.
24 - */
25 -public interface Classifier {
26 - /**
27 - * An immutable result returned by a Classifier describing what was recognized.
28 - */
29 - public class Recognition {
30 - /**
31 - * A unique identifier for what has been recognized. Specific to the class, not the instance of
32 - * the object.
33 - */
34 - private final String id;
35 -
36 - /**
37 - * Display name for the recognition.
38 - */
39 - private final String title;
40 -
41 - /**
42 - * A sortable score for how good the recognition is relative to others. Higher should be better.
43 - */
44 - private final Float confidence;
45 -
46 - /** Optional location within the source image for the location of the recognized object. */
47 - private RectF location;
48 -
49 - public Recognition(
50 - final String id, final String title, final Float confidence, final RectF location) {
51 - this.id = id;
52 - this.title = title;
53 - this.confidence = confidence;
54 - this.location = location;
55 - }
56 -
57 - public String getId() {
58 - return id;
59 - }
60 -
61 - public String getTitle() {
62 - return title;
63 - }
64 -
65 - public Float getConfidence() {
66 - return confidence;
67 - }
68 -
69 - public RectF getLocation() {
70 - return new RectF(location);
71 - }
72 -
73 - public void setLocation(RectF location) {
74 - this.location = location;
75 - }
76 -
77 - @Override
78 - public String toString() {
79 - String resultString = "";
80 - if (id != null) {
81 - resultString += "[" + id + "] ";
82 - }
83 -
84 - if (title != null) {
85 - resultString += title + " ";
86 - }
87 -
88 - if (confidence != null) {
89 - resultString += String.format("(%.1f%%) ", confidence * 100.0f);
90 - }
91 -
92 - if (location != null) {
93 - resultString += location + " ";
94 - }
95 -
96 - return resultString.trim();
97 - }
98 - }
99 -
100 - List<Recognition> recognizeImage(Bitmap bitmap);
101 -
102 - void enableStatLogging(final boolean debug);
103 -
104 - String getStatString();
105 -
106 - void close();
107 -}
1 -/*
2 - * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.graphics.Bitmap;
20 -import android.graphics.Bitmap.Config;
21 -import android.graphics.Canvas;
22 -import android.graphics.Matrix;
23 -import android.graphics.Paint;
24 -import android.graphics.Typeface;
25 -import android.media.ImageReader.OnImageAvailableListener;
26 -import android.os.SystemClock;
27 -import android.util.Size;
28 -import android.util.TypedValue;
29 -import android.view.Display;
30 -import android.view.Surface;
31 -import java.util.List;
32 -import java.util.Vector;
33 -import org.tensorflow.demo.OverlayView.DrawCallback;
34 -import org.tensorflow.demo.env.BorderedText;
35 -import org.tensorflow.demo.env.ImageUtils;
36 -import org.tensorflow.demo.env.Logger;
37 -import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
38 -
39 -public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
40 - private static final Logger LOGGER = new Logger();
41 -
42 - protected static final boolean SAVE_PREVIEW_BITMAP = false;
43 -
44 - private ResultsView resultsView;
45 -
46 - private Bitmap rgbFrameBitmap = null;
47 - private Bitmap croppedBitmap = null;
48 - private Bitmap cropCopyBitmap = null;
49 -
50 - private long lastProcessingTimeMs;
51 -
52 - // These are the settings for the original v1 Inception model. If you want to
53 - // use a model that's been produced from the TensorFlow for Poets codelab,
54 - // you'll need to set IMAGE_SIZE = 299, IMAGE_MEAN = 128, IMAGE_STD = 128,
55 - // INPUT_NAME = "Mul", and OUTPUT_NAME = "final_result".
56 - // You'll also need to update the MODEL_FILE and LABEL_FILE paths to point to
57 - // the ones you produced.
58 - //
59 - // To use v3 Inception model, strip the DecodeJpeg Op from your retrained
60 - // model first:
61 - //
62 - // python strip_unused.py \
63 - // --input_graph=<retrained-pb-file> \
64 - // --output_graph=<your-stripped-pb-file> \
65 - // --input_node_names="Mul" \
66 - // --output_node_names="final_result" \
67 - // --input_binary=true
68 - private static final int INPUT_SIZE = 224;
69 - private static final int IMAGE_MEAN = 117;
70 - private static final float IMAGE_STD = 1;
71 - private static final String INPUT_NAME = "input";
72 - private static final String OUTPUT_NAME = "output";
73 -
74 -
75 - private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
76 - private static final String LABEL_FILE =
77 - "file:///android_asset/imagenet_comp_graph_label_strings.txt";
78 -
79 -
80 - private static final boolean MAINTAIN_ASPECT = true;
81 -
82 - private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
83 -
84 -
85 - private Integer sensorOrientation;
86 - private Classifier classifier;
87 - private Matrix frameToCropTransform;
88 - private Matrix cropToFrameTransform;
89 -
90 -
91 - private BorderedText borderedText;
92 -
93 -
94 - @Override
95 - protected int getLayoutId() {
96 - return R.layout.camera_connection_fragment;
97 - }
98 -
99 - @Override
100 - protected Size getDesiredPreviewFrameSize() {
101 - return DESIRED_PREVIEW_SIZE;
102 - }
103 -
104 - private static final float TEXT_SIZE_DIP = 10;
105 -
106 - @Override
107 - public void onPreviewSizeChosen(final Size size, final int rotation) {
108 - final float textSizePx = TypedValue.applyDimension(
109 - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
110 - borderedText = new BorderedText(textSizePx);
111 - borderedText.setTypeface(Typeface.MONOSPACE);
112 -
113 - classifier =
114 - TensorFlowImageClassifier.create(
115 - getAssets(),
116 - MODEL_FILE,
117 - LABEL_FILE,
118 - INPUT_SIZE,
119 - IMAGE_MEAN,
120 - IMAGE_STD,
121 - INPUT_NAME,
122 - OUTPUT_NAME);
123 -
124 - previewWidth = size.getWidth();
125 - previewHeight = size.getHeight();
126 -
127 - sensorOrientation = rotation - getScreenOrientation();
128 - LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
129 -
130 - LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
131 - rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
132 - croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
133 -
134 - frameToCropTransform = ImageUtils.getTransformationMatrix(
135 - previewWidth, previewHeight,
136 - INPUT_SIZE, INPUT_SIZE,
137 - sensorOrientation, MAINTAIN_ASPECT);
138 -
139 - cropToFrameTransform = new Matrix();
140 - frameToCropTransform.invert(cropToFrameTransform);
141 -
142 - addCallback(
143 - new DrawCallback() {
144 - @Override
145 - public void drawCallback(final Canvas canvas) {
146 - renderDebug(canvas);
147 - }
148 - });
149 - }
150 -
151 - @Override
152 - protected void processImage() {
153 - rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
154 - final Canvas canvas = new Canvas(croppedBitmap);
155 - canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
156 -
157 - // For examining the actual TF input.
158 - if (SAVE_PREVIEW_BITMAP) {
159 - ImageUtils.saveBitmap(croppedBitmap);
160 - }
161 - runInBackground(
162 - new Runnable() {
163 - @Override
164 - public void run() {
165 - final long startTime = SystemClock.uptimeMillis();
166 - final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
167 - lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
168 - LOGGER.i("Detect: %s", results);
169 - cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
170 - if (resultsView == null) {
171 - resultsView = (ResultsView) findViewById(R.id.results);
172 - }
173 - resultsView.setResults(results);
174 - requestRender();
175 - readyForNextImage();
176 - }
177 - });
178 - }
179 -
180 - @Override
181 - public void onSetDebug(boolean debug) {
182 - classifier.enableStatLogging(debug);
183 - }
184 -
185 - private void renderDebug(final Canvas canvas) {
186 - if (!isDebug()) {
187 - return;
188 - }
189 - final Bitmap copy = cropCopyBitmap;
190 - if (copy != null) {
191 - final Matrix matrix = new Matrix();
192 - final float scaleFactor = 2;
193 - matrix.postScale(scaleFactor, scaleFactor);
194 - matrix.postTranslate(
195 - canvas.getWidth() - copy.getWidth() * scaleFactor,
196 - canvas.getHeight() - copy.getHeight() * scaleFactor);
197 - canvas.drawBitmap(copy, matrix, new Paint());
198 -
199 - final Vector<String> lines = new Vector<String>();
200 - if (classifier != null) {
201 - String statString = classifier.getStatString();
202 - String[] statLines = statString.split("\n");
203 - for (String line : statLines) {
204 - lines.add(line);
205 - }
206 - }
207 -
208 - lines.add("Frame: " + previewWidth + "x" + previewHeight);
209 - lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
210 - lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
211 - lines.add("Rotation: " + sensorOrientation);
212 - lines.add("Inference time: " + lastProcessingTimeMs + "ms");
213 -
214 - borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
215 - }
216 - }
217 -}
1 -/*
2 - * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.graphics.Bitmap;
20 -import android.graphics.Bitmap.Config;
21 -import android.graphics.Canvas;
22 -import android.graphics.Color;
23 -import android.graphics.Matrix;
24 -import android.graphics.Paint;
25 -import android.graphics.Paint.Style;
26 -import android.graphics.RectF;
27 -import android.graphics.Typeface;
28 -import android.media.ImageReader.OnImageAvailableListener;
29 -import android.os.SystemClock;
30 -import android.util.Size;
31 -import android.util.TypedValue;
32 -import android.view.Display;
33 -import android.view.Surface;
34 -import android.widget.Toast;
35 -import java.io.IOException;
36 -import java.util.LinkedList;
37 -import java.util.List;
38 -import java.util.Vector;
39 -import org.tensorflow.demo.OverlayView.DrawCallback;
40 -import org.tensorflow.demo.env.BorderedText;
41 -import org.tensorflow.demo.env.ImageUtils;
42 -import org.tensorflow.demo.env.Logger;
43 -import org.tensorflow.demo.tracking.MultiBoxTracker;
44 -import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
45 -
46 -/**
47 - * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
48 - * objects.
49 - */
50 -public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
51 - private static final Logger LOGGER = new Logger();
52 -
53 - // Configuration values for the prepackaged multibox model.
54 - private static final int MB_INPUT_SIZE = 224;
55 - private static final int MB_IMAGE_MEAN = 128;
56 - private static final float MB_IMAGE_STD = 128;
57 - private static final String MB_INPUT_NAME = "ResizeBilinear";
58 - private static final String MB_OUTPUT_LOCATIONS_NAME = "output_locations/Reshape";
59 - private static final String MB_OUTPUT_SCORES_NAME = "output_scores/Reshape";
60 - private static final String MB_MODEL_FILE = "file:///android_asset/multibox_model.pb";
61 - private static final String MB_LOCATION_FILE =
62 - "file:///android_asset/multibox_location_priors.txt";
63 -
64 - private static final int TF_OD_API_INPUT_SIZE = 300;
65 - private static final String TF_OD_API_MODEL_FILE =
66 - "file:///android_asset/ssd_mobilenet_v1_android_export.pb";
67 - private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
68 -
69 - // Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and
70 - // must be manually placed in the assets/ directory by the user.
71 - // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via
72 - // DarkFlow (https://github.com/thtrieu/darkflow). Sample command:
73 - // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise
74 - private static final String YOLO_MODEL_FILE = "file:///android_asset/yolov3.pb";
75 - private static final int YOLO_INPUT_SIZE = 416;
76 - private static final String YOLO_INPUT_NAME = "input";
77 - private static final String YOLO_OUTPUT_NAMES = "output";
78 - private static final int YOLO_BLOCK_SIZE = 32;
79 -
80 - // Which detection model to use: by default uses Tensorflow Object Detection API frozen
81 - // checkpoints. Optionally use legacy Multibox (trained using an older version of the API)
82 - // or YOLO.
83 - private enum DetectorMode {
84 - TF_OD_API, MULTIBOX, YOLO;
85 - }
86 - private static final DetectorMode MODE = DetectorMode.YOLO;
87 -
88 - // Minimum detection confidence to track a detection.
89 - private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f;
90 - private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f;
91 - private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f;
92 -
93 - private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO;
94 -
95 - private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
96 -
97 - private static final boolean SAVE_PREVIEW_BITMAP = false;
98 - private static final float TEXT_SIZE_DIP = 10;
99 -
100 - private Integer sensorOrientation;
101 -
102 - private Classifier detector;
103 -
104 - private long lastProcessingTimeMs;
105 - private Bitmap rgbFrameBitmap = null;
106 - private Bitmap croppedBitmap = null;
107 - private Bitmap cropCopyBitmap = null;
108 -
109 - private boolean computingDetection = false;
110 -
111 - private long timestamp = 0;
112 -
113 - private Matrix frameToCropTransform;
114 - private Matrix cropToFrameTransform;
115 -
116 - private MultiBoxTracker tracker;
117 -
118 - private byte[] luminanceCopy;
119 -
120 - private BorderedText borderedText;
121 - @Override
122 - public void onPreviewSizeChosen(final Size size, final int rotation) {
123 - final float textSizePx =
124 - TypedValue.applyDimension(
125 - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
126 - borderedText = new BorderedText(textSizePx);
127 - borderedText.setTypeface(Typeface.MONOSPACE);
128 -
129 - tracker = new MultiBoxTracker(this);
130 -
131 - int cropSize = TF_OD_API_INPUT_SIZE;
132 - if (MODE == DetectorMode.YOLO) {
133 - detector =
134 - TensorFlowYoloDetector.create(
135 - getAssets(),
136 - YOLO_MODEL_FILE,
137 - YOLO_INPUT_SIZE,
138 - YOLO_INPUT_NAME,
139 - YOLO_OUTPUT_NAMES,
140 - YOLO_BLOCK_SIZE);
141 - cropSize = YOLO_INPUT_SIZE;
142 - } else if (MODE == DetectorMode.MULTIBOX) {
143 - detector =
144 - TensorFlowMultiBoxDetector.create(
145 - getAssets(),
146 - MB_MODEL_FILE,
147 - MB_LOCATION_FILE,
148 - MB_IMAGE_MEAN,
149 - MB_IMAGE_STD,
150 - MB_INPUT_NAME,
151 - MB_OUTPUT_LOCATIONS_NAME,
152 - MB_OUTPUT_SCORES_NAME);
153 - cropSize = MB_INPUT_SIZE;
154 - } else {
155 - try {
156 - detector = TensorFlowObjectDetectionAPIModel.create(
157 - getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
158 - cropSize = TF_OD_API_INPUT_SIZE;
159 - } catch (final IOException e) {
160 - LOGGER.e(e, "Exception initializing classifier!");
161 - Toast toast =
162 - Toast.makeText(
163 - getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
164 - toast.show();
165 - finish();
166 - }
167 - }
168 -
169 - previewWidth = size.getWidth();
170 - previewHeight = size.getHeight();
171 -
172 - sensorOrientation = rotation - getScreenOrientation();
173 - LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
174 -
175 - LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
176 - rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
177 - croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
178 -
179 - frameToCropTransform =
180 - ImageUtils.getTransformationMatrix(
181 - previewWidth, previewHeight,
182 - cropSize, cropSize,
183 - sensorOrientation, MAINTAIN_ASPECT);
184 -
185 - cropToFrameTransform = new Matrix();
186 - frameToCropTransform.invert(cropToFrameTransform);
187 -
188 - trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
189 - trackingOverlay.addCallback(
190 - new DrawCallback() {
191 - @Override
192 - public void drawCallback(final Canvas canvas) {
193 - tracker.draw(canvas);
194 - if (isDebug()) {
195 - tracker.drawDebug(canvas);
196 - }
197 - }
198 - });
199 -
200 - addCallback(
201 - new DrawCallback() {
202 - @Override
203 - public void drawCallback(final Canvas canvas) {
204 - if (!isDebug()) {
205 - return;
206 - }
207 - final Bitmap copy = cropCopyBitmap;
208 - if (copy == null) {
209 - return;
210 - }
211 -
212 - final int backgroundColor = Color.argb(100, 0, 0, 0);
213 - canvas.drawColor(backgroundColor);
214 -
215 - final Matrix matrix = new Matrix();
216 - final float scaleFactor = 2;
217 - matrix.postScale(scaleFactor, scaleFactor);
218 - matrix.postTranslate(
219 - canvas.getWidth() - copy.getWidth() * scaleFactor,
220 - canvas.getHeight() - copy.getHeight() * scaleFactor);
221 - canvas.drawBitmap(copy, matrix, new Paint());
222 -
223 - final Vector<String> lines = new Vector<String>();
224 - if (detector != null) {
225 - final String statString = detector.getStatString();
226 - final String[] statLines = statString.split("\n");
227 - for (final String line : statLines) {
228 - lines.add(line);
229 - }
230 - }
231 - lines.add("");
232 -
233 - lines.add("Frame: " + previewWidth + "x" + previewHeight);
234 - lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
235 - lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
236 - lines.add("Rotation: " + sensorOrientation);
237 - lines.add("Inference time: " + lastProcessingTimeMs + "ms");
238 -
239 - borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
240 - }
241 - });
242 - }
243 -
244 - OverlayView trackingOverlay;
245 -
246 - @Override
247 - protected void processImage() {
248 - ++timestamp;
249 - final long currTimestamp = timestamp;
250 - byte[] originalLuminance = getLuminance();
251 - tracker.onFrame(
252 - previewWidth,
253 - previewHeight,
254 - getLuminanceStride(),
255 - sensorOrientation,
256 - originalLuminance,
257 - timestamp);
258 - trackingOverlay.postInvalidate();
259 -
260 - // No mutex needed as this method is not reentrant.
261 - if (computingDetection) {
262 - readyForNextImage();
263 - return;
264 - }
265 - computingDetection = true;
266 - LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
267 -
268 - rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
269 -
270 - if (luminanceCopy == null) {
271 - luminanceCopy = new byte[originalLuminance.length];
272 - }
273 - System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length);
274 - readyForNextImage();
275 -
276 - final Canvas canvas = new Canvas(croppedBitmap);
277 - canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
278 - // For examining the actual TF input.
279 - if (SAVE_PREVIEW_BITMAP) {
280 - ImageUtils.saveBitmap(croppedBitmap);
281 - }
282 -
283 - runInBackground(
284 - new Runnable() {
285 - @Override
286 - public void run() {
287 - LOGGER.i("Running detection on image " + currTimestamp);
288 - final long startTime = SystemClock.uptimeMillis();
289 - final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
290 - lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
291 -
292 - cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
293 - final Canvas canvas = new Canvas(cropCopyBitmap);
294 - final Paint paint = new Paint();
295 - paint.setColor(Color.RED);
296 - paint.setStyle(Style.STROKE);
297 - paint.setStrokeWidth(2.0f);
298 -
299 - float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
300 - switch (MODE) {
301 - case TF_OD_API:
302 - minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
303 - break;
304 - case MULTIBOX:
305 - minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX;
306 - break;
307 - case YOLO:
308 - minimumConfidence = MINIMUM_CONFIDENCE_YOLO;
309 - break;
310 - }
311 -
312 - final List<Classifier.Recognition> mappedRecognitions =
313 - new LinkedList<Classifier.Recognition>();
314 -
315 - for (final Classifier.Recognition result : results) {
316 - final RectF location = result.getLocation();
317 - if (location != null && result.getConfidence() >= minimumConfidence) {
318 - canvas.drawRect(location, paint);
319 -
320 - cropToFrameTransform.mapRect(location);
321 - result.setLocation(location);
322 - mappedRecognitions.add(result);
323 - }
324 - }
325 -
326 - tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp);
327 - trackingOverlay.postInvalidate();
328 -
329 - requestRender();
330 - computingDetection = false;
331 - }
332 - });
333 - }
334 -
335 - @Override
336 - protected int getLayoutId() {
337 - return R.layout.camera_connection_fragment_tracking;
338 - }
339 -
340 - @Override
341 - protected Size getDesiredPreviewFrameSize() {
342 - return DESIRED_PREVIEW_SIZE;
343 - }
344 -
345 - @Override
346 - public void onSetDebug(final boolean debug) {
347 - detector.enableStatLogging(debug);
348 - }
349 -}
1 -package org.tensorflow.demo;
2 -
3 -/*
4 - * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 - *
6 - * Licensed under the Apache License, Version 2.0 (the "License");
7 - * you may not use this file except in compliance with the License.
8 - * You may obtain a copy of the License at
9 - *
10 - * http://www.apache.org/licenses/LICENSE-2.0
11 - *
12 - * Unless required by applicable law or agreed to in writing, software
13 - * distributed under the License is distributed on an "AS IS" BASIS,
14 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 - * See the License for the specific language governing permissions and
16 - * limitations under the License.
17 - */
18 -
19 -import android.app.Fragment;
20 -import android.graphics.SurfaceTexture;
21 -import android.hardware.Camera;
22 -import android.hardware.Camera.CameraInfo;
23 -import android.os.Bundle;
24 -import android.os.Handler;
25 -import android.os.HandlerThread;
26 -import android.util.Size;
27 -import android.util.SparseIntArray;
28 -import android.view.LayoutInflater;
29 -import android.view.Surface;
30 -import android.view.TextureView;
31 -import android.view.View;
32 -import android.view.ViewGroup;
33 -import java.io.IOException;
34 -import java.util.List;
35 -import org.tensorflow.demo.env.ImageUtils;
36 -import org.tensorflow.demo.env.Logger;
37 -import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
38 -
39 -public class LegacyCameraConnectionFragment extends Fragment {
40 - private Camera camera;
41 - private static final Logger LOGGER = new Logger();
42 - private Camera.PreviewCallback imageListener;
43 - private Size desiredSize;
44 -
45 - /**
46 - * The layout identifier to inflate for this Fragment.
47 - */
48 - private int layout;
49 -
50 - public LegacyCameraConnectionFragment(
51 - final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) {
52 - this.imageListener = imageListener;
53 - this.layout = layout;
54 - this.desiredSize = desiredSize;
55 - }
56 -
57 - /**
58 - * Conversion from screen rotation to JPEG orientation.
59 - */
60 - private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
61 -
62 - static {
63 - ORIENTATIONS.append(Surface.ROTATION_0, 90);
64 - ORIENTATIONS.append(Surface.ROTATION_90, 0);
65 - ORIENTATIONS.append(Surface.ROTATION_180, 270);
66 - ORIENTATIONS.append(Surface.ROTATION_270, 180);
67 - }
68 -
69 - /**
70 - * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a
71 - * {@link TextureView}.
72 - */
73 - private final TextureView.SurfaceTextureListener surfaceTextureListener =
74 - new TextureView.SurfaceTextureListener() {
75 - @Override
76 - public void onSurfaceTextureAvailable(
77 - final SurfaceTexture texture, final int width, final int height) {
78 -
79 - int index = getCameraId();
80 - camera = Camera.open(index);
81 -
82 - try {
83 - Camera.Parameters parameters = camera.getParameters();
84 - List<String> focusModes = parameters.getSupportedFocusModes();
85 - if (focusModes != null
86 - && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) {
87 - parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE);
88 - }
89 - List<Camera.Size> cameraSizes = parameters.getSupportedPreviewSizes();
90 - Size[] sizes = new Size[cameraSizes.size()];
91 - int i = 0;
92 - for (Camera.Size size : cameraSizes) {
93 - sizes[i++] = new Size(size.width, size.height);
94 - }
95 - Size previewSize =
96 - CameraConnectionFragment.chooseOptimalSize(
97 - sizes, desiredSize.getWidth(), desiredSize.getHeight());
98 - parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight());
99 - camera.setDisplayOrientation(90);
100 - camera.setParameters(parameters);
101 - camera.setPreviewTexture(texture);
102 - } catch (IOException exception) {
103 - camera.release();
104 - }
105 -
106 - camera.setPreviewCallbackWithBuffer(imageListener);
107 - Camera.Size s = camera.getParameters().getPreviewSize();
108 - camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]);
109 -
110 - textureView.setAspectRatio(s.height, s.width);
111 -
112 - camera.startPreview();
113 - }
114 -
115 - @Override
116 - public void onSurfaceTextureSizeChanged(
117 - final SurfaceTexture texture, final int width, final int height) {}
118 -
119 - @Override
120 - public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) {
121 - return true;
122 - }
123 -
124 - @Override
125 - public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
126 - };
127 -
128 - /**
129 - * An {@link AutoFitTextureView} for camera preview.
130 - */
131 - private AutoFitTextureView textureView;
132 -
133 - /**
134 - * An additional thread for running tasks that shouldn't block the UI.
135 - */
136 - private HandlerThread backgroundThread;
137 -
138 - @Override
139 - public View onCreateView(
140 - final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) {
141 - return inflater.inflate(layout, container, false);
142 - }
143 -
144 - @Override
145 - public void onViewCreated(final View view, final Bundle savedInstanceState) {
146 - textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
147 - }
148 -
149 - @Override
150 - public void onActivityCreated(final Bundle savedInstanceState) {
151 - super.onActivityCreated(savedInstanceState);
152 - }
153 -
154 - @Override
155 - public void onResume() {
156 - super.onResume();
157 - startBackgroundThread();
158 - // When the screen is turned off and turned back on, the SurfaceTexture is already
159 - // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
160 - // a camera and start preview from here (otherwise, we wait until the surface is ready in
161 - // the SurfaceTextureListener).
162 -
163 - if (textureView.isAvailable()) {
164 - camera.startPreview();
165 - } else {
166 - textureView.setSurfaceTextureListener(surfaceTextureListener);
167 - }
168 - }
169 -
170 - @Override
171 - public void onPause() {
172 - stopCamera();
173 - stopBackgroundThread();
174 - super.onPause();
175 - }
176 -
177 - /**
178 - * Starts a background thread and its {@link Handler}.
179 - */
180 - private void startBackgroundThread() {
181 - backgroundThread = new HandlerThread("CameraBackground");
182 - backgroundThread.start();
183 - }
184 -
185 - /**
186 - * Stops the background thread and its {@link Handler}.
187 - */
188 - private void stopBackgroundThread() {
189 - backgroundThread.quitSafely();
190 - try {
191 - backgroundThread.join();
192 - backgroundThread = null;
193 - } catch (final InterruptedException e) {
194 - LOGGER.e(e, "Exception!");
195 - }
196 - }
197 -
198 - protected void stopCamera() {
199 - if (camera != null) {
200 - camera.stopPreview();
201 - camera.setPreviewCallback(null);
202 - camera.release();
203 - camera = null;
204 - }
205 - }
206 -
207 - private int getCameraId() {
208 - CameraInfo ci = new CameraInfo();
209 - for (int i = 0; i < Camera.getNumberOfCameras(); i++) {
210 - Camera.getCameraInfo(i, ci);
211 - if (ci.facing == CameraInfo.CAMERA_FACING_BACK)
212 - return i;
213 - }
214 - return -1; // No camera found
215 - }
216 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.content.Context;
19 -import android.graphics.Canvas;
20 -import android.util.AttributeSet;
21 -import android.view.View;
22 -import java.util.LinkedList;
23 -import java.util.List;
24 -
25 -/**
26 - * A simple View providing a render callback to other classes.
27 - */
28 -public class OverlayView extends View {
29 - private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>();
30 -
31 - public OverlayView(final Context context, final AttributeSet attrs) {
32 - super(context, attrs);
33 - }
34 -
35 - /**
36 - * Interface defining the callback for client classes.
37 - */
38 - public interface DrawCallback {
39 - public void drawCallback(final Canvas canvas);
40 - }
41 -
42 - public void addCallback(final DrawCallback callback) {
43 - callbacks.add(callback);
44 - }
45 -
46 - @Override
47 - public synchronized void draw(final Canvas canvas) {
48 - for (final DrawCallback callback : callbacks) {
49 - callback.drawCallback(canvas);
50 - }
51 - }
52 -}
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.content.Context;
19 -import android.graphics.Canvas;
20 -import android.graphics.Paint;
21 -import android.util.AttributeSet;
22 -import android.util.TypedValue;
23 -import android.view.View;
24 -
25 -import org.tensorflow.demo.Classifier.Recognition;
26 -
27 -import java.util.List;
28 -
29 -public class RecognitionScoreView extends View implements ResultsView {
30 - private static final float TEXT_SIZE_DIP = 24;
31 - private List<Recognition> results;
32 - private final float textSizePx;
33 - private final Paint fgPaint;
34 - private final Paint bgPaint;
35 -
36 - public RecognitionScoreView(final Context context, final AttributeSet set) {
37 - super(context, set);
38 -
39 - textSizePx =
40 - TypedValue.applyDimension(
41 - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
42 - fgPaint = new Paint();
43 - fgPaint.setTextSize(textSizePx);
44 -
45 - bgPaint = new Paint();
46 - bgPaint.setColor(0xcc4285f4);
47 - }
48 -
49 - @Override
50 - public void setResults(final List<Recognition> results) {
51 - this.results = results;
52 - postInvalidate();
53 - }
54 -
55 - @Override
56 - public void onDraw(final Canvas canvas) {
57 - final int x = 10;
58 - int y = (int) (fgPaint.getTextSize() * 1.5f);
59 -
60 - canvas.drawPaint(bgPaint);
61 -
62 - if (results != null) {
63 - for (final Recognition recog : results) {
64 - canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint);
65 - y += fgPaint.getTextSize() * 1.5f;
66 - }
67 - }
68 - }
69 -}
1 -/*
2 - * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.util.Log;
20 -import android.util.Pair;
21 -import java.util.ArrayDeque;
22 -import java.util.ArrayList;
23 -import java.util.Arrays;
24 -import java.util.Deque;
25 -import java.util.List;
26 -
27 -/** Reads in results from an instantaneous audio recognition model and smoothes them over time. */
28 -public class RecognizeCommands {
29 - // Configuration settings.
30 - private List<String> labels = new ArrayList<String>();
31 - private long averageWindowDurationMs;
32 - private float detectionThreshold;
33 - private int suppressionMs;
34 - private int minimumCount;
35 - private long minimumTimeBetweenSamplesMs;
36 -
37 - // Working variables.
38 - private Deque<Pair<Long, float[]>> previousResults = new ArrayDeque<Pair<Long, float[]>>();
39 - private String previousTopLabel;
40 - private int labelsCount;
41 - private long previousTopLabelTime;
42 - private float previousTopLabelScore;
43 -
44 - private static final String SILENCE_LABEL = "_silence_";
45 - private static final long MINIMUM_TIME_FRACTION = 4;
46 -
47 - public RecognizeCommands(
48 - List<String> inLabels,
49 - long inAverageWindowDurationMs,
50 - float inDetectionThreshold,
51 - int inSuppressionMS,
52 - int inMinimumCount,
53 - long inMinimumTimeBetweenSamplesMS) {
54 - labels = inLabels;
55 - averageWindowDurationMs = inAverageWindowDurationMs;
56 - detectionThreshold = inDetectionThreshold;
57 - suppressionMs = inSuppressionMS;
58 - minimumCount = inMinimumCount;
59 - labelsCount = inLabels.size();
60 - previousTopLabel = SILENCE_LABEL;
61 - previousTopLabelTime = Long.MIN_VALUE;
62 - previousTopLabelScore = 0.0f;
63 - minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS;
64 - }
65 -
66 - /** Holds information about what's been recognized. */
67 - public static class RecognitionResult {
68 - public final String foundCommand;
69 - public final float score;
70 - public final boolean isNewCommand;
71 -
72 - public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) {
73 - foundCommand = inFoundCommand;
74 - score = inScore;
75 - isNewCommand = inIsNewCommand;
76 - }
77 - }
78 -
79 - private static class ScoreForSorting implements Comparable<ScoreForSorting> {
80 - public final float score;
81 - public final int index;
82 -
83 - public ScoreForSorting(float inScore, int inIndex) {
84 - score = inScore;
85 - index = inIndex;
86 - }
87 -
88 - @Override
89 - public int compareTo(ScoreForSorting other) {
90 - if (this.score > other.score) {
91 - return -1;
92 - } else if (this.score < other.score) {
93 - return 1;
94 - } else {
95 - return 0;
96 - }
97 - }
98 - }
99 -
100 - public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) {
101 - if (currentResults.length != labelsCount) {
102 - throw new RuntimeException(
103 - "The results for recognition should contain "
104 - + labelsCount
105 - + " elements, but there are "
106 - + currentResults.length);
107 - }
108 -
109 - if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) {
110 - throw new RuntimeException(
111 - "You must feed results in increasing time order, but received a timestamp of "
112 - + currentTimeMS
113 - + " that was earlier than the previous one of "
114 - + previousResults.getFirst().first);
115 - }
116 -
117 - final int howManyResults = previousResults.size();
118 - // Ignore any results that are coming in too frequently.
119 - if (howManyResults > 1) {
120 - final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first;
121 - if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) {
122 - return new RecognitionResult(previousTopLabel, previousTopLabelScore, false);
123 - }
124 - }
125 -
126 - // Add the latest results to the head of the queue.
127 - previousResults.addLast(new Pair<Long, float[]>(currentTimeMS, currentResults));
128 -
129 - // Prune any earlier results that are too old for the averaging window.
130 - final long timeLimit = currentTimeMS - averageWindowDurationMs;
131 - while (previousResults.getFirst().first < timeLimit) {
132 - previousResults.removeFirst();
133 - }
134 -
135 - // If there are too few results, assume the result will be unreliable and
136 - // bail.
137 - final long earliestTime = previousResults.getFirst().first;
138 - final long samplesDuration = currentTimeMS - earliestTime;
139 - if ((howManyResults < minimumCount)
140 - || (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) {
141 - Log.v("RecognizeResult", "Too few results");
142 - return new RecognitionResult(previousTopLabel, 0.0f, false);
143 - }
144 -
145 - // Calculate the average score across all the results in the window.
146 - float[] averageScores = new float[labelsCount];
147 - for (Pair<Long, float[]> previousResult : previousResults) {
148 - final float[] scoresTensor = previousResult.second;
149 - int i = 0;
150 - while (i < scoresTensor.length) {
151 - averageScores[i] += scoresTensor[i] / howManyResults;
152 - ++i;
153 - }
154 - }
155 -
156 - // Sort the averaged results in descending score order.
157 - ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount];
158 - for (int i = 0; i < labelsCount; ++i) {
159 - sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i);
160 - }
161 - Arrays.sort(sortedAverageScores);
162 -
163 - // See if the latest top score is enough to trigger a detection.
164 - final int currentTopIndex = sortedAverageScores[0].index;
165 - final String currentTopLabel = labels.get(currentTopIndex);
166 - final float currentTopScore = sortedAverageScores[0].score;
167 - // If we've recently had another label trigger, assume one that occurs too
168 - // soon afterwards is a bad result.
169 - long timeSinceLastTop;
170 - if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) {
171 - timeSinceLastTop = Long.MAX_VALUE;
172 - } else {
173 - timeSinceLastTop = currentTimeMS - previousTopLabelTime;
174 - }
175 - boolean isNewCommand;
176 - if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) {
177 - previousTopLabel = currentTopLabel;
178 - previousTopLabelTime = currentTimeMS;
179 - previousTopLabelScore = currentTopScore;
180 - isNewCommand = true;
181 - } else {
182 - isNewCommand = false;
183 - }
184 - return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand);
185 - }
186 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import org.tensorflow.demo.Classifier.Recognition;
19 -
20 -import java.util.List;
21 -
22 -public interface ResultsView {
23 - public void setResults(final List<Recognition> results);
24 -}
1 -/*
2 - * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -/* Demonstrates how to run an audio recognition model in Android.
18 -
19 -This example loads a simple speech recognition model trained by the tutorial at
20 -https://www.tensorflow.org/tutorials/audio_training
21 -
22 -The model files should be downloaded automatically from the TensorFlow website,
23 -but if you have a custom model you can update the LABEL_FILENAME and
24 -MODEL_FILENAME constants to point to your own files.
25 -
26 -The example application displays a list view with all of the known audio labels,
27 -and highlights each one when it thinks it has detected one through the
28 -microphone. The averaging of results to give a more reliable signal happens in
29 -the RecognizeCommands helper class.
30 -*/
31 -
32 -package org.tensorflow.demo;
33 -
34 -import android.animation.AnimatorInflater;
35 -import android.animation.AnimatorSet;
36 -import android.app.Activity;
37 -import android.content.pm.PackageManager;
38 -import android.media.AudioFormat;
39 -import android.media.AudioRecord;
40 -import android.media.MediaRecorder;
41 -import android.os.Build;
42 -import android.os.Bundle;
43 -import android.util.Log;
44 -import android.view.View;
45 -import android.widget.ArrayAdapter;
46 -import android.widget.Button;
47 -import android.widget.ListView;
48 -import java.io.BufferedReader;
49 -import java.io.IOException;
50 -import java.io.InputStreamReader;
51 -import java.util.ArrayList;
52 -import java.util.List;
53 -import java.util.concurrent.locks.ReentrantLock;
54 -import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
55 -import org.tensorflow.demo.R;
56 -
57 -/**
58 - * An activity that listens for audio and then uses a TensorFlow model to detect particular classes,
59 - * by default a small set of action words.
60 - */
61 -public class SpeechActivity extends Activity {
62 -
63 - // Constants that control the behavior of the recognition code and model
64 - // settings. See the audio recognition tutorial for a detailed explanation of
65 - // all these, but you should customize them to match your training settings if
66 - // you are running your own model.
67 - private static final int SAMPLE_RATE = 16000;
68 - private static final int SAMPLE_DURATION_MS = 1000;
69 - private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000);
70 - private static final long AVERAGE_WINDOW_DURATION_MS = 500;
71 - private static final float DETECTION_THRESHOLD = 0.70f;
72 - private static final int SUPPRESSION_MS = 1500;
73 - private static final int MINIMUM_COUNT = 3;
74 - private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30;
75 - private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt";
76 - private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.pb";
77 - private static final String INPUT_DATA_NAME = "decoded_sample_data:0";
78 - private static final String SAMPLE_RATE_NAME = "decoded_sample_data:1";
79 - private static final String OUTPUT_SCORES_NAME = "labels_softmax";
80 -
81 - // UI elements.
82 - private static final int REQUEST_RECORD_AUDIO = 13;
83 - private Button quitButton;
84 - private ListView labelsListView;
85 - private static final String LOG_TAG = SpeechActivity.class.getSimpleName();
86 -
87 - // Working variables.
88 - short[] recordingBuffer = new short[RECORDING_LENGTH];
89 - int recordingOffset = 0;
90 - boolean shouldContinue = true;
91 - private Thread recordingThread;
92 - boolean shouldContinueRecognition = true;
93 - private Thread recognitionThread;
94 - private final ReentrantLock recordingBufferLock = new ReentrantLock();
95 - private TensorFlowInferenceInterface inferenceInterface;
96 - private List<String> labels = new ArrayList<String>();
97 - private List<String> displayedLabels = new ArrayList<>();
98 - private RecognizeCommands recognizeCommands = null;
99 -
100 - @Override
101 - protected void onCreate(Bundle savedInstanceState) {
102 - // Set up the UI.
103 - super.onCreate(savedInstanceState);
104 - setContentView(R.layout.activity_speech);
105 - quitButton = (Button) findViewById(R.id.quit);
106 - quitButton.setOnClickListener(
107 - new View.OnClickListener() {
108 - @Override
109 - public void onClick(View view) {
110 - moveTaskToBack(true);
111 - android.os.Process.killProcess(android.os.Process.myPid());
112 - System.exit(1);
113 - }
114 - });
115 - labelsListView = (ListView) findViewById(R.id.list_view);
116 -
117 - // Load the labels for the model, but only display those that don't start
118 - // with an underscore.
119 - String actualFilename = LABEL_FILENAME.split("file:///android_asset/")[1];
120 - Log.i(LOG_TAG, "Reading labels from: " + actualFilename);
121 - BufferedReader br = null;
122 - try {
123 - br = new BufferedReader(new InputStreamReader(getAssets().open(actualFilename)));
124 - String line;
125 - while ((line = br.readLine()) != null) {
126 - labels.add(line);
127 - if (line.charAt(0) != '_') {
128 - displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1));
129 - }
130 - }
131 - br.close();
132 - } catch (IOException e) {
133 - throw new RuntimeException("Problem reading label file!", e);
134 - }
135 -
136 - // Build a list view based on these labels.
137 - ArrayAdapter<String> arrayAdapter =
138 - new ArrayAdapter<String>(this, R.layout.list_text_item, displayedLabels);
139 - labelsListView.setAdapter(arrayAdapter);
140 -
141 - // Set up an object to smooth recognition results to increase accuracy.
142 - recognizeCommands =
143 - new RecognizeCommands(
144 - labels,
145 - AVERAGE_WINDOW_DURATION_MS,
146 - DETECTION_THRESHOLD,
147 - SUPPRESSION_MS,
148 - MINIMUM_COUNT,
149 - MINIMUM_TIME_BETWEEN_SAMPLES_MS);
150 -
151 - // Load the TensorFlow model.
152 - inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME);
153 -
154 - // Start the recording and recognition threads.
155 - requestMicrophonePermission();
156 - startRecording();
157 - startRecognition();
158 - }
159 -
160 - private void requestMicrophonePermission() {
161 - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
162 - requestPermissions(
163 - new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
164 - }
165 - }
166 -
167 - @Override
168 - public void onRequestPermissionsResult(
169 - int requestCode, String[] permissions, int[] grantResults) {
170 - if (requestCode == REQUEST_RECORD_AUDIO
171 - && grantResults.length > 0
172 - && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
173 - startRecording();
174 - startRecognition();
175 - }
176 - }
177 -
178 - public synchronized void startRecording() {
179 - if (recordingThread != null) {
180 - return;
181 - }
182 - shouldContinue = true;
183 - recordingThread =
184 - new Thread(
185 - new Runnable() {
186 - @Override
187 - public void run() {
188 - record();
189 - }
190 - });
191 - recordingThread.start();
192 - }
193 -
194 - public synchronized void stopRecording() {
195 - if (recordingThread == null) {
196 - return;
197 - }
198 - shouldContinue = false;
199 - recordingThread = null;
200 - }
201 -
202 - private void record() {
203 - android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);
204 -
205 - // Estimate the buffer size we'll need for this device.
206 - int bufferSize =
207 - AudioRecord.getMinBufferSize(
208 - SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT);
209 - if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) {
210 - bufferSize = SAMPLE_RATE * 2;
211 - }
212 - short[] audioBuffer = new short[bufferSize / 2];
213 -
214 - AudioRecord record =
215 - new AudioRecord(
216 - MediaRecorder.AudioSource.DEFAULT,
217 - SAMPLE_RATE,
218 - AudioFormat.CHANNEL_IN_MONO,
219 - AudioFormat.ENCODING_PCM_16BIT,
220 - bufferSize);
221 -
222 - if (record.getState() != AudioRecord.STATE_INITIALIZED) {
223 - Log.e(LOG_TAG, "Audio Record can't initialize!");
224 - return;
225 - }
226 -
227 - record.startRecording();
228 -
229 - Log.v(LOG_TAG, "Start recording");
230 -
231 - // Loop, gathering audio data and copying it to a round-robin buffer.
232 - while (shouldContinue) {
233 - int numberRead = record.read(audioBuffer, 0, audioBuffer.length);
234 - int maxLength = recordingBuffer.length;
235 - int newRecordingOffset = recordingOffset + numberRead;
236 - int secondCopyLength = Math.max(0, newRecordingOffset - maxLength);
237 - int firstCopyLength = numberRead - secondCopyLength;
238 - // We store off all the data for the recognition thread to access. The ML
239 - // thread will copy out of this buffer into its own, while holding the
240 - // lock, so this should be thread safe.
241 - recordingBufferLock.lock();
242 - try {
243 - System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength);
244 - System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength);
245 - recordingOffset = newRecordingOffset % maxLength;
246 - } finally {
247 - recordingBufferLock.unlock();
248 - }
249 - }
250 -
251 - record.stop();
252 - record.release();
253 - }
254 -
255 - public synchronized void startRecognition() {
256 - if (recognitionThread != null) {
257 - return;
258 - }
259 - shouldContinueRecognition = true;
260 - recognitionThread =
261 - new Thread(
262 - new Runnable() {
263 - @Override
264 - public void run() {
265 - recognize();
266 - }
267 - });
268 - recognitionThread.start();
269 - }
270 -
271 - public synchronized void stopRecognition() {
272 - if (recognitionThread == null) {
273 - return;
274 - }
275 - shouldContinueRecognition = false;
276 - recognitionThread = null;
277 - }
278 -
279 - private void recognize() {
280 - Log.v(LOG_TAG, "Start recognition");
281 -
282 - short[] inputBuffer = new short[RECORDING_LENGTH];
283 - float[] floatInputBuffer = new float[RECORDING_LENGTH];
284 - float[] outputScores = new float[labels.size()];
285 - String[] outputScoresNames = new String[] {OUTPUT_SCORES_NAME};
286 - int[] sampleRateList = new int[] {SAMPLE_RATE};
287 -
288 - // Loop, grabbing recorded data and running the recognition model on it.
289 - while (shouldContinueRecognition) {
290 - // The recording thread places data in this round-robin buffer, so lock to
291 - // make sure there's no writing happening and then copy it to our own
292 - // local version.
293 - recordingBufferLock.lock();
294 - try {
295 - int maxLength = recordingBuffer.length;
296 - int firstCopyLength = maxLength - recordingOffset;
297 - int secondCopyLength = recordingOffset;
298 - System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength);
299 - System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength);
300 - } finally {
301 - recordingBufferLock.unlock();
302 - }
303 -
304 - // We need to feed in float values between -1.0f and 1.0f, so divide the
305 - // signed 16-bit inputs.
306 - for (int i = 0; i < RECORDING_LENGTH; ++i) {
307 - floatInputBuffer[i] = inputBuffer[i] / 32767.0f;
308 - }
309 -
310 - // Run the model.
311 - inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList);
312 - inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1);
313 - inferenceInterface.run(outputScoresNames);
314 - inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores);
315 -
316 - // Use the smoother to figure out if we've had a real recognition event.
317 - long currentTime = System.currentTimeMillis();
318 - final RecognizeCommands.RecognitionResult result =
319 - recognizeCommands.processLatestResults(outputScores, currentTime);
320 -
321 - runOnUiThread(
322 - new Runnable() {
323 - @Override
324 - public void run() {
325 - // If we do have a new command, highlight the right list entry.
326 - if (!result.foundCommand.startsWith("_") && result.isNewCommand) {
327 - int labelIndex = -1;
328 - for (int i = 0; i < labels.size(); ++i) {
329 - if (labels.get(i).equals(result.foundCommand)) {
330 - labelIndex = i;
331 - }
332 - }
333 - final View labelView = labelsListView.getChildAt(labelIndex - 2);
334 -
335 - AnimatorSet colorAnimation =
336 - (AnimatorSet)
337 - AnimatorInflater.loadAnimator(
338 - SpeechActivity.this, R.animator.color_animation);
339 - colorAnimation.setTarget(labelView);
340 - colorAnimation.start();
341 - }
342 - }
343 - });
344 - try {
345 - // We don't need to run too frequently, so snooze for a bit.
346 - Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS);
347 - } catch (InterruptedException e) {
348 - // Ignore
349 - }
350 - }
351 -
352 - Log.v(LOG_TAG, "End recognition");
353 - }
354 -}
1 -/*
2 - * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 - *
4 - * Licensed under the Apache License, Version 2.0 (the "License");
5 - * you may not use this file except in compliance with the License.
6 - * You may obtain a copy of the License at
7 - *
8 - * http://www.apache.org/licenses/LICENSE-2.0
9 - *
10 - * Unless required by applicable law or agreed to in writing, software
11 - * distributed under the License is distributed on an "AS IS" BASIS,
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 - * See the License for the specific language governing permissions and
14 - * limitations under the License.
15 - */
16 -
17 -package org.tensorflow.demo;
18 -
19 -import android.app.UiModeManager;
20 -import android.content.Context;
21 -import android.content.res.AssetManager;
22 -import android.content.res.Configuration;
23 -import android.graphics.Bitmap;
24 -import android.graphics.Bitmap.Config;
25 -import android.graphics.BitmapFactory;
26 -import android.graphics.Canvas;
27 -import android.graphics.Color;
28 -import android.graphics.Matrix;
29 -import android.graphics.Paint;
30 -import android.graphics.Paint.Style;
31 -import android.graphics.Rect;
32 -import android.graphics.Typeface;
33 -import android.media.ImageReader.OnImageAvailableListener;
34 -import android.os.Bundle;
35 -import android.os.SystemClock;
36 -import android.util.DisplayMetrics;
37 -import android.util.Size;
38 -import android.util.TypedValue;
39 -import android.view.Display;
40 -import android.view.KeyEvent;
41 -import android.view.MotionEvent;
42 -import android.view.View;
43 -import android.view.View.OnClickListener;
44 -import android.view.View.OnTouchListener;
45 -import android.view.ViewGroup;
46 -import android.widget.BaseAdapter;
47 -import android.widget.Button;
48 -import android.widget.GridView;
49 -import android.widget.ImageView;
50 -import android.widget.RelativeLayout;
51 -import android.widget.Toast;
52 -import java.io.IOException;
53 -import java.io.InputStream;
54 -import java.util.ArrayList;
55 -import java.util.Collections;
56 -import java.util.Vector;
57 -import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
58 -import org.tensorflow.demo.OverlayView.DrawCallback;
59 -import org.tensorflow.demo.env.BorderedText;
60 -import org.tensorflow.demo.env.ImageUtils;
61 -import org.tensorflow.demo.env.Logger;
62 -import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
63 -
64 -/**
65 - * Sample activity that stylizes the camera preview according to "A Learned Representation For
66 - * Artistic Style" (https://arxiv.org/abs/1610.07629)
67 - */
68 -public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
69 - private static final Logger LOGGER = new Logger();
70 -
71 - private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
72 - private static final String INPUT_NODE = "input";
73 - private static final String STYLE_NODE = "style_num";
74 - private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";
75 - private static final int NUM_STYLES = 26;
76 -
77 - private static final boolean SAVE_PREVIEW_BITMAP = false;
78 -
79 - // Whether to actively manipulate non-selected sliders so that sum of activations always appears
80 - // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless.
81 - private static final boolean NORMALIZE_SLIDERS = true;
82 -
83 - private static final float TEXT_SIZE_DIP = 12;
84 -
85 - private static final boolean DEBUG_MODEL = false;
86 -
87 - private static final int[] SIZES = {128, 192, 256, 384, 512, 720};
88 -
89 - private static final Size DESIRED_PREVIEW_SIZE = new Size(1280, 720);
90 -
91 - // Start at a medium size, but let the user step up through smaller sizes so they don't get
92 - // immediately stuck processing a large image.
93 - private int desiredSizeIndex = -1;
94 - private int desiredSize = 256;
95 - private int initializedSize = 0;
96 -
97 - private Integer sensorOrientation;
98 -
99 - private long lastProcessingTimeMs;
100 - private Bitmap rgbFrameBitmap = null;
101 - private Bitmap croppedBitmap = null;
102 - private Bitmap cropCopyBitmap = null;
103 -
104 - private final float[] styleVals = new float[NUM_STYLES];
105 - private int[] intValues;
106 - private float[] floatValues;
107 -
108 - private int frameNum = 0;
109 -
110 - private Bitmap textureCopyBitmap;
111 -
112 - private Matrix frameToCropTransform;
113 - private Matrix cropToFrameTransform;
114 -
115 - private BorderedText borderedText;
116 -
117 - private TensorFlowInferenceInterface inferenceInterface;
118 -
119 - private int lastOtherStyle = 1;
120 -
121 - private boolean allZero = false;
122 -
123 - private ImageGridAdapter adapter;
124 - private GridView grid;
125 -
126 - private final OnTouchListener gridTouchAdapter =
127 - new OnTouchListener() {
128 - ImageSlider slider = null;
129 -
130 - @Override
131 - public boolean onTouch(final View v, final MotionEvent event) {
132 - switch (event.getActionMasked()) {
133 - case MotionEvent.ACTION_DOWN:
134 - for (int i = 0; i < NUM_STYLES; ++i) {
135 - final ImageSlider child = adapter.items[i];
136 - final Rect rect = new Rect();
137 - child.getHitRect(rect);
138 - if (rect.contains((int) event.getX(), (int) event.getY())) {
139 - slider = child;
140 - slider.setHilighted(true);
141 - }
142 - }
143 - break;
144 -
145 - case MotionEvent.ACTION_MOVE:
146 - if (slider != null) {
147 - final Rect rect = new Rect();
148 - slider.getHitRect(rect);
149 -
150 - final float newSliderVal =
151 - (float)
152 - Math.min(
153 - 1.0,
154 - Math.max(
155 - 0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight()));
156 -
157 - setStyle(slider, newSliderVal);
158 - }
159 - break;
160 -
161 - case MotionEvent.ACTION_UP:
162 - if (slider != null) {
163 - slider.setHilighted(false);
164 - slider = null;
165 - }
166 - break;
167 -
168 - default: // fall out
169 -
170 - }
171 - return true;
172 - }
173 - };
174 -
175 - @Override
176 - public void onCreate(final Bundle savedInstanceState) {
177 - super.onCreate(savedInstanceState);
178 - }
179 -
180 - @Override
181 - protected int getLayoutId() {
182 - return R.layout.camera_connection_fragment_stylize;
183 - }
184 -
185 - @Override
186 - protected Size getDesiredPreviewFrameSize() {
187 - return DESIRED_PREVIEW_SIZE;
188 - }
189 -
190 - public static Bitmap getBitmapFromAsset(final Context context, final String filePath) {
191 - final AssetManager assetManager = context.getAssets();
192 -
193 - Bitmap bitmap = null;
194 - try {
195 - final InputStream inputStream = assetManager.open(filePath);
196 - bitmap = BitmapFactory.decodeStream(inputStream);
197 - } catch (final IOException e) {
198 - LOGGER.e("Error opening bitmap!", e);
199 - }
200 -
201 - return bitmap;
202 - }
203 -
204 - private class ImageSlider extends ImageView {
205 - private float value = 0.0f;
206 - private boolean hilighted = false;
207 -
208 - private final Paint boxPaint;
209 - private final Paint linePaint;
210 -
211 - public ImageSlider(final Context context) {
212 - super(context);
213 - value = 0.0f;
214 -
215 - boxPaint = new Paint();
216 - boxPaint.setColor(Color.BLACK);
217 - boxPaint.setAlpha(128);
218 -
219 - linePaint = new Paint();
220 - linePaint.setColor(Color.WHITE);
221 - linePaint.setStrokeWidth(10.0f);
222 - linePaint.setStyle(Style.STROKE);
223 - }
224 -
225 - @Override
226 - public void onDraw(final Canvas canvas) {
227 - super.onDraw(canvas);
228 - final float y = (1.0f - value) * canvas.getHeight();
229 -
230 - // If all sliders are zero, don't bother shading anything.
231 - if (!allZero) {
232 - canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint);
233 - }
234 -
235 - if (value > 0.0f) {
236 - canvas.drawLine(0, y, canvas.getWidth(), y, linePaint);
237 - }
238 -
239 - if (hilighted) {
240 - canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint);
241 - }
242 - }
243 -
244 - @Override
245 - protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
246 - super.onMeasure(widthMeasureSpec, heightMeasureSpec);
247 - setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
248 - }
249 -
250 - public void setValue(final float value) {
251 - this.value = value;
252 - postInvalidate();
253 - }
254 -
255 - public void setHilighted(final boolean highlighted) {
256 - this.hilighted = highlighted;
257 - this.postInvalidate();
258 - }
259 - }
260 -
261 - private class ImageGridAdapter extends BaseAdapter {
262 - final ImageSlider[] items = new ImageSlider[NUM_STYLES];
263 - final ArrayList<Button> buttons = new ArrayList<>();
264 -
265 - {
266 - final Button sizeButton =
267 - new Button(StylizeActivity.this) {
268 - @Override
269 - protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
270 - super.onMeasure(widthMeasureSpec, heightMeasureSpec);
271 - setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
272 - }
273 - };
274 - sizeButton.setText("" + desiredSize);
275 - sizeButton.setOnClickListener(
276 - new OnClickListener() {
277 - @Override
278 - public void onClick(final View v) {
279 - desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length;
280 - desiredSize = SIZES[desiredSizeIndex];
281 - sizeButton.setText("" + desiredSize);
282 - sizeButton.postInvalidate();
283 - }
284 - });
285 -
286 - final Button saveButton =
287 - new Button(StylizeActivity.this) {
288 - @Override
289 - protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
290 - super.onMeasure(widthMeasureSpec, heightMeasureSpec);
291 - setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
292 - }
293 - };
294 - saveButton.setText("save");
295 - saveButton.setTextSize(12);
296 -
297 - saveButton.setOnClickListener(
298 - new OnClickListener() {
299 - @Override
300 - public void onClick(final View v) {
301 - if (textureCopyBitmap != null) {
302 - // TODO(andrewharp): Save as jpeg with guaranteed unique filename.
303 - ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png");
304 - Toast.makeText(
305 - StylizeActivity.this,
306 - "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png",
307 - Toast.LENGTH_LONG)
308 - .show();
309 - }
310 - }
311 - });
312 -
313 - buttons.add(sizeButton);
314 - buttons.add(saveButton);
315 -
316 - for (int i = 0; i < NUM_STYLES; ++i) {
317 - LOGGER.v("Creating item %d", i);
318 -
319 - if (items[i] == null) {
320 - final ImageSlider slider = new ImageSlider(StylizeActivity.this);
321 - final Bitmap bm =
322 - getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg");
323 - slider.setImageBitmap(bm);
324 -
325 - items[i] = slider;
326 - }
327 - }
328 - }
329 -
330 - @Override
331 - public int getCount() {
332 - return buttons.size() + NUM_STYLES;
333 - }
334 -
335 - @Override
336 - public Object getItem(final int position) {
337 - if (position < buttons.size()) {
338 - return buttons.get(position);
339 - } else {
340 - return items[position - buttons.size()];
341 - }
342 - }
343 -
344 - @Override
345 - public long getItemId(final int position) {
346 - return getItem(position).hashCode();
347 - }
348 -
349 - @Override
350 - public View getView(final int position, final View convertView, final ViewGroup parent) {
351 - if (convertView != null) {
352 - return convertView;
353 - }
354 - return (View) getItem(position);
355 - }
356 - }
357 -
358 - @Override
359 - public void onPreviewSizeChosen(final Size size, final int rotation) {
360 - final float textSizePx = TypedValue.applyDimension(
361 - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
362 - borderedText = new BorderedText(textSizePx);
363 - borderedText.setTypeface(Typeface.MONOSPACE);
364 -
365 - inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
366 -
367 - previewWidth = size.getWidth();
368 - previewHeight = size.getHeight();
369 -
370 - final Display display = getWindowManager().getDefaultDisplay();
371 - final int screenOrientation = display.getRotation();
372 -
373 - LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
374 -
375 - sensorOrientation = rotation + screenOrientation;
376 -
377 - addCallback(
378 - new DrawCallback() {
379 - @Override
380 - public void drawCallback(final Canvas canvas) {
381 - renderDebug(canvas);
382 - }
383 - });
384 -
385 - adapter = new ImageGridAdapter();
386 - grid = (GridView) findViewById(R.id.grid_layout);
387 - grid.setAdapter(adapter);
388 - grid.setOnTouchListener(gridTouchAdapter);
389 -
390 - // Change UI on Android TV
391 - UiModeManager uiModeManager = (UiModeManager) getSystemService(UI_MODE_SERVICE);
392 - if (uiModeManager.getCurrentModeType() == Configuration.UI_MODE_TYPE_TELEVISION) {
393 - DisplayMetrics displayMetrics = new DisplayMetrics();
394 - getWindowManager().getDefaultDisplay().getMetrics(displayMetrics);
395 - int styleSelectorHeight = displayMetrics.heightPixels;
396 - int styleSelectorWidth = displayMetrics.widthPixels - styleSelectorHeight;
397 - RelativeLayout.LayoutParams layoutParams = new RelativeLayout.LayoutParams(styleSelectorWidth, ViewGroup.LayoutParams.MATCH_PARENT);
398 -
399 - // Calculate number of style in a row, so all the style can show up without scrolling
400 - int numOfStylePerRow = 3;
401 - while (styleSelectorWidth / numOfStylePerRow * Math.ceil((float) (adapter.getCount() - 2) / numOfStylePerRow) > styleSelectorHeight) {
402 - numOfStylePerRow++;
403 - }
404 - grid.setNumColumns(numOfStylePerRow);
405 - layoutParams.addRule(RelativeLayout.ALIGN_PARENT_RIGHT);
406 - grid.setLayoutParams(layoutParams);
407 - adapter.buttons.clear();
408 - }
409 -
410 - setStyle(adapter.items[0], 1.0f);
411 - }
412 -
413 - private void setStyle(final ImageSlider slider, final float value) {
414 - slider.setValue(value);
415 -
416 - if (NORMALIZE_SLIDERS) {
417 - // Slider vals correspond directly to the input tensor vals, and normalization is visually
418 - // maintained by remanipulating non-selected sliders.
419 - float otherSum = 0.0f;
420 -
421 - for (int i = 0; i < NUM_STYLES; ++i) {
422 - if (adapter.items[i] != slider) {
423 - otherSum += adapter.items[i].value;
424 - }
425 - }
426 -
427 - if (otherSum > 0.0) {
428 - float highestOtherVal = 0;
429 - final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f;
430 - for (int i = 0; i < NUM_STYLES; ++i) {
431 - final ImageSlider child = adapter.items[i];
432 - if (child == slider) {
433 - continue;
434 - }
435 - final float newVal = child.value * factor;
436 - child.setValue(newVal > 0.01f ? newVal : 0.0f);
437 -
438 - if (child.value > highestOtherVal) {
439 - lastOtherStyle = i;
440 - highestOtherVal = child.value;
441 - }
442 - }
443 - } else {
444 - // Everything else is 0, so just pick a suitable slider to push up when the
445 - // selected one goes down.
446 - if (adapter.items[lastOtherStyle] == slider) {
447 - lastOtherStyle = (lastOtherStyle + 1) % NUM_STYLES;
448 - }
449 - adapter.items[lastOtherStyle].setValue(1.0f - value);
450 - }
451 - }
452 -
453 - final boolean lastAllZero = allZero;
454 - float sum = 0.0f;
455 - for (int i = 0; i < NUM_STYLES; ++i) {
456 - sum += adapter.items[i].value;
457 - }
458 - allZero = sum == 0.0f;
459 -
460 - // Now update the values used for the input tensor. If nothing is set, mix in everything
461 - // equally. Otherwise everything is normalized to sum to 1.0.
462 - for (int i = 0; i < NUM_STYLES; ++i) {
463 - styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum;
464 -
465 - if (lastAllZero != allZero) {
466 - adapter.items[i].postInvalidate();
467 - }
468 - }
469 - }
470 -
471 - private void resetPreviewBuffers() {
472 - croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
473 -
474 - frameToCropTransform = ImageUtils.getTransformationMatrix(
475 - previewWidth, previewHeight,
476 - desiredSize, desiredSize,
477 - sensorOrientation, true);
478 -
479 - cropToFrameTransform = new Matrix();
480 - frameToCropTransform.invert(cropToFrameTransform);
481 - intValues = new int[desiredSize * desiredSize];
482 - floatValues = new float[desiredSize * desiredSize * 3];
483 - initializedSize = desiredSize;
484 - }
485 -
486 - @Override
487 - protected void processImage() {
488 - if (desiredSize != initializedSize) {
489 - LOGGER.i(
490 - "Initializing at size preview size %dx%d, stylize size %d",
491 - previewWidth, previewHeight, desiredSize);
492 -
493 - rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
494 - croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
495 - frameToCropTransform = ImageUtils.getTransformationMatrix(
496 - previewWidth, previewHeight,
497 - desiredSize, desiredSize,
498 - sensorOrientation, true);
499 -
500 - cropToFrameTransform = new Matrix();
501 - frameToCropTransform.invert(cropToFrameTransform);
502 - intValues = new int[desiredSize * desiredSize];
503 - floatValues = new float[desiredSize * desiredSize * 3];
504 - initializedSize = desiredSize;
505 - }
506 - rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
507 - final Canvas canvas = new Canvas(croppedBitmap);
508 - canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
509 -
510 - // For examining the actual TF input.
511 - if (SAVE_PREVIEW_BITMAP) {
512 - ImageUtils.saveBitmap(croppedBitmap);
513 - }
514 -
515 - runInBackground(
516 - new Runnable() {
517 - @Override
518 - public void run() {
519 - cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
520 - final long startTime = SystemClock.uptimeMillis();
521 - stylizeImage(croppedBitmap);
522 - lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
523 - textureCopyBitmap = Bitmap.createBitmap(croppedBitmap);
524 - requestRender();
525 - readyForNextImage();
526 - }
527 - });
528 - if (desiredSize != initializedSize) {
529 - resetPreviewBuffers();
530 - }
531 - }
532 -
533 - private void stylizeImage(final Bitmap bitmap) {
534 - ++frameNum;
535 - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
536 -
537 - if (DEBUG_MODEL) {
538 - // Create a white square that steps through a black background 1 pixel per frame.
539 - final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth();
540 - final int centerY = bitmap.getHeight() / 2;
541 - final int squareSize = 10;
542 - for (int i = 0; i < intValues.length; ++i) {
543 - final int x = i % bitmap.getWidth();
544 - final int y = i / bitmap.getHeight();
545 - final float val =
546 - Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f;
547 - floatValues[i * 3] = val;
548 - floatValues[i * 3 + 1] = val;
549 - floatValues[i * 3 + 2] = val;
550 - }
551 - } else {
552 - for (int i = 0; i < intValues.length; ++i) {
553 - final int val = intValues[i];
554 - floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
555 - floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
556 - floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
557 - }
558 - }
559 -
560 - // Copy the input data into TensorFlow.
561 - LOGGER.i("Width: %s , Height: %s", bitmap.getWidth(), bitmap.getHeight());
562 - inferenceInterface.feed(
563 - INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3);
564 - inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);
565 -
566 - inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug());
567 - inferenceInterface.fetch(OUTPUT_NODE, floatValues);
568 -
569 - for (int i = 0; i < intValues.length; ++i) {
570 - intValues[i] =
571 - 0xFF000000
572 - | (((int) (floatValues[i * 3] * 255)) << 16)
573 - | (((int) (floatValues[i * 3 + 1] * 255)) << 8)
574 - | ((int) (floatValues[i * 3 + 2] * 255));
575 - }
576 -
577 - bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
578 - }
579 -
580 - private void renderDebug(final Canvas canvas) {
581 - // TODO(andrewharp): move result display to its own View instead of using debug overlay.
582 - final Bitmap texture = textureCopyBitmap;
583 - if (texture != null) {
584 - final Matrix matrix = new Matrix();
585 - final float scaleFactor =
586 - DEBUG_MODEL
587 - ? 4.0f
588 - : Math.min(
589 - (float) canvas.getWidth() / texture.getWidth(),
590 - (float) canvas.getHeight() / texture.getHeight());
591 - matrix.postScale(scaleFactor, scaleFactor);
592 - canvas.drawBitmap(texture, matrix, new Paint());
593 - }
594 -
595 - if (!isDebug()) {
596 - return;
597 - }
598 -
599 - final Bitmap copy = cropCopyBitmap;
600 - if (copy == null) {
601 - return;
602 - }
603 -
604 - canvas.drawColor(0x55000000);
605 -
606 - final Matrix matrix = new Matrix();
607 - final float scaleFactor = 2;
608 - matrix.postScale(scaleFactor, scaleFactor);
609 - matrix.postTranslate(
610 - canvas.getWidth() - copy.getWidth() * scaleFactor,
611 - canvas.getHeight() - copy.getHeight() * scaleFactor);
612 - canvas.drawBitmap(copy, matrix, new Paint());
613 -
614 - final Vector<String> lines = new Vector<>();
615 -
616 - final String[] statLines = inferenceInterface.getStatString().split("\n");
617 - Collections.addAll(lines, statLines);
618 -
619 - lines.add("");
620 -
621 - lines.add("Frame: " + previewWidth + "x" + previewHeight);
622 - lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
623 - lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
624 - lines.add("Rotation: " + sensorOrientation);
625 - lines.add("Inference time: " + lastProcessingTimeMs + "ms");
626 - lines.add("Desired size: " + desiredSize);
627 - lines.add("Initialized size: " + initializedSize);
628 -
629 - borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
630 - }
631 -
632 - @Override
633 - public boolean onKeyDown(int keyCode, KeyEvent event) {
634 - int moveOffset = 0;
635 - switch (keyCode) {
636 - case KeyEvent.KEYCODE_DPAD_LEFT:
637 - moveOffset = -1;
638 - break;
639 - case KeyEvent.KEYCODE_DPAD_RIGHT:
640 - moveOffset = 1;
641 - break;
642 - case KeyEvent.KEYCODE_DPAD_UP:
643 - moveOffset = -1 * grid.getNumColumns();
644 - break;
645 - case KeyEvent.KEYCODE_DPAD_DOWN:
646 - moveOffset = grid.getNumColumns();
647 - break;
648 - default:
649 - return super.onKeyDown(keyCode, event);
650 - }
651 -
652 - // get the highest selected style
653 - int currentSelect = 0;
654 - float highestValue = 0;
655 - for (int i = 0; i < adapter.getCount(); i++) {
656 - if (adapter.items[i].value > highestValue) {
657 - currentSelect = i;
658 - highestValue = adapter.items[i].value;
659 - }
660 - }
661 - setStyle(adapter.items[(currentSelect + moveOffset + adapter.getCount()) % adapter.getCount()], 1);
662 -
663 - return true;
664 - }
665 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.content.res.AssetManager;
19 -import android.graphics.Bitmap;
20 -import android.os.Trace;
21 -import android.util.Log;
22 -import java.io.BufferedReader;
23 -import java.io.IOException;
24 -import java.io.InputStreamReader;
25 -import java.util.ArrayList;
26 -import java.util.Comparator;
27 -import java.util.List;
28 -import java.util.PriorityQueue;
29 -import java.util.Vector;
30 -import org.tensorflow.Operation;
31 -import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
32 -
33 -/** A classifier specialized to label images using TensorFlow. */
34 -public class TensorFlowImageClassifier implements Classifier {
35 - private static final String TAG = "TensorFlowImageClassifier";
36 -
37 - // Only return this many results with at least this confidence.
38 - private static final int MAX_RESULTS = 3;
39 - private static final float THRESHOLD = 0.1f;
40 -
41 - // Config values.
42 - private String inputName;
43 - private String outputName;
44 - private int inputSize;
45 - private int imageMean;
46 - private float imageStd;
47 -
48 - // Pre-allocated buffers.
49 - private Vector<String> labels = new Vector<String>();
50 - private int[] intValues;
51 - private float[] floatValues;
52 - private float[] outputs;
53 - private String[] outputNames;
54 -
55 - private boolean logStats = false;
56 -
57 - private TensorFlowInferenceInterface inferenceInterface;
58 -
59 - private TensorFlowImageClassifier() {}
60 -
61 - /**
62 - * Initializes a native TensorFlow session for classifying images.
63 - *
64 - * @param assetManager The asset manager to be used to load assets.
65 - * @param modelFilename The filepath of the model GraphDef protocol buffer.
66 - * @param labelFilename The filepath of label file for classes.
67 - * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
68 - * @param imageMean The assumed mean of the image values.
69 - * @param imageStd The assumed std of the image values.
70 - * @param inputName The label of the image input node.
71 - * @param outputName The label of the output node.
72 - * @throws IOException
73 - */
74 - public static Classifier create(
75 - AssetManager assetManager,
76 - String modelFilename,
77 - String labelFilename,
78 - int inputSize,
79 - int imageMean,
80 - float imageStd,
81 - String inputName,
82 - String outputName) {
83 - TensorFlowImageClassifier c = new TensorFlowImageClassifier();
84 - c.inputName = inputName;
85 - c.outputName = outputName;
86 -
87 - // Read the label names into memory.
88 - // TODO(andrewharp): make this handle non-assets.
89 - String actualFilename = labelFilename.split("file:///android_asset/")[1];
90 - Log.i(TAG, "Reading labels from: " + actualFilename);
91 - BufferedReader br = null;
92 - try {
93 - br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
94 - String line;
95 - while ((line = br.readLine()) != null) {
96 - c.labels.add(line);
97 - }
98 - br.close();
99 - } catch (IOException e) {
100 - throw new RuntimeException("Problem reading label file!" , e);
101 - }
102 -
103 - c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
104 -
105 - // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
106 - final Operation operation = c.inferenceInterface.graphOperation(outputName);
107 - final int numClasses = (int) operation.output(0).shape().size(1);
108 - Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
109 -
110 - // Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
111 - // the placeholder node for input in the graphdef typically used does not specify a shape, so it
112 - // must be passed in as a parameter.
113 - c.inputSize = inputSize;
114 - c.imageMean = imageMean;
115 - c.imageStd = imageStd;
116 -
117 - // Pre-allocate buffers.
118 - c.outputNames = new String[] {outputName};
119 - c.intValues = new int[inputSize * inputSize];
120 - c.floatValues = new float[inputSize * inputSize * 3];
121 - c.outputs = new float[numClasses];
122 -
123 - return c;
124 - }
125 -
126 - @Override
127 - public List<Recognition> recognizeImage(final Bitmap bitmap) {
128 - // Log this method so that it can be analyzed with systrace.
129 - Trace.beginSection("recognizeImage");
130 -
131 - Trace.beginSection("preprocessBitmap");
132 - // Preprocess the image data from 0-255 int to normalized float based
133 - // on the provided parameters.
134 - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
135 - for (int i = 0; i < intValues.length; ++i) {
136 - final int val = intValues[i];
137 - floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
138 - floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
139 - floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
140 - }
141 - Trace.endSection();
142 -
143 - // Copy the input data into TensorFlow.
144 - Trace.beginSection("feed");
145 - inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
146 - Trace.endSection();
147 -
148 - // Run the inference call.
149 - Trace.beginSection("run");
150 - inferenceInterface.run(outputNames, logStats);
151 - Trace.endSection();
152 -
153 - // Copy the output Tensor back into the output array.
154 - Trace.beginSection("fetch");
155 - inferenceInterface.fetch(outputName, outputs);
156 - Trace.endSection();
157 -
158 - // Find the best classifications.
159 - PriorityQueue<Recognition> pq =
160 - new PriorityQueue<Recognition>(
161 - 3,
162 - new Comparator<Recognition>() {
163 - @Override
164 - public int compare(Recognition lhs, Recognition rhs) {
165 - // Intentionally reversed to put high confidence at the head of the queue.
166 - return Float.compare(rhs.getConfidence(), lhs.getConfidence());
167 - }
168 - });
169 - for (int i = 0; i < outputs.length; ++i) {
170 - if (outputs[i] > THRESHOLD) {
171 - pq.add(
172 - new Recognition(
173 - "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
174 - }
175 - }
176 - final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
177 - int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
178 - for (int i = 0; i < recognitionsSize; ++i) {
179 - recognitions.add(pq.poll());
180 - }
181 - Trace.endSection(); // "recognizeImage"
182 - return recognitions;
183 - }
184 -
185 - @Override
186 - public void enableStatLogging(boolean logStats) {
187 - this.logStats = logStats;
188 - }
189 -
190 - @Override
191 - public String getStatString() {
192 - return inferenceInterface.getStatString();
193 - }
194 -
195 - @Override
196 - public void close() {
197 - inferenceInterface.close();
198 - }
199 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.content.res.AssetManager;
19 -import android.graphics.Bitmap;
20 -import android.graphics.RectF;
21 -import android.os.Trace;
22 -import java.io.BufferedReader;
23 -import java.io.FileInputStream;
24 -import java.io.IOException;
25 -import java.io.InputStream;
26 -import java.io.InputStreamReader;
27 -import java.util.ArrayList;
28 -import java.util.Comparator;
29 -import java.util.List;
30 -import java.util.PriorityQueue;
31 -import java.util.StringTokenizer;
32 -import org.tensorflow.Graph;
33 -import org.tensorflow.Operation;
34 -import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
35 -import org.tensorflow.demo.env.Logger;
36 -
37 -/**
38 - * A detector for general purpose object detection as described in Scalable Object Detection using
39 - * Deep Neural Networks (https://arxiv.org/abs/1312.2249).
40 - */
41 -public class TensorFlowMultiBoxDetector implements Classifier {
42 - private static final Logger LOGGER = new Logger();
43 -
44 - // Only return this many results.
45 - private static final int MAX_RESULTS = Integer.MAX_VALUE;
46 -
47 - // Config values.
48 - private String inputName;
49 - private int inputSize;
50 - private int imageMean;
51 - private float imageStd;
52 -
53 - // Pre-allocated buffers.
54 - private int[] intValues;
55 - private float[] floatValues;
56 - private float[] outputLocations;
57 - private float[] outputScores;
58 - private String[] outputNames;
59 - private int numLocations;
60 -
61 - private boolean logStats = false;
62 -
63 - private TensorFlowInferenceInterface inferenceInterface;
64 -
65 - private float[] boxPriors;
66 -
67 - /**
68 - * Initializes a native TensorFlow session for classifying images.
69 - *
70 - * @param assetManager The asset manager to be used to load assets.
71 - * @param modelFilename The filepath of the model GraphDef protocol buffer.
72 - * @param locationFilename The filepath of label file for classes.
73 - * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
74 - * @param imageMean The assumed mean of the image values.
75 - * @param imageStd The assumed std of the image values.
76 - * @param inputName The label of the image input node.
77 - * @param outputName The label of the output node.
78 - */
79 - public static Classifier create(
80 - final AssetManager assetManager,
81 - final String modelFilename,
82 - final String locationFilename,
83 - final int imageMean,
84 - final float imageStd,
85 - final String inputName,
86 - final String outputLocationsName,
87 - final String outputScoresName) {
88 - final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
89 -
90 - d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
91 -
92 - final Graph g = d.inferenceInterface.graph();
93 -
94 - d.inputName = inputName;
95 - // The inputName node has a shape of [N, H, W, C], where
96 - // N is the batch size
97 - // H = W are the height and width
98 - // C is the number of channels (3 for our purposes - RGB)
99 - final Operation inputOp = g.operation(inputName);
100 - if (inputOp == null) {
101 - throw new RuntimeException("Failed to find input Node '" + inputName + "'");
102 - }
103 - d.inputSize = (int) inputOp.output(0).shape().size(1);
104 - d.imageMean = imageMean;
105 - d.imageStd = imageStd;
106 - // The outputScoresName node has a shape of [N, NumLocations], where N
107 - // is the batch size.
108 - final Operation outputOp = g.operation(outputScoresName);
109 - if (outputOp == null) {
110 - throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'");
111 - }
112 - d.numLocations = (int) outputOp.output(0).shape().size(1);
113 -
114 - d.boxPriors = new float[d.numLocations * 8];
115 -
116 - try {
117 - d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
118 - } catch (final IOException e) {
119 - throw new RuntimeException("Error initializing box priors from " + locationFilename);
120 - }
121 -
122 - // Pre-allocate buffers.
123 - d.outputNames = new String[] {outputLocationsName, outputScoresName};
124 - d.intValues = new int[d.inputSize * d.inputSize];
125 - d.floatValues = new float[d.inputSize * d.inputSize * 3];
126 - d.outputScores = new float[d.numLocations];
127 - d.outputLocations = new float[d.numLocations * 4];
128 -
129 - return d;
130 - }
131 -
132 - private TensorFlowMultiBoxDetector() {}
133 -
134 - private void loadCoderOptions(
135 - final AssetManager assetManager, final String locationFilename, final float[] boxPriors)
136 - throws IOException {
137 - // Try to be intelligent about opening from assets or sdcard depending on prefix.
138 - final String assetPrefix = "file:///android_asset/";
139 - InputStream is;
140 - if (locationFilename.startsWith(assetPrefix)) {
141 - is = assetManager.open(locationFilename.split(assetPrefix)[1]);
142 - } else {
143 - is = new FileInputStream(locationFilename);
144 - }
145 -
146 - // Read values. Number of values per line doesn't matter, as long as they are separated
147 - // by commas and/or whitespace, and there are exactly numLocations * 8 values total.
148 - // Values are in the order mean, std for each consecutive corner of each box, for a total of 8
149 - // per location.
150 - final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
151 - int priorIndex = 0;
152 - String line;
153 - while ((line = reader.readLine()) != null) {
154 - final StringTokenizer st = new StringTokenizer(line, ", ");
155 - while (st.hasMoreTokens()) {
156 - final String token = st.nextToken();
157 - try {
158 - final float number = Float.parseFloat(token);
159 - boxPriors[priorIndex++] = number;
160 - } catch (final NumberFormatException e) {
161 - // Silently ignore.
162 - }
163 - }
164 - }
165 - if (priorIndex != boxPriors.length) {
166 - throw new RuntimeException(
167 - "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length);
168 - }
169 - }
170 -
171 - private float[] decodeLocationsEncoding(final float[] locationEncoding) {
172 - final float[] locations = new float[locationEncoding.length];
173 - boolean nonZero = false;
174 - for (int i = 0; i < numLocations; ++i) {
175 - for (int j = 0; j < 4; ++j) {
176 - final float currEncoding = locationEncoding[4 * i + j];
177 - nonZero = nonZero || currEncoding != 0.0f;
178 -
179 - final float mean = boxPriors[i * 8 + j * 2];
180 - final float stdDev = boxPriors[i * 8 + j * 2 + 1];
181 - float currentLocation = currEncoding * stdDev + mean;
182 - currentLocation = Math.max(currentLocation, 0.0f);
183 - currentLocation = Math.min(currentLocation, 1.0f);
184 - locations[4 * i + j] = currentLocation;
185 - }
186 - }
187 -
188 - if (!nonZero) {
189 - LOGGER.w("No non-zero encodings; check log for inference errors.");
190 - }
191 - return locations;
192 - }
193 -
194 - private float[] decodeScoresEncoding(final float[] scoresEncoding) {
195 - final float[] scores = new float[scoresEncoding.length];
196 - for (int i = 0; i < scoresEncoding.length; ++i) {
197 - scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i])));
198 - }
199 - return scores;
200 - }
201 -
202 - @Override
203 - public List<Recognition> recognizeImage(final Bitmap bitmap) {
204 - // Log this method so that it can be analyzed with systrace.
205 - Trace.beginSection("recognizeImage");
206 -
207 - Trace.beginSection("preprocessBitmap");
208 - // Preprocess the image data from 0-255 int to normalized float based
209 - // on the provided parameters.
210 - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
211 -
212 - for (int i = 0; i < intValues.length; ++i) {
213 - floatValues[i * 3 + 0] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd;
214 - floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd;
215 - floatValues[i * 3 + 2] = ((intValues[i] & 0xFF) - imageMean) / imageStd;
216 - }
217 - Trace.endSection(); // preprocessBitmap
218 -
219 - // Copy the input data into TensorFlow.
220 - Trace.beginSection("feed");
221 - inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
222 - Trace.endSection();
223 -
224 - // Run the inference call.
225 - Trace.beginSection("run");
226 - inferenceInterface.run(outputNames, logStats);
227 - Trace.endSection();
228 -
229 - // Copy the output Tensor back into the output array.
230 - Trace.beginSection("fetch");
231 - final float[] outputScoresEncoding = new float[numLocations];
232 - final float[] outputLocationsEncoding = new float[numLocations * 4];
233 - inferenceInterface.fetch(outputNames[0], outputLocationsEncoding);
234 - inferenceInterface.fetch(outputNames[1], outputScoresEncoding);
235 - Trace.endSection();
236 -
237 - outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
238 - outputScores = decodeScoresEncoding(outputScoresEncoding);
239 -
240 - // Find the best detections.
241 - final PriorityQueue<Recognition> pq =
242 - new PriorityQueue<Recognition>(
243 - 1,
244 - new Comparator<Recognition>() {
245 - @Override
246 - public int compare(final Recognition lhs, final Recognition rhs) {
247 - // Intentionally reversed to put high confidence at the head of the queue.
248 - return Float.compare(rhs.getConfidence(), lhs.getConfidence());
249 - }
250 - });
251 -
252 - // Scale them back to the input size.
253 - for (int i = 0; i < outputScores.length; ++i) {
254 - final RectF detection =
255 - new RectF(
256 - outputLocations[4 * i] * inputSize,
257 - outputLocations[4 * i + 1] * inputSize,
258 - outputLocations[4 * i + 2] * inputSize,
259 - outputLocations[4 * i + 3] * inputSize);
260 - pq.add(new Recognition("" + i, null, outputScores[i], detection));
261 - }
262 -
263 - final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
264 - for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
265 - recognitions.add(pq.poll());
266 - }
267 - Trace.endSection(); // "recognizeImage"
268 - return recognitions;
269 - }
270 -
271 - @Override
272 - public void enableStatLogging(final boolean logStats) {
273 - this.logStats = logStats;
274 - }
275 -
276 - @Override
277 - public String getStatString() {
278 - return inferenceInterface.getStatString();
279 - }
280 -
281 - @Override
282 - public void close() {
283 - inferenceInterface.close();
284 - }
285 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.content.res.AssetManager;
19 -import android.graphics.Bitmap;
20 -import android.graphics.RectF;
21 -import android.os.Trace;
22 -import java.io.BufferedReader;
23 -import java.io.IOException;
24 -import java.io.InputStream;
25 -import java.io.InputStreamReader;
26 -import java.util.ArrayList;
27 -import java.util.Comparator;
28 -import java.util.List;
29 -import java.util.PriorityQueue;
30 -import java.util.Vector;
31 -import org.tensorflow.Graph;
32 -import org.tensorflow.Operation;
33 -import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
34 -import org.tensorflow.demo.env.Logger;
35 -
36 -/**
37 - * Wrapper for frozen detection models trained using the Tensorflow Object Detection API:
38 - * github.com/tensorflow/models/tree/master/research/object_detection
39 - */
40 -public class TensorFlowObjectDetectionAPIModel implements Classifier {
41 - private static final Logger LOGGER = new Logger();
42 -
43 - // Only return this many results.
44 - private static final int MAX_RESULTS = 100;
45 -
46 - // Config values.
47 - private String inputName;
48 - private int inputSize;
49 -
50 - // Pre-allocated buffers.
51 - private Vector<String> labels = new Vector<String>();
52 - private int[] intValues;
53 - private byte[] byteValues;
54 - private float[] outputLocations;
55 - private float[] outputScores;
56 - private float[] outputClasses;
57 - private float[] outputNumDetections;
58 - private String[] outputNames;
59 -
60 - private boolean logStats = false;
61 -
62 - private TensorFlowInferenceInterface inferenceInterface;
63 -
64 - /**
65 - * Initializes a native TensorFlow session for classifying images.
66 - *
67 - * @param assetManager The asset manager to be used to load assets.
68 - * @param modelFilename The filepath of the model GraphDef protocol buffer.
69 - * @param labelFilename The filepath of label file for classes.
70 - */
71 - public static Classifier create(
72 - final AssetManager assetManager,
73 - final String modelFilename,
74 - final String labelFilename,
75 - final int inputSize) throws IOException {
76 - final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();
77 -
78 - InputStream labelsInput = null;
79 - String actualFilename = labelFilename.split("file:///android_asset/")[1];
80 - labelsInput = assetManager.open(actualFilename);
81 - BufferedReader br = null;
82 - br = new BufferedReader(new InputStreamReader(labelsInput));
83 - String line;
84 - while ((line = br.readLine()) != null) {
85 - LOGGER.w(line);
86 - d.labels.add(line);
87 - }
88 - br.close();
89 -
90 -
91 - d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
92 -
93 - final Graph g = d.inferenceInterface.graph();
94 -
95 - d.inputName = "image_tensor";
96 - // The inputName node has a shape of [N, H, W, C], where
97 - // N is the batch size
98 - // H = W are the height and width
99 - // C is the number of channels (3 for our purposes - RGB)
100 - final Operation inputOp = g.operation(d.inputName);
101 - if (inputOp == null) {
102 - throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
103 - }
104 - d.inputSize = inputSize;
105 - // The outputScoresName node has a shape of [N, NumLocations], where N
106 - // is the batch size.
107 - final Operation outputOp1 = g.operation("detection_scores");
108 - if (outputOp1 == null) {
109 - throw new RuntimeException("Failed to find output Node 'detection_scores'");
110 - }
111 - final Operation outputOp2 = g.operation("detection_boxes");
112 - if (outputOp2 == null) {
113 - throw new RuntimeException("Failed to find output Node 'detection_boxes'");
114 - }
115 - final Operation outputOp3 = g.operation("detection_classes");
116 - if (outputOp3 == null) {
117 - throw new RuntimeException("Failed to find output Node 'detection_classes'");
118 - }
119 -
120 - // Pre-allocate buffers.
121 - d.outputNames = new String[] {"detection_boxes", "detection_scores",
122 - "detection_classes", "num_detections"};
123 - d.intValues = new int[d.inputSize * d.inputSize];
124 - d.byteValues = new byte[d.inputSize * d.inputSize * 3];
125 - d.outputScores = new float[MAX_RESULTS];
126 - d.outputLocations = new float[MAX_RESULTS * 4];
127 - d.outputClasses = new float[MAX_RESULTS];
128 - d.outputNumDetections = new float[1];
129 - return d;
130 - }
131 -
132 - private TensorFlowObjectDetectionAPIModel() {}
133 -
134 - @Override
135 - public List<Recognition> recognizeImage(final Bitmap bitmap) {
136 - // Log this method so that it can be analyzed with systrace.
137 - Trace.beginSection("recognizeImage");
138 -
139 - Trace.beginSection("preprocessBitmap");
140 - // Preprocess the image data to extract R, G and B bytes from int of form 0x00RRGGBB
141 - // on the provided parameters.
142 - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
143 -
144 - for (int i = 0; i < intValues.length; ++i) {
145 - byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
146 - byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
147 - byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
148 - }
149 - Trace.endSection(); // preprocessBitmap
150 -
151 - // Copy the input data into TensorFlow.
152 - Trace.beginSection("feed");
153 - inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
154 - Trace.endSection();
155 -
156 - // Run the inference call.
157 - Trace.beginSection("run");
158 - inferenceInterface.run(outputNames, logStats);
159 - Trace.endSection();
160 -
161 - // Copy the output Tensor back into the output array.
162 - Trace.beginSection("fetch");
163 - outputLocations = new float[MAX_RESULTS * 4];
164 - outputScores = new float[MAX_RESULTS];
165 - outputClasses = new float[MAX_RESULTS];
166 - outputNumDetections = new float[1];
167 - inferenceInterface.fetch(outputNames[0], outputLocations);
168 - inferenceInterface.fetch(outputNames[1], outputScores);
169 - inferenceInterface.fetch(outputNames[2], outputClasses);
170 - inferenceInterface.fetch(outputNames[3], outputNumDetections);
171 - Trace.endSection();
172 -
173 - // Find the best detections.
174 - final PriorityQueue<Recognition> pq =
175 - new PriorityQueue<Recognition>(
176 - 1,
177 - new Comparator<Recognition>() {
178 - @Override
179 - public int compare(final Recognition lhs, final Recognition rhs) {
180 - // Intentionally reversed to put high confidence at the head of the queue.
181 - return Float.compare(rhs.getConfidence(), lhs.getConfidence());
182 - }
183 - });
184 -
185 - // Scale them back to the input size.
186 - for (int i = 0; i < outputScores.length; ++i) {
187 - final RectF detection =
188 - new RectF(
189 - outputLocations[4 * i + 1] * inputSize,
190 - outputLocations[4 * i] * inputSize,
191 - outputLocations[4 * i + 3] * inputSize,
192 - outputLocations[4 * i + 2] * inputSize);
193 - pq.add(
194 - new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection));
195 - }
196 -
197 - final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
198 - for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
199 - recognitions.add(pq.poll());
200 - }
201 - Trace.endSection(); // "recognizeImage"
202 - return recognitions;
203 - }
204 -
205 - @Override
206 - public void enableStatLogging(final boolean logStats) {
207 - this.logStats = logStats;
208 - }
209 -
210 - @Override
211 - public String getStatString() {
212 - return inferenceInterface.getStatString();
213 - }
214 -
215 - @Override
216 - public void close() {
217 - inferenceInterface.close();
218 - }
219 -}
1 -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo;
17 -
18 -import android.content.res.AssetManager;
19 -import android.graphics.Bitmap;
20 -import android.graphics.RectF;
21 -import android.os.Trace;
22 -import java.util.ArrayList;
23 -import java.util.Comparator;
24 -import java.util.List;
25 -import java.util.PriorityQueue;
26 -import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
27 -import org.tensorflow.demo.env.Logger;
28 -import org.tensorflow.demo.env.SplitTimer;
29 -
30 -/** An object detector that uses TF and a YOLO model to detect objects. */
31 -public class TensorFlowYoloDetector implements Classifier {
32 - private static final Logger LOGGER = new Logger();
33 -
34 - // Only return this many results with at least this confidence.
35 - private static final int MAX_RESULTS = 5;
36 -
37 - private static final int NUM_CLASSES = 1;
38 -
39 - private static final int NUM_BOXES_PER_BLOCK = 5;
40 -
41 - // TODO(andrewharp): allow loading anchors and classes
42 - // from files.
43 - private static final double[] ANCHORS = {
44 - 1.08, 1.19,
45 - 3.42, 4.41,
46 - 6.63, 11.38,
47 - 9.42, 5.11,
48 - 16.62, 10.52
49 - };
50 -
51 - private static final String[] LABELS = {
52 - "dog"
53 - };
54 -
55 - // Config values.
56 - private String inputName;
57 - private int inputSize;
58 -
59 - // Pre-allocated buffers.
60 - private int[] intValues;
61 - private float[] floatValues;
62 - private String[] outputNames;
63 -
64 - private int blockSize;
65 -
66 - private boolean logStats = false;
67 -
68 - private TensorFlowInferenceInterface inferenceInterface;
69 -
70 - /** Initializes a native TensorFlow session for classifying images. */
71 - public static Classifier create(
72 - final AssetManager assetManager,
73 - final String modelFilename,
74 - final int inputSize,
75 - final String inputName,
76 - final String outputName,
77 - final int blockSize) {
78 - TensorFlowYoloDetector d = new TensorFlowYoloDetector();
79 - d.inputName = inputName;
80 - d.inputSize = inputSize;
81 -
82 - // Pre-allocate buffers.
83 - d.outputNames = outputName.split(",");
84 - d.intValues = new int[inputSize * inputSize];
85 - d.floatValues = new float[inputSize * inputSize * 3];
86 - d.blockSize = blockSize;
87 -
88 - d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
89 -
90 - return d;
91 - }
92 -
93 - private TensorFlowYoloDetector() {}
94 -
95 - private float expit(final float x) {
96 - return (float) (1. / (1. + Math.exp(-x)));
97 - }
98 -
99 - private void softmax(final float[] vals) {
100 - float max = Float.NEGATIVE_INFINITY;
101 - for (final float val : vals) {
102 - max = Math.max(max, val);
103 - }
104 - float sum = 0.0f;
105 - for (int i = 0; i < vals.length; ++i) {
106 - vals[i] = (float) Math.exp(vals[i] - max);
107 - sum += vals[i];
108 - }
109 - for (int i = 0; i < vals.length; ++i) {
110 - vals[i] = vals[i] / sum;
111 - }
112 - }
113 -
114 - @Override
115 - public List<Recognition> recognizeImage(final Bitmap bitmap) {
116 - final SplitTimer timer = new SplitTimer("recognizeImage");
117 -
118 - // Log this method so that it can be analyzed with systrace.
119 - Trace.beginSection("recognizeImage");
120 -
121 - Trace.beginSection("preprocessBitmap");
122 - // Preprocess the image data from 0-255 int to normalized float based
123 - // on the provided parameters.
124 - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
125 -
126 - for (int i = 0; i < intValues.length; ++i) {
127 - floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;
128 - floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
129 - floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
130 - }
131 - Trace.endSection(); // preprocessBitmap
132 -
133 - // Copy the input data into TensorFlow.
134 - Trace.beginSection("feed");
135 - inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
136 - Trace.endSection();
137 -
138 - timer.endSplit("ready for inference");
139 -
140 - // Run the inference call.
141 - Trace.beginSection("run");
142 - inferenceInterface.run(outputNames, logStats);
143 - Trace.endSection();
144 -
145 - timer.endSplit("ran inference");
146 -
147 - // Copy the output Tensor back into the output array.
148 - Trace.beginSection("fetch");
149 - final int gridWidth = bitmap.getWidth() / blockSize;
150 - final int gridHeight = bitmap.getHeight() / blockSize;
151 - final float[] output =
152 - new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
153 - inferenceInterface.fetch(outputNames[0], output);
154 - Trace.endSection();
155 -
156 - // Find the best detections.
157 - final PriorityQueue<Recognition> pq =
158 - new PriorityQueue<Recognition>(
159 - 1,
160 - new Comparator<Recognition>() {
161 - @Override
162 - public int compare(final Recognition lhs, final Recognition rhs) {
163 - // Intentionally reversed to put high confidence at the head of the queue.
164 - return Float.compare(rhs.getConfidence(), lhs.getConfidence());
165 - }
166 - });
167 -
168 - for (int y = 0; y < gridHeight; ++y) {
169 - for (int x = 0; x < gridWidth; ++x) {
170 - for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
171 - final int offset =
172 - (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
173 - + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
174 - + (NUM_CLASSES + 5) * b;
175 -
176 - final float xPos = (x + expit(output[offset + 0])) * blockSize;
177 - final float yPos = (y + expit(output[offset + 1])) * blockSize;
178 -
179 - final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize;
180 - final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize;
181 -
182 - final RectF rect =
183 - new RectF(
184 - Math.max(0, xPos - w / 2),
185 - Math.max(0, yPos - h / 2),
186 - Math.min(bitmap.getWidth() - 1, xPos + w / 2),
187 - Math.min(bitmap.getHeight() - 1, yPos + h / 2));
188 - final float confidence = expit(output[offset + 4]);
189 -
190 - int detectedClass = -1;
191 - float maxClass = 0;
192 -
193 - final float[] classes = new float[NUM_CLASSES];
194 - for (int c = 0; c < NUM_CLASSES; ++c) {
195 - classes[c] = output[offset + 5 + c];
196 - }
197 - softmax(classes);
198 -
199 - for (int c = 0; c < NUM_CLASSES; ++c) {
200 - if (classes[c] > maxClass) {
201 - detectedClass = c;
202 - maxClass = classes[c];
203 - }
204 - }
205 -
206 - final float confidenceInClass = maxClass * confidence;
207 - if (confidenceInClass > 0.01) {
208 - LOGGER.i(
209 - "%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect);
210 - pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect));
211 - }
212 - }
213 - }
214 - }
215 - timer.endSplit("decoded results");
216 -
217 - final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
218 - for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
219 - recognitions.add(pq.poll());
220 - }
221 - Trace.endSection(); // "recognizeImage"
222 -
223 - timer.endSplit("processed results");
224 -
225 - return recognitions;
226 - }
227 -
228 - @Override
229 - public void enableStatLogging(final boolean logStats) {
230 - this.logStats = logStats;
231 - }
232 -
233 - @Override
234 - public String getStatString() {
235 - return inferenceInterface.getStatString();
236 - }
237 -
238 - @Override
239 - public void close() {
240 - inferenceInterface.close();
241 - }
242 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.env;
17 -
18 -import android.graphics.Canvas;
19 -import android.graphics.Color;
20 -import android.graphics.Paint;
21 -import android.graphics.Paint.Align;
22 -import android.graphics.Paint.Style;
23 -import android.graphics.Rect;
24 -import android.graphics.Typeface;
25 -import java.util.Vector;
26 -
27 -/**
28 - * A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
29 - */
30 -public class BorderedText {
31 - private final Paint interiorPaint;
32 - private final Paint exteriorPaint;
33 -
34 - private final float textSize;
35 -
36 - /**
37 - * Creates a left-aligned bordered text object with a white interior, and a black exterior with
38 - * the specified text size.
39 - *
40 - * @param textSize text size in pixels
41 - */
42 - public BorderedText(final float textSize) {
43 - this(Color.WHITE, Color.BLACK, textSize);
44 - }
45 -
46 - /**
47 - * Create a bordered text object with the specified interior and exterior colors, text size and
48 - * alignment.
49 - *
50 - * @param interiorColor the interior text color
51 - * @param exteriorColor the exterior text color
52 - * @param textSize text size in pixels
53 - */
54 - public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
55 - interiorPaint = new Paint();
56 - interiorPaint.setTextSize(textSize);
57 - interiorPaint.setColor(interiorColor);
58 - interiorPaint.setStyle(Style.FILL);
59 - interiorPaint.setAntiAlias(false);
60 - interiorPaint.setAlpha(255);
61 -
62 - exteriorPaint = new Paint();
63 - exteriorPaint.setTextSize(textSize);
64 - exteriorPaint.setColor(exteriorColor);
65 - exteriorPaint.setStyle(Style.FILL_AND_STROKE);
66 - exteriorPaint.setStrokeWidth(textSize / 8);
67 - exteriorPaint.setAntiAlias(false);
68 - exteriorPaint.setAlpha(255);
69 -
70 - this.textSize = textSize;
71 - }
72 -
73 - public void setTypeface(Typeface typeface) {
74 - interiorPaint.setTypeface(typeface);
75 - exteriorPaint.setTypeface(typeface);
76 - }
77 -
78 - public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
79 - canvas.drawText(text, posX, posY, exteriorPaint);
80 - canvas.drawText(text, posX, posY, interiorPaint);
81 - }
82 -
83 - public void drawLines(Canvas canvas, final float posX, final float posY, Vector<String> lines) {
84 - int lineNum = 0;
85 - for (final String line : lines) {
86 - drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line);
87 - ++lineNum;
88 - }
89 - }
90 -
91 - public void setInteriorColor(final int color) {
92 - interiorPaint.setColor(color);
93 - }
94 -
95 - public void setExteriorColor(final int color) {
96 - exteriorPaint.setColor(color);
97 - }
98 -
99 - public float getTextSize() {
100 - return textSize;
101 - }
102 -
103 - public void setAlpha(final int alpha) {
104 - interiorPaint.setAlpha(alpha);
105 - exteriorPaint.setAlpha(alpha);
106 - }
107 -
108 - public void getTextBounds(
109 - final String line, final int index, final int count, final Rect lineBounds) {
110 - interiorPaint.getTextBounds(line, index, count, lineBounds);
111 - }
112 -
113 - public void setTextAlign(final Align align) {
114 - interiorPaint.setTextAlign(align);
115 - exteriorPaint.setTextAlign(align);
116 - }
117 -}
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.env;
17 -
18 -import android.graphics.Bitmap;
19 -import android.graphics.Matrix;
20 -import android.os.Environment;
21 -import java.io.File;
22 -import java.io.FileOutputStream;
23 -
24 -/**
25 - * Utility class for manipulating images.
26 - **/
27 -public class ImageUtils {
28 - @SuppressWarnings("unused")
29 - private static final Logger LOGGER = new Logger();
30 -
31 - static {
32 - try {
33 - System.loadLibrary("tensorflow_demo");
34 - } catch (UnsatisfiedLinkError e) {
35 - LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable.");
36 - }
37 - }
38 -
39 - /**
40 - * Utility method to compute the allocated size in bytes of a YUV420SP image
41 - * of the given dimensions.
42 - */
43 - public static int getYUVByteSize(final int width, final int height) {
44 - // The luminance plane requires 1 byte per pixel.
45 - final int ySize = width * height;
46 -
47 - // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up.
48 - // Each 2x2 block takes 2 bytes to encode, one each for U and V.
49 - final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2;
50 -
51 - return ySize + uvSize;
52 - }
53 -
54 - /**
55 - * Saves a Bitmap object to disk for analysis.
56 - *
57 - * @param bitmap The bitmap to save.
58 - */
59 - public static void saveBitmap(final Bitmap bitmap) {
60 - saveBitmap(bitmap, "preview.png");
61 - }
62 -
63 - /**
64 - * Saves a Bitmap object to disk for analysis.
65 - *
66 - * @param bitmap The bitmap to save.
67 - * @param filename The location to save the bitmap to.
68 - */
69 - public static void saveBitmap(final Bitmap bitmap, final String filename) {
70 - final String root =
71 - Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow";
72 - LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root);
73 - final File myDir = new File(root);
74 -
75 - if (!myDir.mkdirs()) {
76 - LOGGER.i("Make dir failed");
77 - }
78 -
79 - final String fname = filename;
80 - final File file = new File(myDir, fname);
81 - if (file.exists()) {
82 - file.delete();
83 - }
84 - try {
85 - final FileOutputStream out = new FileOutputStream(file);
86 - bitmap.compress(Bitmap.CompressFormat.PNG, 99, out);
87 - out.flush();
88 - out.close();
89 - } catch (final Exception e) {
90 - LOGGER.e(e, "Exception!");
91 - }
92 - }
93 -
94 - // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
95 - // are normalized to eight bits.
96 - static final int kMaxChannelValue = 262143;
97 -
98 - // Always prefer the native implementation if available.
99 - private static boolean useNativeConversion = true;
100 -
101 - public static void convertYUV420SPToARGB8888(
102 - byte[] input,
103 - int width,
104 - int height,
105 - int[] output) {
106 - if (useNativeConversion) {
107 - try {
108 - ImageUtils.convertYUV420SPToARGB8888(input, output, width, height, false);
109 - return;
110 - } catch (UnsatisfiedLinkError e) {
111 - LOGGER.w(
112 - "Native YUV420SP -> RGB implementation not found, falling back to Java implementation");
113 - useNativeConversion = false;
114 - }
115 - }
116 -
117 - // Java implementation of YUV420SP to ARGB8888 converting
118 - final int frameSize = width * height;
119 - for (int j = 0, yp = 0; j < height; j++) {
120 - int uvp = frameSize + (j >> 1) * width;
121 - int u = 0;
122 - int v = 0;
123 -
124 - for (int i = 0; i < width; i++, yp++) {
125 - int y = 0xff & input[yp];
126 - if ((i & 1) == 0) {
127 - v = 0xff & input[uvp++];
128 - u = 0xff & input[uvp++];
129 - }
130 -
131 - output[yp] = YUV2RGB(y, u, v);
132 - }
133 - }
134 - }
135 -
136 - private static int YUV2RGB(int y, int u, int v) {
137 - // Adjust and check YUV values
138 - y = (y - 16) < 0 ? 0 : (y - 16);
139 - u -= 128;
140 - v -= 128;
141 -
142 - // This is the floating point equivalent. We do the conversion in integer
143 - // because some Android devices do not have floating point in hardware.
144 - // nR = (int)(1.164 * nY + 2.018 * nU);
145 - // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
146 - // nB = (int)(1.164 * nY + 1.596 * nV);
147 - int y1192 = 1192 * y;
148 - int r = (y1192 + 1634 * v);
149 - int g = (y1192 - 833 * v - 400 * u);
150 - int b = (y1192 + 2066 * u);
151 -
152 - // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ]
153 - r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r);
154 - g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g);
155 - b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b);
156 -
157 - return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff);
158 - }
159 -
160 -
161 - public static void convertYUV420ToARGB8888(
162 - byte[] yData,
163 - byte[] uData,
164 - byte[] vData,
165 - int width,
166 - int height,
167 - int yRowStride,
168 - int uvRowStride,
169 - int uvPixelStride,
170 - int[] out) {
171 - if (useNativeConversion) {
172 - try {
173 - convertYUV420ToARGB8888(
174 - yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false);
175 - return;
176 - } catch (UnsatisfiedLinkError e) {
177 - LOGGER.w(
178 - "Native YUV420 -> RGB implementation not found, falling back to Java implementation");
179 - useNativeConversion = false;
180 - }
181 - }
182 -
183 - int yp = 0;
184 - for (int j = 0; j < height; j++) {
185 - int pY = yRowStride * j;
186 - int pUV = uvRowStride * (j >> 1);
187 -
188 - for (int i = 0; i < width; i++) {
189 - int uv_offset = pUV + (i >> 1) * uvPixelStride;
190 -
191 - out[yp++] = YUV2RGB(
192 - 0xff & yData[pY + i],
193 - 0xff & uData[uv_offset],
194 - 0xff & vData[uv_offset]);
195 - }
196 - }
197 - }
198 -
199 -
200 - /**
201 - * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The
202 - * input and output must already be allocated and non-null. For efficiency, no error checking is
203 - * performed.
204 - *
205 - * @param input The array of YUV 4:2:0 input data.
206 - * @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
207 - * @param width The width of the input image.
208 - * @param height The height of the input image.
209 - * @param halfSize If true, downsample to 50% in each dimension, otherwise not.
210 - */
211 - private static native void convertYUV420SPToARGB8888(
212 - byte[] input, int[] output, int width, int height, boolean halfSize);
213 -
214 - /**
215 - * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
216 - * and height. The input and output must already be allocated and non-null.
217 - * For efficiency, no error checking is performed.
218 - *
219 - * @param y
220 - * @param u
221 - * @param v
222 - * @param uvPixelStride
223 - * @param width The width of the input image.
224 - * @param height The height of the input image.
225 - * @param halfSize If true, downsample to 50% in each dimension, otherwise not.
226 - * @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
227 - */
228 - private static native void convertYUV420ToARGB8888(
229 - byte[] y,
230 - byte[] u,
231 - byte[] v,
232 - int[] output,
233 - int width,
234 - int height,
235 - int yRowStride,
236 - int uvRowStride,
237 - int uvPixelStride,
238 - boolean halfSize);
239 -
240 - /**
241 - * Converts YUV420 semi-planar data to RGB 565 data using the supplied width
242 - * and height. The input and output must already be allocated and non-null.
243 - * For efficiency, no error checking is performed.
244 - *
245 - * @param input The array of YUV 4:2:0 input data.
246 - * @param output A pre-allocated array for the RGB 5:6:5 output data.
247 - * @param width The width of the input image.
248 - * @param height The height of the input image.
249 - */
250 - private static native void convertYUV420SPToRGB565(
251 - byte[] input, byte[] output, int width, int height);
252 -
253 - /**
254 - * Converts 32-bit ARGB8888 image data to YUV420SP data. This is useful, for
255 - * instance, in creating data to feed the classes that rely on raw camera
256 - * preview frames.
257 - *
258 - * @param input An array of input pixels in ARGB8888 format.
259 - * @param output A pre-allocated array for the YUV420SP output data.
260 - * @param width The width of the input image.
261 - * @param height The height of the input image.
262 - */
263 - private static native void convertARGB8888ToYUV420SP(
264 - int[] input, byte[] output, int width, int height);
265 -
266 - /**
267 - * Converts 16-bit RGB565 image data to YUV420SP data. This is useful, for
268 - * instance, in creating data to feed the classes that rely on raw camera
269 - * preview frames.
270 - *
271 - * @param input An array of input pixels in RGB565 format.
272 - * @param output A pre-allocated array for the YUV420SP output data.
273 - * @param width The width of the input image.
274 - * @param height The height of the input image.
275 - */
276 - private static native void convertRGB565ToYUV420SP(
277 - byte[] input, byte[] output, int width, int height);
278 -
279 - /**
280 - * Returns a transformation matrix from one reference frame into another.
281 - * Handles cropping (if maintaining aspect ratio is desired) and rotation.
282 - *
283 - * @param srcWidth Width of source frame.
284 - * @param srcHeight Height of source frame.
285 - * @param dstWidth Width of destination frame.
286 - * @param dstHeight Height of destination frame.
287 - * @param applyRotation Amount of rotation to apply from one frame to another.
288 - * Must be a multiple of 90.
289 - * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
290 - * cropping the image if necessary.
291 - * @return The transformation fulfilling the desired requirements.
292 - */
293 - public static Matrix getTransformationMatrix(
294 - final int srcWidth,
295 - final int srcHeight,
296 - final int dstWidth,
297 - final int dstHeight,
298 - final int applyRotation,
299 - final boolean maintainAspectRatio) {
300 - final Matrix matrix = new Matrix();
301 -
302 - if (applyRotation != 0) {
303 - if (applyRotation % 90 != 0) {
304 - LOGGER.w("Rotation of %d % 90 != 0", applyRotation);
305 - }
306 -
307 - // Translate so center of image is at origin.
308 - matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);
309 -
310 - // Rotate around origin.
311 - matrix.postRotate(applyRotation);
312 - }
313 -
314 - // Account for the already applied rotation, if any, and then determine how
315 - // much scaling is needed for each axis.
316 - final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;
317 -
318 - final int inWidth = transpose ? srcHeight : srcWidth;
319 - final int inHeight = transpose ? srcWidth : srcHeight;
320 -
321 - // Apply scaling if necessary.
322 - if (inWidth != dstWidth || inHeight != dstHeight) {
323 - final float scaleFactorX = dstWidth / (float) inWidth;
324 - final float scaleFactorY = dstHeight / (float) inHeight;
325 -
326 - if (maintainAspectRatio) {
327 - // Scale by minimum factor so that dst is filled completely while
328 - // maintaining the aspect ratio. Some image may fall off the edge.
329 - final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
330 - matrix.postScale(scaleFactor, scaleFactor);
331 - } else {
332 - // Scale exactly to fill dst from src.
333 - matrix.postScale(scaleFactorX, scaleFactorY);
334 - }
335 - }
336 -
337 - if (applyRotation != 0) {
338 - // Translate back from origin centered reference to destination frame.
339 - matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
340 - }
341 -
342 - return matrix;
343 - }
344 -}
1 -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.env;
17 -
18 -import android.util.Log;
19 -
20 -import java.util.HashSet;
21 -import java.util.Set;
22 -
23 -/**
24 - * Wrapper for the platform log function, allows convenient message prefixing and log disabling.
25 - */
26 -public final class Logger {
27 - private static final String DEFAULT_TAG = "tensorflow";
28 - private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG;
29 -
30 - // Classes to be ignored when examining the stack trace
31 - private static final Set<String> IGNORED_CLASS_NAMES;
32 -
33 - static {
34 - IGNORED_CLASS_NAMES = new HashSet<String>(3);
35 - IGNORED_CLASS_NAMES.add("dalvik.system.VMStack");
36 - IGNORED_CLASS_NAMES.add("java.lang.Thread");
37 - IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName());
38 - }
39 -
40 - private final String tag;
41 - private final String messagePrefix;
42 - private int minLogLevel = DEFAULT_MIN_LOG_LEVEL;
43 -
44 - /**
45 - * Creates a Logger using the class name as the message prefix.
46 - *
47 - * @param clazz the simple name of this class is used as the message prefix.
48 - */
49 - public Logger(final Class<?> clazz) {
50 - this(clazz.getSimpleName());
51 - }
52 -
53 - /**
54 - * Creates a Logger using the specified message prefix.
55 - *
56 - * @param messagePrefix is prepended to the text of every message.
57 - */
58 - public Logger(final String messagePrefix) {
59 - this(DEFAULT_TAG, messagePrefix);
60 - }
61 -
62 - /**
63 - * Creates a Logger with a custom tag and a custom message prefix. If the message prefix
64 - * is set to <pre>null</pre>, the caller's class name is used as the prefix.
65 - *
66 - * @param tag identifies the source of a log message.
67 - * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is
68 - * being used
69 - */
70 - public Logger(final String tag, final String messagePrefix) {
71 - this.tag = tag;
72 - final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix;
73 - this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix;
74 - }
75 -
76 - /**
77 - * Creates a Logger using the caller's class name as the message prefix.
78 - */
79 - public Logger() {
80 - this(DEFAULT_TAG, null);
81 - }
82 -
83 - /**
84 - * Creates a Logger using the caller's class name as the message prefix.
85 - */
86 - public Logger(final int minLogLevel) {
87 - this(DEFAULT_TAG, null);
88 - this.minLogLevel = minLogLevel;
89 - }
90 -
91 - public void setMinLogLevel(final int minLogLevel) {
92 - this.minLogLevel = minLogLevel;
93 - }
94 -
95 - public boolean isLoggable(final int logLevel) {
96 - return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel);
97 - }
98 -
99 - /**
100 - * Return caller's simple name.
101 - *
102 - * Android getStackTrace() returns an array that looks like this:
103 - * stackTrace[0]: dalvik.system.VMStack
104 - * stackTrace[1]: java.lang.Thread
105 - * stackTrace[2]: com.google.android.apps.unveil.env.UnveilLogger
106 - * stackTrace[3]: com.google.android.apps.unveil.BaseApplication
107 - *
108 - * This function returns the simple version of the first non-filtered name.
109 - *
110 - * @return caller's simple name
111 - */
112 - private static String getCallerSimpleName() {
113 - // Get the current callstack so we can pull the class of the caller off of it.
114 - final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
115 -
116 - for (final StackTraceElement elem : stackTrace) {
117 - final String className = elem.getClassName();
118 - if (!IGNORED_CLASS_NAMES.contains(className)) {
119 - // We're only interested in the simple name of the class, not the complete package.
120 - final String[] classParts = className.split("\\.");
121 - return classParts[classParts.length - 1];
122 - }
123 - }
124 -
125 - return Logger.class.getSimpleName();
126 - }
127 -
128 - private String toMessage(final String format, final Object... args) {
129 - return messagePrefix + (args.length > 0 ? String.format(format, args) : format);
130 - }
131 -
132 - public void v(final String format, final Object... args) {
133 - if (isLoggable(Log.VERBOSE)) {
134 - Log.v(tag, toMessage(format, args));
135 - }
136 - }
137 -
138 - public void v(final Throwable t, final String format, final Object... args) {
139 - if (isLoggable(Log.VERBOSE)) {
140 - Log.v(tag, toMessage(format, args), t);
141 - }
142 - }
143 -
144 - public void d(final String format, final Object... args) {
145 - if (isLoggable(Log.DEBUG)) {
146 - Log.d(tag, toMessage(format, args));
147 - }
148 - }
149 -
150 - public void d(final Throwable t, final String format, final Object... args) {
151 - if (isLoggable(Log.DEBUG)) {
152 - Log.d(tag, toMessage(format, args), t);
153 - }
154 - }
155 -
156 - public void i(final String format, final Object... args) {
157 - if (isLoggable(Log.INFO)) {
158 - Log.i(tag, toMessage(format, args));
159 - }
160 - }
161 -
162 - public void i(final Throwable t, final String format, final Object... args) {
163 - if (isLoggable(Log.INFO)) {
164 - Log.i(tag, toMessage(format, args), t);
165 - }
166 - }
167 -
168 - public void w(final String format, final Object... args) {
169 - if (isLoggable(Log.WARN)) {
170 - Log.w(tag, toMessage(format, args));
171 - }
172 - }
173 -
174 - public void w(final Throwable t, final String format, final Object... args) {
175 - if (isLoggable(Log.WARN)) {
176 - Log.w(tag, toMessage(format, args), t);
177 - }
178 - }
179 -
180 - public void e(final String format, final Object... args) {
181 - if (isLoggable(Log.ERROR)) {
182 - Log.e(tag, toMessage(format, args));
183 - }
184 - }
185 -
186 - public void e(final Throwable t, final String format, final Object... args) {
187 - if (isLoggable(Log.ERROR)) {
188 - Log.e(tag, toMessage(format, args), t);
189 - }
190 - }
191 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.env;
17 -
18 -import android.graphics.Bitmap;
19 -import android.text.TextUtils;
20 -import java.io.Serializable;
21 -import java.util.ArrayList;
22 -import java.util.List;
23 -
24 -/**
25 - * Size class independent of a Camera object.
26 - */
27 -public class Size implements Comparable<Size>, Serializable {
28 -
29 - // 1.4 went out with this UID so we'll need to maintain it to preserve pending queries when
30 - // upgrading.
31 - public static final long serialVersionUID = 7689808733290872361L;
32 -
33 - public final int width;
34 - public final int height;
35 -
36 - public Size(final int width, final int height) {
37 - this.width = width;
38 - this.height = height;
39 - }
40 -
41 - public Size(final Bitmap bmp) {
42 - this.width = bmp.getWidth();
43 - this.height = bmp.getHeight();
44 - }
45 -
46 - /**
47 - * Rotate a size by the given number of degrees.
48 - * @param size Size to rotate.
49 - * @param rotation Degrees {0, 90, 180, 270} to rotate the size.
50 - * @return Rotated size.
51 - */
52 - public static Size getRotatedSize(final Size size, final int rotation) {
53 - if (rotation % 180 != 0) {
54 - // The phone is portrait, therefore the camera is sideways and frame should be rotated.
55 - return new Size(size.height, size.width);
56 - }
57 - return size;
58 - }
59 -
60 - public static Size parseFromString(String sizeString) {
61 - if (TextUtils.isEmpty(sizeString)) {
62 - return null;
63 - }
64 -
65 - sizeString = sizeString.trim();
66 -
67 - // The expected format is "<width>x<height>".
68 - final String[] components = sizeString.split("x");
69 - if (components.length == 2) {
70 - try {
71 - final int width = Integer.parseInt(components[0]);
72 - final int height = Integer.parseInt(components[1]);
73 - return new Size(width, height);
74 - } catch (final NumberFormatException e) {
75 - return null;
76 - }
77 - } else {
78 - return null;
79 - }
80 - }
81 -
82 - public static List<Size> sizeStringToList(final String sizes) {
83 - final List<Size> sizeList = new ArrayList<Size>();
84 - if (sizes != null) {
85 - final String[] pairs = sizes.split(",");
86 - for (final String pair : pairs) {
87 - final Size size = Size.parseFromString(pair);
88 - if (size != null) {
89 - sizeList.add(size);
90 - }
91 - }
92 - }
93 - return sizeList;
94 - }
95 -
96 - public static String sizeListToString(final List<Size> sizes) {
97 - String sizesString = "";
98 - if (sizes != null && sizes.size() > 0) {
99 - sizesString = sizes.get(0).toString();
100 - for (int i = 1; i < sizes.size(); i++) {
101 - sizesString += "," + sizes.get(i).toString();
102 - }
103 - }
104 - return sizesString;
105 - }
106 -
107 - public final float aspectRatio() {
108 - return (float) width / (float) height;
109 - }
110 -
111 - @Override
112 - public int compareTo(final Size other) {
113 - return width * height - other.width * other.height;
114 - }
115 -
116 - @Override
117 - public boolean equals(final Object other) {
118 - if (other == null) {
119 - return false;
120 - }
121 -
122 - if (!(other instanceof Size)) {
123 - return false;
124 - }
125 -
126 - final Size otherSize = (Size) other;
127 - return (width == otherSize.width && height == otherSize.height);
128 - }
129 -
130 - @Override
131 - public int hashCode() {
132 - return width * 32713 + height;
133 - }
134 -
135 - @Override
136 - public String toString() {
137 - return dimensionsAsString(width, height);
138 - }
139 -
140 - public static final String dimensionsAsString(final int width, final int height) {
141 - return width + "x" + height;
142 - }
143 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.env;
17 -
18 -import android.os.SystemClock;
19 -
20 -/**
21 - * A simple utility timer for measuring CPU time and wall-clock splits.
22 - */
23 -public class SplitTimer {
24 - private final Logger logger;
25 -
26 - private long lastWallTime;
27 - private long lastCpuTime;
28 -
29 - public SplitTimer(final String name) {
30 - logger = new Logger(name);
31 - newSplit();
32 - }
33 -
34 - public void newSplit() {
35 - lastWallTime = SystemClock.uptimeMillis();
36 - lastCpuTime = SystemClock.currentThreadTimeMillis();
37 - }
38 -
39 - public void endSplit(final String splitName) {
40 - final long currWallTime = SystemClock.uptimeMillis();
41 - final long currCpuTime = SystemClock.currentThreadTimeMillis();
42 -
43 - logger.i(
44 - "%s: cpu=%dms wall=%dms",
45 - splitName, currCpuTime - lastCpuTime, currWallTime - lastWallTime);
46 -
47 - lastWallTime = currWallTime;
48 - lastCpuTime = currCpuTime;
49 - }
50 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.tracking;
17 -
18 -import android.content.Context;
19 -import android.graphics.Canvas;
20 -import android.graphics.Color;
21 -import android.graphics.Matrix;
22 -import android.graphics.Paint;
23 -import android.graphics.Paint.Cap;
24 -import android.graphics.Paint.Join;
25 -import android.graphics.Paint.Style;
26 -import android.graphics.RectF;
27 -import android.text.TextUtils;
28 -import android.util.Pair;
29 -import android.util.TypedValue;
30 -import android.widget.Toast;
31 -import java.util.LinkedList;
32 -import java.util.List;
33 -import java.util.Queue;
34 -import org.tensorflow.demo.Classifier.Recognition;
35 -import org.tensorflow.demo.env.BorderedText;
36 -import org.tensorflow.demo.env.ImageUtils;
37 -import org.tensorflow.demo.env.Logger;
38 -
39 -/**
40 - * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
41 - * objects to new detections.
42 - */
43 -public class MultiBoxTracker {
44 - private final Logger logger = new Logger();
45 -
46 - private static final float TEXT_SIZE_DIP = 18;
47 -
48 - // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
49 - // the lower scored box (new or old) will be removed.
50 - private static final float MAX_OVERLAP = 0.2f;
51 -
52 - private static final float MIN_SIZE = 16.0f;
53 -
54 - // Allow replacement of the tracked box with new results if
55 - // correlation has dropped below this level.
56 - private static final float MARGINAL_CORRELATION = 0.75f;
57 -
58 - // Consider object to be lost if correlation falls below this threshold.
59 - private static final float MIN_CORRELATION = 0.3f;
60 -
61 - private static final int[] COLORS = {
62 - Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE,
63 - Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"),
64 - Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"),
65 - Color.parseColor("#AA33AA"), Color.parseColor("#0D0068")
66 - };
67 -
68 - private final Queue<Integer> availableColors = new LinkedList<Integer>();
69 -
70 - public ObjectTracker objectTracker;
71 -
72 - final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
73 -
74 - private static class TrackedRecognition {
75 - ObjectTracker.TrackedObject trackedObject;
76 - RectF location;
77 - float detectionConfidence;
78 - int color;
79 - String title;
80 - }
81 -
82 - private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
83 -
84 - private final Paint boxPaint = new Paint();
85 -
86 - private final float textSizePx;
87 - private final BorderedText borderedText;
88 -
89 - private Matrix frameToCanvasMatrix;
90 -
91 - private int frameWidth;
92 - private int frameHeight;
93 -
94 - private int sensorOrientation;
95 - private Context context;
96 -
97 - public MultiBoxTracker(final Context context) {
98 - this.context = context;
99 - for (final int color : COLORS) {
100 - availableColors.add(color);
101 - }
102 -
103 - boxPaint.setColor(Color.RED);
104 - boxPaint.setStyle(Style.STROKE);
105 - boxPaint.setStrokeWidth(12.0f);
106 - boxPaint.setStrokeCap(Cap.ROUND);
107 - boxPaint.setStrokeJoin(Join.ROUND);
108 - boxPaint.setStrokeMiter(100);
109 -
110 - textSizePx =
111 - TypedValue.applyDimension(
112 - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics());
113 - borderedText = new BorderedText(textSizePx);
114 - }
115 -
116 - private Matrix getFrameToCanvasMatrix() {
117 - return frameToCanvasMatrix;
118 - }
119 -
120 - public synchronized void drawDebug(final Canvas canvas) {
121 - final Paint textPaint = new Paint();
122 - textPaint.setColor(Color.WHITE);
123 - textPaint.setTextSize(60.0f);
124 -
125 - final Paint boxPaint = new Paint();
126 - boxPaint.setColor(Color.RED);
127 - boxPaint.setAlpha(200);
128 - boxPaint.setStyle(Style.STROKE);
129 -
130 - for (final Pair<Float, RectF> detection : screenRects) {
131 - final RectF rect = detection.second;
132 - canvas.drawRect(rect, boxPaint);
133 - canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
134 - borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
135 - }
136 -
137 - if (objectTracker == null) {
138 - return;
139 - }
140 -
141 - // Draw correlations.
142 - for (final TrackedRecognition recognition : trackedObjects) {
143 - final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
144 -
145 - final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
146 -
147 - if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
148 - final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
149 - borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
150 - }
151 - }
152 -
153 - final Matrix matrix = getFrameToCanvasMatrix();
154 - objectTracker.drawDebug(canvas, matrix);
155 - }
156 -
157 - public synchronized void trackResults(
158 - final List<Recognition> results, final byte[] frame, final long timestamp) {
159 - logger.i("Processing %d results from %d", results.size(), timestamp);
160 - processResults(timestamp, results, frame);
161 - }
162 -
163 - public synchronized void draw(final Canvas canvas) {
164 - final boolean rotated = sensorOrientation % 180 == 90;
165 - final float multiplier =
166 - Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight),
167 - canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth));
168 - frameToCanvasMatrix =
169 - ImageUtils.getTransformationMatrix(
170 - frameWidth,
171 - frameHeight,
172 - (int) (multiplier * (rotated ? frameHeight : frameWidth)),
173 - (int) (multiplier * (rotated ? frameWidth : frameHeight)),
174 - sensorOrientation,
175 - false);
176 - for (final TrackedRecognition recognition : trackedObjects) {
177 - final RectF trackedPos =
178 - (objectTracker != null)
179 - ? recognition.trackedObject.getTrackedPositionInPreviewFrame()
180 - : new RectF(recognition.location);
181 -
182 - getFrameToCanvasMatrix().mapRect(trackedPos);
183 - boxPaint.setColor(recognition.color);
184 -
185 - final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
186 - canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
187 -
188 - final String labelString =
189 - !TextUtils.isEmpty(recognition.title)
190 - ? String.format("%s %.2f", recognition.title, recognition.detectionConfidence)
191 - : String.format("%.2f", recognition.detectionConfidence);
192 - borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
193 - }
194 - }
195 -
196 - private boolean initialized = false;
197 -
198 - public synchronized void onFrame(
199 - final int w,
200 - final int h,
201 - final int rowStride,
202 - final int sensorOrientation,
203 - final byte[] frame,
204 - final long timestamp) {
205 - if (objectTracker == null && !initialized) {
206 - ObjectTracker.clearInstance();
207 -
208 - logger.i("Initializing ObjectTracker: %dx%d", w, h);
209 - objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
210 - frameWidth = w;
211 - frameHeight = h;
212 - this.sensorOrientation = sensorOrientation;
213 - initialized = true;
214 -
215 - if (objectTracker == null) {
216 - String message =
217 - "Object tracking support not found. "
218 - + "See tensorflow/examples/android/README.md for details.";
219 - Toast.makeText(context, message, Toast.LENGTH_LONG).show();
220 - logger.e(message);
221 - }
222 - }
223 -
224 - if (objectTracker == null) {
225 - return;
226 - }
227 -
228 - objectTracker.nextFrame(frame, null, timestamp, null, true);
229 -
230 - // Clean up any objects not worth tracking any more.
231 - final LinkedList<TrackedRecognition> copyList =
232 - new LinkedList<TrackedRecognition>(trackedObjects);
233 - for (final TrackedRecognition recognition : copyList) {
234 - final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
235 - final float correlation = trackedObject.getCurrentCorrelation();
236 - if (correlation < MIN_CORRELATION) {
237 - logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
238 - trackedObject.stopTracking();
239 - trackedObjects.remove(recognition);
240 -
241 - availableColors.add(recognition.color);
242 - }
243 - }
244 - }
245 -
246 - private void processResults(
247 - final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
248 - final List<Pair<Float, Recognition>> rectsToTrack = new LinkedList<Pair<Float, Recognition>>();
249 -
250 - screenRects.clear();
251 - final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
252 -
253 - for (final Recognition result : results) {
254 - if (result.getLocation() == null) {
255 - continue;
256 - }
257 - final RectF detectionFrameRect = new RectF(result.getLocation());
258 -
259 - final RectF detectionScreenRect = new RectF();
260 - rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
261 -
262 - logger.v(
263 - "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
264 -
265 - screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
266 -
267 - if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
268 - logger.w("Degenerate rectangle! " + detectionFrameRect);
269 - continue;
270 - }
271 -
272 - rectsToTrack.add(new Pair<Float, Recognition>(result.getConfidence(), result));
273 - }
274 -
275 - if (rectsToTrack.isEmpty()) {
276 - logger.v("Nothing to track, aborting.");
277 - return;
278 - }
279 -
280 - if (objectTracker == null) {
281 - trackedObjects.clear();
282 - for (final Pair<Float, Recognition> potential : rectsToTrack) {
283 - final TrackedRecognition trackedRecognition = new TrackedRecognition();
284 - trackedRecognition.detectionConfidence = potential.first;
285 - trackedRecognition.location = new RectF(potential.second.getLocation());
286 - trackedRecognition.trackedObject = null;
287 - trackedRecognition.title = potential.second.getTitle();
288 - trackedRecognition.color = COLORS[trackedObjects.size()];
289 - trackedObjects.add(trackedRecognition);
290 -
291 - if (trackedObjects.size() >= COLORS.length) {
292 - break;
293 - }
294 - }
295 - return;
296 - }
297 -
298 - logger.i("%d rects to track", rectsToTrack.size());
299 - for (final Pair<Float, Recognition> potential : rectsToTrack) {
300 - handleDetection(originalFrame, timestamp, potential);
301 - }
302 - }
303 -
304 - private void handleDetection(
305 - final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential) {
306 - final ObjectTracker.TrackedObject potentialObject =
307 - objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy);
308 -
309 - final float potentialCorrelation = potentialObject.getCurrentCorrelation();
310 - logger.v(
311 - "Tracked object went from %s to %s with correlation %.2f",
312 - potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
313 -
314 - if (potentialCorrelation < MARGINAL_CORRELATION) {
315 - logger.v("Correlation too low to begin tracking %s.", potentialObject);
316 - potentialObject.stopTracking();
317 - return;
318 - }
319 -
320 - final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
321 -
322 - float maxIntersect = 0.0f;
323 -
324 - // This is the current tracked object whose color we will take. If left null we'll take the
325 - // first one from the color queue.
326 - TrackedRecognition recogToReplace = null;
327 -
328 - // Look for intersections that will be overridden by this object or an intersection that would
329 - // prevent this one from being placed.
330 - for (final TrackedRecognition trackedRecognition : trackedObjects) {
331 - final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
332 - final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
333 - final RectF intersection = new RectF();
334 - final boolean intersects = intersection.setIntersect(a, b);
335 -
336 - final float intersectArea = intersection.width() * intersection.height();
337 - final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea;
338 - final float intersectOverUnion = intersectArea / totalArea;
339 -
340 - // If there is an intersection with this currently tracked box above the maximum overlap
341 - // percentage allowed, either the new recognition needs to be dismissed or the old
342 - // recognition needs to be removed and possibly replaced with the new one.
343 - if (intersects && intersectOverUnion > MAX_OVERLAP) {
344 - if (potential.first < trackedRecognition.detectionConfidence
345 - && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
346 - // If track for the existing object is still going strong and the detection score was
347 - // good, reject this new object.
348 - potentialObject.stopTracking();
349 - return;
350 - } else {
351 - removeList.add(trackedRecognition);
352 -
353 - // Let the previously tracked object with max intersection amount donate its color to
354 - // the new object.
355 - if (intersectOverUnion > maxIntersect) {
356 - maxIntersect = intersectOverUnion;
357 - recogToReplace = trackedRecognition;
358 - }
359 - }
360 - }
361 - }
362 -
363 - // If we're already tracking the max object and no intersections were found to bump off,
364 - // pick the worst current tracked object to remove, if it's also worse than this candidate
365 - // object.
366 - if (availableColors.isEmpty() && removeList.isEmpty()) {
367 - for (final TrackedRecognition candidate : trackedObjects) {
368 - if (candidate.detectionConfidence < potential.first) {
369 - if (recogToReplace == null
370 - || candidate.detectionConfidence < recogToReplace.detectionConfidence) {
371 - // Save it so that we use this color for the new object.
372 - recogToReplace = candidate;
373 - }
374 - }
375 - }
376 - if (recogToReplace != null) {
377 - logger.v("Found non-intersecting object to remove.");
378 - removeList.add(recogToReplace);
379 - } else {
380 - logger.v("No non-intersecting object found to remove");
381 - }
382 - }
383 -
384 - // Remove everything that got intersected.
385 - for (final TrackedRecognition trackedRecognition : removeList) {
386 - logger.v(
387 - "Removing tracked object %s with detection confidence %.2f, correlation %.2f",
388 - trackedRecognition.trackedObject,
389 - trackedRecognition.detectionConfidence,
390 - trackedRecognition.trackedObject.getCurrentCorrelation());
391 - trackedRecognition.trackedObject.stopTracking();
392 - trackedObjects.remove(trackedRecognition);
393 - if (trackedRecognition != recogToReplace) {
394 - availableColors.add(trackedRecognition.color);
395 - }
396 - }
397 -
398 - if (recogToReplace == null && availableColors.isEmpty()) {
399 - logger.e("No room to track this object, aborting.");
400 - potentialObject.stopTracking();
401 - return;
402 - }
403 -
404 - // Finally safe to say we can track this object.
405 - logger.v(
406 - "Tracking object %s (%s) with detection confidence %.2f at position %s",
407 - potentialObject,
408 - potential.second.getTitle(),
409 - potential.first,
410 - potential.second.getLocation());
411 - final TrackedRecognition trackedRecognition = new TrackedRecognition();
412 - trackedRecognition.detectionConfidence = potential.first;
413 - trackedRecognition.trackedObject = potentialObject;
414 - trackedRecognition.title = potential.second.getTitle();
415 -
416 - // Use the color from a replaced object before taking one from the color queue.
417 - trackedRecognition.color =
418 - recogToReplace != null ? recogToReplace.color : availableColors.poll();
419 - trackedObjects.add(trackedRecognition);
420 - }
421 -}
1 -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 -
3 -Licensed under the Apache License, Version 2.0 (the "License");
4 -you may not use this file except in compliance with the License.
5 -You may obtain a copy of the License at
6 -
7 - http://www.apache.org/licenses/LICENSE-2.0
8 -
9 -Unless required by applicable law or agreed to in writing, software
10 -distributed under the License is distributed on an "AS IS" BASIS,
11 -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -See the License for the specific language governing permissions and
13 -limitations under the License.
14 -==============================================================================*/
15 -
16 -package org.tensorflow.demo.tracking;
17 -
18 -import android.graphics.Canvas;
19 -import android.graphics.Color;
20 -import android.graphics.Matrix;
21 -import android.graphics.Paint;
22 -import android.graphics.PointF;
23 -import android.graphics.RectF;
24 -import android.graphics.Typeface;
25 -import java.util.ArrayList;
26 -import java.util.HashMap;
27 -import java.util.LinkedList;
28 -import java.util.List;
29 -import java.util.Map;
30 -import java.util.Vector;
31 -import javax.microedition.khronos.opengles.GL10;
32 -import org.tensorflow.demo.env.Logger;
33 -import org.tensorflow.demo.env.Size;
34 -
35 -/**
36 - * True object detector/tracker class that tracks objects across consecutive preview frames.
37 - * It provides a simplified Java interface to the analogous native object defined by
38 - * jni/client_vision/tracking/object_tracker.*.
39 - *
40 - * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must
41 - * be allocated by ObjectTracker.getInstance(). In addition, release() should be called
42 - * as soon as the ObjectTracker is no longer needed, and before a new one is created.
43 - *
44 - * nextFrame() should be called as new frames become available, preferably as often as possible.
45 - *
46 - * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects
47 - * are associated with the ObjectTracker that created them, and are only valid while that
48 - * ObjectTracker still exists.
49 - */
50 -public class ObjectTracker {
51 - private static final Logger LOGGER = new Logger();
52 -
53 - private static boolean libraryFound = false;
54 -
55 - static {
56 - try {
57 - System.loadLibrary("tensorflow_demo");
58 - libraryFound = true;
59 - } catch (UnsatisfiedLinkError e) {
60 - LOGGER.e("libtensorflow_demo.so not found, tracking unavailable");
61 - }
62 - }
63 -
64 - private static final boolean DRAW_TEXT = false;
65 -
66 - /**
67 - * How many history points to keep track of and draw in the red history line.
68 - */
69 - private static final int MAX_DEBUG_HISTORY_SIZE = 30;
70 -
71 - /**
72 - * How many frames of optical flow deltas to record.
73 - * TODO(andrewharp): Push this down to the native level so it can be polled
74 - * efficiently into a an array for upload, instead of keeping a duplicate
75 - * copy in Java.
76 - */
77 - private static final int MAX_FRAME_HISTORY_SIZE = 200;
78 -
79 - private static final int DOWNSAMPLE_FACTOR = 2;
80 -
81 - private final byte[] downsampledFrame;
82 -
83 - protected static ObjectTracker instance;
84 -
85 - private final Map<String, TrackedObject> trackedObjects;
86 -
87 - private long lastTimestamp;
88 -
89 - private FrameChange lastKeypoints;
90 -
91 - private final Vector<PointF> debugHistory;
92 -
93 - private final LinkedList<TimestampedDeltas> timestampedDeltas;
94 -
95 - protected final int frameWidth;
96 - protected final int frameHeight;
97 - private final int rowStride;
98 - protected final boolean alwaysTrack;
99 -
100 - private static class TimestampedDeltas {
101 - final long timestamp;
102 - final byte[] deltas;
103 -
104 - public TimestampedDeltas(final long timestamp, final byte[] deltas) {
105 - this.timestamp = timestamp;
106 - this.deltas = deltas;
107 - }
108 - }
109 -
110 - /**
111 - * A simple class that records keypoint information, which includes
112 - * local location, score and type. This will be used in calculating
113 - * FrameChange.
114 - */
115 - public static class Keypoint {
116 - public final float x;
117 - public final float y;
118 - public final float score;
119 - public final int type;
120 -
121 - public Keypoint(final float x, final float y) {
122 - this.x = x;
123 - this.y = y;
124 - this.score = 0;
125 - this.type = -1;
126 - }
127 -
128 - public Keypoint(final float x, final float y, final float score, final int type) {
129 - this.x = x;
130 - this.y = y;
131 - this.score = score;
132 - this.type = type;
133 - }
134 -
135 - Keypoint delta(final Keypoint other) {
136 - return new Keypoint(this.x - other.x, this.y - other.y);
137 - }
138 - }
139 -
140 - /**
141 - * A simple class that could calculate Keypoint delta.
142 - * This class will be used in calculating frame translation delta
143 - * for optical flow.
144 - */
145 - public static class PointChange {
146 - public final Keypoint keypointA;
147 - public final Keypoint keypointB;
148 - Keypoint pointDelta;
149 - private final boolean wasFound;
150 -
151 - public PointChange(final float x1, final float y1,
152 - final float x2, final float y2,
153 - final float score, final int type,
154 - final boolean wasFound) {
155 - this.wasFound = wasFound;
156 -
157 - keypointA = new Keypoint(x1, y1, score, type);
158 - keypointB = new Keypoint(x2, y2);
159 - }
160 -
161 - public Keypoint getDelta() {
162 - if (pointDelta == null) {
163 - pointDelta = keypointB.delta(keypointA);
164 - }
165 - return pointDelta;
166 - }
167 - }
168 -
169 - /** A class that records a timestamped frame translation delta for optical flow. */
170 - public static class FrameChange {
171 - public static final int KEYPOINT_STEP = 7;
172 -
173 - public final Vector<PointChange> pointDeltas;
174 -
175 - private final float minScore;
176 - private final float maxScore;
177 -
178 - public FrameChange(final float[] framePoints) {
179 - float minScore = 100.0f;
180 - float maxScore = -100.0f;
181 -
182 - pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP);
183 -
184 - for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) {
185 - final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR;
186 - final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR;
187 -
188 - final boolean wasFound = framePoints[i + 2] > 0.0f;
189 -
190 - final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR;
191 - final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR;
192 - final float score = framePoints[i + 5];
193 - final int type = (int) framePoints[i + 6];
194 -
195 - minScore = Math.min(minScore, score);
196 - maxScore = Math.max(maxScore, score);
197 -
198 - pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound));
199 - }
200 -
201 - this.minScore = minScore;
202 - this.maxScore = maxScore;
203 - }
204 - }
205 -
206 - public static synchronized ObjectTracker getInstance(
207 - final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
208 - if (!libraryFound) {
209 - LOGGER.e(
210 - "Native object tracking support not found. "
211 - + "See tensorflow/examples/android/README.md for details.");
212 - return null;
213 - }
214 -
215 - if (instance == null) {
216 - instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
217 - instance.init();
218 - } else {
219 - throw new RuntimeException(
220 - "Tried to create a new objectracker before releasing the old one!");
221 - }
222 - return instance;
223 - }
224 -
225 - public static synchronized void clearInstance() {
226 - if (instance != null) {
227 - instance.release();
228 - }
229 - }
230 -
231 - protected ObjectTracker(
232 - final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
233 - this.frameWidth = frameWidth;
234 - this.frameHeight = frameHeight;
235 - this.rowStride = rowStride;
236 - this.alwaysTrack = alwaysTrack;
237 - this.timestampedDeltas = new LinkedList<TimestampedDeltas>();
238 -
239 - trackedObjects = new HashMap<String, TrackedObject>();
240 -
241 - debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE);
242 -
243 - downsampledFrame =
244 - new byte
245 - [(frameWidth + DOWNSAMPLE_FACTOR - 1)
246 - / DOWNSAMPLE_FACTOR
247 - * (frameHeight + DOWNSAMPLE_FACTOR - 1)
248 - / DOWNSAMPLE_FACTOR];
249 - }
250 -
251 - protected void init() {
252 - // The native tracker never sees the full frame, so pre-scale dimensions
253 - // by the downsample factor.
254 - initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack);
255 - }
256 -
257 - private final float[] matrixValues = new float[9];
258 -
259 - private long downsampledTimestamp;
260 -
261 - @SuppressWarnings("unused")
262 - public synchronized void drawOverlay(final GL10 gl,
263 - final Size cameraViewSize, final Matrix matrix) {
264 - final Matrix tempMatrix = new Matrix(matrix);
265 - tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR);
266 - tempMatrix.getValues(matrixValues);
267 - drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues);
268 - }
269 -
270 - public synchronized void nextFrame(
271 - final byte[] frameData, final byte[] uvData,
272 - final long timestamp, final float[] transformationMatrix,
273 - final boolean updateDebugInfo) {
274 - if (downsampledTimestamp != timestamp) {
275 - ObjectTracker.downsampleImageNative(
276 - frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
277 - downsampledTimestamp = timestamp;
278 - }
279 -
280 - // Do Lucas Kanade using the fullframe initializer.
281 - nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix);
282 -
283 - timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR)));
284 - while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) {
285 - timestampedDeltas.removeFirst();
286 - }
287 -
288 - for (final TrackedObject trackedObject : trackedObjects.values()) {
289 - trackedObject.updateTrackedPosition();
290 - }
291 -
292 - if (updateDebugInfo) {
293 - updateDebugHistory();
294 - }
295 -
296 - lastTimestamp = timestamp;
297 - }
298 -
299 - public synchronized void release() {
300 - releaseMemoryNative();
301 - synchronized (ObjectTracker.class) {
302 - instance = null;
303 - }
304 - }
305 -
306 - private void drawHistoryDebug(final Canvas canvas) {
307 - drawHistoryPoint(
308 - canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2);
309 - }
310 -
311 - private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) {
312 - final Paint p = new Paint();
313 - p.setAntiAlias(false);
314 - p.setTypeface(Typeface.SERIF);
315 -
316 - p.setColor(Color.RED);
317 - p.setStrokeWidth(2.0f);
318 -
319 - // Draw the center circle.
320 - p.setColor(Color.GREEN);
321 - canvas.drawCircle(startX, startY, 3.0f, p);
322 -
323 - p.setColor(Color.RED);
324 -
325 - // Iterate through in backwards order.
326 - synchronized (debugHistory) {
327 - final int numPoints = debugHistory.size();
328 - float lastX = startX;
329 - float lastY = startY;
330 - for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) {
331 - final PointF delta = debugHistory.get(numPoints - keypointNum - 1);
332 - final float newX = lastX + delta.x;
333 - final float newY = lastY + delta.y;
334 - canvas.drawLine(lastX, lastY, newX, newY, p);
335 - lastX = newX;
336 - lastY = newY;
337 - }
338 - }
339 - }
340 -
341 - private static int floatToChar(final float value) {
342 - return Math.max(0, Math.min((int) (value * 255.999f), 255));
343 - }
344 -
345 - private void drawKeypointsDebug(final Canvas canvas) {
346 - final Paint p = new Paint();
347 - if (lastKeypoints == null) {
348 - return;
349 - }
350 - final int keypointSize = 3;
351 -
352 - final float minScore = lastKeypoints.minScore;
353 - final float maxScore = lastKeypoints.maxScore;
354 -
355 - for (final PointChange keypoint : lastKeypoints.pointDeltas) {
356 - if (keypoint.wasFound) {
357 - final int r =
358 - floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore));
359 - final int b =
360 - floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore));
361 -
362 - final int color = 0xFF000000 | (r << 16) | b;
363 - p.setColor(color);
364 -
365 - final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y,
366 - keypoint.keypointB.x, keypoint.keypointB.y};
367 - canvas.drawRect(screenPoints[2] - keypointSize,
368 - screenPoints[3] - keypointSize,
369 - screenPoints[2] + keypointSize,
370 - screenPoints[3] + keypointSize, p);
371 - p.setColor(Color.CYAN);
372 - canvas.drawLine(screenPoints[2], screenPoints[3],
373 - screenPoints[0], screenPoints[1], p);
374 -
375 - if (DRAW_TEXT) {
376 - p.setColor(Color.WHITE);
377 - canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score,
378 - keypoint.keypointA.x, keypoint.keypointA.y, p);
379 - }
380 - } else {
381 - p.setColor(Color.YELLOW);
382 - final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y};
383 - canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p);
384 - }
385 - }
386 - }
387 -
388 - private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX,
389 - final float positionY, final float radius) {
390 - final RectF currPosition = getCurrentPosition(timestamp,
391 - new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius));
392 - return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY);
393 - }
394 -
395 - private synchronized RectF getCurrentPosition(final long timestamp, final RectF
396 - oldPosition) {
397 - final RectF downscaledFrameRect = downscaleRect(oldPosition);
398 -
399 - final float[] delta = new float[4];
400 - getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top,
401 - downscaledFrameRect.right, downscaledFrameRect.bottom, delta);
402 -
403 - final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
404 -
405 - return upscaleRect(newPosition);
406 - }
407 -
408 - private void updateDebugHistory() {
409 - lastKeypoints = new FrameChange(getKeypointsNative(false));
410 -
411 - if (lastTimestamp == 0) {
412 - return;
413 - }
414 -
415 - final PointF delta =
416 - getAccumulatedDelta(
417 - lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100);
418 -
419 - synchronized (debugHistory) {
420 - debugHistory.add(delta);
421 -
422 - while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) {
423 - debugHistory.remove(0);
424 - }
425 - }
426 - }
427 -
428 - public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) {
429 - canvas.save();
430 - canvas.setMatrix(frameToCanvas);
431 -
432 - drawHistoryDebug(canvas);
433 - drawKeypointsDebug(canvas);
434 -
435 - canvas.restore();
436 - }
437 -
438 - public Vector<String> getDebugText() {
439 - final Vector<String> lines = new Vector<String>();
440 -
441 - if (lastKeypoints != null) {
442 - lines.add("Num keypoints " + lastKeypoints.pointDeltas.size());
443 - lines.add("Min score: " + lastKeypoints.minScore);
444 - lines.add("Max score: " + lastKeypoints.maxScore);
445 - }
446 -
447 - return lines;
448 - }
449 -
450 - public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) {
451 - final List<byte[]> frameDeltas = new ArrayList<byte[]>();
452 - while (timestampedDeltas.size() > 0) {
453 - final TimestampedDeltas currentDeltas = timestampedDeltas.peek();
454 - if (currentDeltas.timestamp <= endFrameTime) {
455 - frameDeltas.add(currentDeltas.deltas);
456 - timestampedDeltas.removeFirst();
457 - } else {
458 - break;
459 - }
460 - }
461 -
462 - return frameDeltas;
463 - }
464 -
465 - private RectF downscaleRect(final RectF fullFrameRect) {
466 - return new RectF(
467 - fullFrameRect.left / DOWNSAMPLE_FACTOR,
468 - fullFrameRect.top / DOWNSAMPLE_FACTOR,
469 - fullFrameRect.right / DOWNSAMPLE_FACTOR,
470 - fullFrameRect.bottom / DOWNSAMPLE_FACTOR);
471 - }
472 -
473 - private RectF upscaleRect(final RectF downsampledFrameRect) {
474 - return new RectF(
475 - downsampledFrameRect.left * DOWNSAMPLE_FACTOR,
476 - downsampledFrameRect.top * DOWNSAMPLE_FACTOR,
477 - downsampledFrameRect.right * DOWNSAMPLE_FACTOR,
478 - downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR);
479 - }
480 -
481 - /**
482 - * A TrackedObject represents a native TrackedObject, and provides access to the
483 - * relevant native tracking information available after every frame update. They may
484 - * be safely passed around and accessed externally, but will become invalid after
485 - * stopTracking() is called or the related creating ObjectTracker is deactivated.
486 - *
487 - * @author andrewharp@google.com (Andrew Harp)
488 - */
489 - public class TrackedObject {
490 - private final String id;
491 -
492 - private long lastExternalPositionTime;
493 -
494 - private RectF lastTrackedPosition;
495 - private boolean visibleInLastFrame;
496 -
497 - private boolean isDead;
498 -
499 - TrackedObject(final RectF position, final long timestamp, final byte[] data) {
500 - isDead = false;
501 -
502 - id = Integer.toString(this.hashCode());
503 -
504 - lastExternalPositionTime = timestamp;
505 -
506 - synchronized (ObjectTracker.this) {
507 - registerInitialAppearance(position, data);
508 - setPreviousPosition(position, timestamp);
509 - trackedObjects.put(id, this);
510 - }
511 - }
512 -
513 - public void stopTracking() {
514 - checkValidObject();
515 -
516 - synchronized (ObjectTracker.this) {
517 - isDead = true;
518 - forgetNative(id);
519 - trackedObjects.remove(id);
520 - }
521 - }
522 -
523 - public float getCurrentCorrelation() {
524 - checkValidObject();
525 - return ObjectTracker.this.getCurrentCorrelation(id);
526 - }
527 -
528 - void registerInitialAppearance(final RectF position, final byte[] data) {
529 - final RectF externalPosition = downscaleRect(position);
530 - registerNewObjectWithAppearanceNative(id,
531 - externalPosition.left, externalPosition.top,
532 - externalPosition.right, externalPosition.bottom,
533 - data);
534 - }
535 -
536 - synchronized void setPreviousPosition(final RectF position, final long timestamp) {
537 - checkValidObject();
538 - synchronized (ObjectTracker.this) {
539 - if (lastExternalPositionTime > timestamp) {
540 - LOGGER.w("Tried to use older position time!");
541 - return;
542 - }
543 - final RectF externalPosition = downscaleRect(position);
544 - lastExternalPositionTime = timestamp;
545 -
546 - setPreviousPositionNative(id,
547 - externalPosition.left, externalPosition.top,
548 - externalPosition.right, externalPosition.bottom,
549 - lastExternalPositionTime);
550 -
551 - updateTrackedPosition();
552 - }
553 - }
554 -
555 - void setCurrentPosition(final RectF position) {
556 - checkValidObject();
557 - final RectF downsampledPosition = downscaleRect(position);
558 - synchronized (ObjectTracker.this) {
559 - setCurrentPositionNative(id,
560 - downsampledPosition.left, downsampledPosition.top,
561 - downsampledPosition.right, downsampledPosition.bottom);
562 - }
563 - }
564 -
565 - private synchronized void updateTrackedPosition() {
566 - checkValidObject();
567 -
568 - final float[] delta = new float[4];
569 - getTrackedPositionNative(id, delta);
570 - lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
571 -
572 - visibleInLastFrame = isObjectVisible(id);
573 - }
574 -
575 - public synchronized RectF getTrackedPositionInPreviewFrame() {
576 - checkValidObject();
577 -
578 - if (lastTrackedPosition == null) {
579 - return null;
580 - }
581 - return upscaleRect(lastTrackedPosition);
582 - }
583 -
584 - synchronized long getLastExternalPositionTime() {
585 - return lastExternalPositionTime;
586 - }
587 -
588 - public synchronized boolean visibleInLastPreviewFrame() {
589 - return visibleInLastFrame;
590 - }
591 -
592 - private void checkValidObject() {
593 - if (isDead) {
594 - throw new RuntimeException("TrackedObject already removed from tracking!");
595 - } else if (ObjectTracker.this != instance) {
596 - throw new RuntimeException("TrackedObject created with another ObjectTracker!");
597 - }
598 - }
599 - }
600 -
601 - public synchronized TrackedObject trackObject(
602 - final RectF position, final long timestamp, final byte[] frameData) {
603 - if (downsampledTimestamp != timestamp) {
604 - ObjectTracker.downsampleImageNative(
605 - frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
606 - downsampledTimestamp = timestamp;
607 - }
608 - return new TrackedObject(position, timestamp, downsampledFrame);
609 - }
610 -
611 - public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) {
612 - return new TrackedObject(position, lastTimestamp, frameData);
613 - }
614 -
615 - /** ********************* NATIVE CODE ************************************ */
616 -
617 - /** This will contain an opaque pointer to the native ObjectTracker */
618 - private long nativeObjectTracker;
619 -
620 - private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);
621 -
622 - protected native void registerNewObjectWithAppearanceNative(
623 - String objectId, float x1, float y1, float x2, float y2, byte[] data);
624 -
625 - protected native void setPreviousPositionNative(
626 - String objectId, float x1, float y1, float x2, float y2, long timestamp);
627 -
628 - protected native void setCurrentPositionNative(
629 - String objectId, float x1, float y1, float x2, float y2);
630 -
631 - protected native void forgetNative(String key);
632 -
633 - protected native String getModelIdNative(String key);
634 -
635 - protected native boolean haveObject(String key);
636 - protected native boolean isObjectVisible(String key);
637 - protected native float getCurrentCorrelation(String key);
638 -
639 - protected native float getMatchScore(String key);
640 -
641 - protected native void getTrackedPositionNative(String key, float[] points);
642 -
643 - protected native void nextFrameNative(
644 - byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix);
645 -
646 - protected native void releaseMemoryNative();
647 -
648 - protected native void getCurrentPositionNative(long timestamp,
649 - final float positionX1, final float positionY1,
650 - final float positionX2, final float positionY2,
651 - final float[] delta);
652 -
653 - protected native byte[] getKeypointsPacked(float scaleFactor);
654 -
655 - protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints);
656 -
657 - protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas);
658 -
659 - protected static native void downsampleImageNative(
660 - int width, int height, int rowStride, byte[] input, int factor, byte[] output);
661 -}