장수창

modified test app

Showing 114 changed files with 0 additions and 16626 deletions
# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore
*.iml
.idea/compiler.xml
.idea/copyright
.idea/dictionaries
.idea/gradle.xml
.idea/libraries
.idea/inspectionProfiles
.idea/misc.xml
.idea/modules.xml
.idea/runConfigurations.xml
.idea/tasks.xml
.idea/workspace.xml
.gradle
local.properties
.DS_Store
build/
gradleBuild/
*.apk
*.ap_
*.dex
*.class
bin/
gen/
out/
*.log
.navigation/
/captures
.externalNativeBuild
<component name="ProjectCodeStyleConfiguration">
<code_scheme name="Project" version="173">
<codeStyleSettings language="XML">
<indentOptions>
<option name="CONTINUATION_INDENT_SIZE" value="4" />
</indentOptions>
<arrangement>
<rules>
<section>
<rule>
<match>
<AND>
<NAME>xmlns:android</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>xmlns:.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*:id</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*:name</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>name</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>style</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
<order>ANDROID_ATTRIBUTE_ORDER</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>.*</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
</rules>
</arrangement>
</codeStyleSettings>
</code_scheme>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/../../.." vcs="Git" />
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.demo">
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<uses-feature android:name="android.hardware.camera.autofocus" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.RECORD_AUDIO" />
<application android:allowBackup="true"
android:debuggable="true"
android:label="@string/app_name"
android:icon="@drawable/ic_launcher"
android:theme="@style/MaterialTheme">
<!-- <activity android:name="org.tensorflow.demo.ClassifierActivity"-->
<!-- android:screenOrientation="portrait"-->
<!-- android:label="@string/activity_name_classification">-->
<!-- <intent-filter>-->
<!-- <action android:name="android.intent.action.MAIN" />-->
<!-- <category android:name="android.intent.category.LAUNCHER" />-->
<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
<!-- </intent-filter>-->
<!-- </activity>-->
<activity android:name="org.tensorflow.demo.DetectorActivity"
android:screenOrientation="portrait"
android:label="@string/activity_name_detection">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
<category android:name="android.intent.category.LEANBACK_LAUNCHER" />
</intent-filter>
</activity>
<!-- <activity android:name="org.tensorflow.demo.StylizeActivity"-->
<!-- android:screenOrientation="portrait"-->
<!-- android:label="@string/activity_name_stylize">-->
<!-- <intent-filter>-->
<!-- <action android:name="android.intent.action.MAIN" />-->
<!-- <category android:name="android.intent.category.LAUNCHER" />-->
<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
<!-- </intent-filter>-->
<!-- </activity>-->
<!-- <activity android:name="org.tensorflow.demo.SpeechActivity"-->
<!-- android:screenOrientation="portrait"-->
<!-- android:label="@string/activity_name_speech">-->
<!-- <intent-filter>-->
<!-- <action android:name="android.intent.action.MAIN" />-->
<!-- <category android:name="android.intent.category.LAUNCHER" />-->
<!-- <category android:name="android.intent.category.LEANBACK_LAUNCHER" />-->
<!-- </intent-filter>-->
<!-- </activity>-->
</application>
</manifest>
# Description:
# TensorFlow camera demo app for Android.
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
LINKER_SCRIPT = "jni/version_script.lds"
# libtensorflow_demo.so contains the native code for image colorspace conversion
# and object tracking used by the demo. It does not require TF as a dependency
# to build if STANDALONE_DEMO_LIB is defined.
# TF support for the demo is provided separately by libtensorflow_inference.so.
cc_binary(
name = "libtensorflow_demo.so",
srcs = glob([
"jni/**/*.cc",
"jni/**/*.h",
]),
copts = tf_copts(),
defines = ["STANDALONE_DEMO_LIB"],
linkopts = [
"-landroid",
"-ldl",
"-ljnigraphics",
"-llog",
"-lm",
"-z defs",
"-s",
"-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
],
linkshared = 1,
linkstatic = 1,
tags = [
"manual",
"notap",
],
deps = [
LINKER_SCRIPT,
],
)
cc_library(
name = "tensorflow_native_libs",
srcs = [
":libtensorflow_demo.so",
"//tensorflow/tools/android/inference_interface:libtensorflow_inference.so",
],
tags = [
"manual",
"notap",
],
)
android_binary(
name = "tensorflow_demo",
srcs = glob([
"src/**/*.java",
]),
# Package assets from assets dir as well as all model targets. Remove undesired models
# (and corresponding Activities in source) to reduce APK size.
assets = [
"//tensorflow/examples/android/assets:asset_files",
":external_assets",
],
assets_dir = "",
custom_package = "org.tensorflow.demo",
manifest = "AndroidManifest.xml",
resource_files = glob(["res/**"]),
tags = [
"manual",
"notap",
],
deps = [
":tensorflow_native_libs",
"//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java",
],
)
# LINT.IfChange
filegroup(
name = "external_assets",
srcs = [
"@inception_v1//:model_files",
"@mobile_ssd//:model_files",
"@speech_commands//:model_files",
"@stylize//:model_files",
],
)
# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle)
filegroup(
name = "java_files",
srcs = glob(["src/**/*.java"]),
)
filegroup(
name = "jni_files",
srcs = glob([
"jni/**/*.cc",
"jni/**/*.h",
]),
)
filegroup(
name = "resource_files",
srcs = glob(["res/**"]),
)
exports_files([
"AndroidManifest.xml",
])
# TensorFlow Android Camera Demo
This folder contains an example application utilizing TensorFlow for Android
devices.
## Description
The demos in this folder are designed to give straightforward samples of using
TensorFlow in mobile applications.
Inference is done using the [TensorFlow Android Inference
Interface](../../tools/android/inference_interface), which may be built
separately if you want a standalone library to drop into your existing
application. Object tracking and efficient YUV -> RGB conversion are handled by
`libtensorflow_demo.so`.
A device running Android 5.0 (API 21) or higher is required to run the demo due
to the use of the camera2 API, although the native libraries themselves can run
on API >= 14 devices.
## Current samples:
1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java):
Uses the [Google Inception](https://arxiv.org/abs/1409.4842)
model to classify camera frames in real-time, displaying the top results
in an overlay on the camera image.
2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
Demonstrates an SSD-Mobilenet model trained using the
[Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection/)
introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to
localize and track objects (from 80 categories) in the camera preview
in real-time.
3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java):
Uses a model based on [A Learned Representation For Artistic
Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
image to that of a number of different artists.
4. [TF
Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java):
Runs a simple speech recognition model built by the [audio training
tutorial](https://www.tensorflow.org/versions/master/tutorials/audio_recognition). Listens
for a small set of words, and highlights them in the UI when they are
recognized.
<img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%">
## Prebuilt Components:
The fastest path to trying the demo is to download the [prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
Also available are precompiled native libraries, and a jcenter package that you
may simply drop into your own applications. See
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
for more details.
## Running the Demo
Once the app is installed it can be started via the "TF Classify", "TF Detect",
"TF Stylize", and "TF Speech" icons, which have the orange TensorFlow logo as
their icon.
While running the activities, pressing the volume keys on your device will
toggle debug visualizations on/off, rendering additional info to the screen that
may be useful for development purposes.
## Building in Android Studio using the TensorFlow AAR from JCenter
The simplest way to compile the demo app yourself, and try out changes to the
project code is to use AndroidStudio. Simply set this `android` directory as the
project root.
Then edit the `build.gradle` file and change the value of `nativeBuildSystem` to
`'none'` so that the project is built in the simplest way possible:
```None
def nativeBuildSystem = 'none'
```
While this project includes full build integration for TensorFlow, this setting
disables it, and uses the TensorFlow Inference Interface package from JCenter.
Note: Currently, in this build mode, YUV -> RGB is done using a less efficient
Java implementation, and object tracking is not available in the "TF Detect"
activity. Setting the build system to `'cmake'` currently only builds
`libtensorflow_demo.so`, which provides fast YUV -> RGB conversion and object
tracking, while still acquiring TensorFlow support via the downloaded AAR, so it
may be a lightweight way to enable these features.
For any project that does not include custom low level TensorFlow code, this is
likely sufficient.
For details on how to include this JCenter package in your own project see
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
## Building the Demo with TensorFlow from Source
Pick your preferred approach below. At the moment, we have full support for
Bazel, and partial support for gradle, cmake, make, and Android Studio.
As a first step for all build types, clone the TensorFlow repo with:
```
git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git
```
Note that `--recurse-submodules` is necessary to prevent some issues with
protobuf compilation.
### Bazel
NOTE: Bazel does not currently support building for Android on Windows. Full
support for gradle/cmake builds is coming soon, but in the meantime we suggest
that Windows users download the
[prebuilt demo APK](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk)
instead.
##### Install Bazel and Android Prerequisites
Bazel is the primary build system for TensorFlow. To build with Bazel, it and
the Android NDK and SDK must be installed on your system.
1. Install the latest version of Bazel as per the instructions [on the Bazel
website](https://bazel.build/versions/master/docs/install.html).
2. The Android NDK is required to build the native (C/C++) TensorFlow code. The
current recommended version is 14b, which may be found
[here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
3. The Android SDK and build tools may be obtained
[here](https://developer.android.com/tools/revisions/build-tools.html), or
alternatively as part of [Android
Studio](https://developer.android.com/studio/index.html). Build tools API >=
23 is required to build the TF Android demo (though it will run on API >= 21
devices).
##### Edit WORKSPACE
NOTE: As long as you have the SDK and NDK installed, the `./configure` script
will create these rules for you. Answer "Yes" when the script asks to
automatically configure the `./WORKSPACE`.
The Android entries in
[`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L19-L36) must be uncommented
with the paths filled in appropriately depending on where you installed the NDK
and SDK. Otherwise an error such as: "The external label
'//external:android/sdk' is not bound to anything" will be reported.
Also edit the API levels for the SDK in WORKSPACE to the highest level you have
installed in your SDK. This must be >= 23 (this is completely independent of the
API level of the demo, which is defined in AndroidManifest.xml). The NDK API
level may remain at 14.
##### Install Model Files (optional)
The TensorFlow `GraphDef`s that contain the model definitions and weights are
not packaged in the repo because of their size. They are downloaded
automatically and packaged with the APK by Bazel via a new_http_archive defined
in `WORKSPACE` during the build process, and by Gradle via
download-models.gradle.
**Optional**: If you wish to place the models in your assets manually, remove
all of the `model_files` entries from the `assets` list in `tensorflow_demo`
found in the [`BUILD`](BUILD#L92) file. Then download and extract the archives
yourself to the `assets` directory in the source tree:
```bash
BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models
for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip
do
curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP}
unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/
done
```
This will extract the models and their associated metadata files to the local
assets/ directory.
If you are using Gradle, make sure to remove download-models.gradle reference
from build.gradle after your manually download models; otherwise gradle might
download models again and overwrite your models.
##### Build
After editing your WORKSPACE file to update the SDK/NDK configuration, you may
build the APK. Run this from your workspace root:
```bash
bazel build --cxxopt='--std=c++11' -c opt //tensorflow/examples/android:tensorflow_demo
```
##### Install
Make sure that adb debugging is enabled on your Android 5.0 (API 21) or later
device, then after building use the following command from your workspace root
to install the APK:
```bash
adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
```
### Android Studio with Bazel
Android Studio may be used to build the demo in conjunction with Bazel. First,
make sure that you can build with Bazel following the above directions. Then,
look at [build.gradle](build.gradle) and make sure that the path to Bazel
matches that of your system.
At this point you can add the tensorflow/examples/android directory as a new
Android Studio project. Click through installing all the Gradle extensions it
requests, and you should be able to have Android Studio build the demo like any
other application (it will call out to Bazel to build the native code with the
NDK).
### CMake
Full CMake support for the demo is coming soon, but for now it is possible to
build the TensorFlow Android Inference library using
[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake).
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
# It is necessary to use this filegroup rather than globbing the files in this
# folder directly the examples/android:tensorflow_demo target due to the fact
# that assets_dir is necessarily set to "" there (to allow using other
# arbitrary targets as assets).
filegroup(
name = "asset_files",
srcs = glob(
["**/*"],
exclude = ["BUILD"],
),
)
This file is too large to display.
// This file provides basic support for building the TensorFlow demo
// in Android Studio with Gradle.
//
// Note that Bazel is still used by default to compile the native libs,
// and should be installed at the location noted below. This build file
// automates the process of calling out to it and copying the compiled
// libraries back into the appropriate directory.
//
// Alternatively, experimental support for Makefile builds is provided by
// setting nativeBuildSystem below to 'makefile'. This will allow building the demo
// on Windows machines, but note that full equivalence with the Bazel
// build is not yet guaranteed. See comments below for caveats and tips
// for speeding up the build, such as enabling ccache.
// NOTE: Running a make build will cause subsequent Bazel builds to *fail*
// unless the contrib/makefile/downloads/ and gen/ dirs are deleted afterwards.
// The cmake build only creates libtensorflow_demo.so. In this situation,
// libtensorflow_inference.so will be acquired via the tensorflow.aar dependency.
// It is necessary to customize Gradle's build directory, as otherwise
// it will conflict with the BUILD file used by Bazel on case-insensitive OSs.
project.buildDir = 'gradleBuild'
getProject().setBuildDir('gradleBuild')
buildscript {
repositories {
jcenter()
google()
}
dependencies {
classpath 'com.android.tools.build:gradle:3.3.1'
classpath 'org.apache.httpcomponents:httpclient:4.5.4'
}
}
allprojects {
repositories {
jcenter()
google()
}
}
// set to 'bazel', 'cmake', 'makefile', 'none'
def nativeBuildSystem = 'none'
// Controls output directory in APK and CPU type for Bazel builds.
// NOTE: Does not affect the Makefile build target API (yet), which currently
// assumes armeabi-v7a. If building with make, changing this will require
// editing the Makefile as well.
// The CMake build has only been tested with armeabi-v7a; others may not work.
def cpuType = 'armeabi-v7a'
// Output directory in the local directory for packaging into the APK.
def nativeOutDir = 'libs/' + cpuType
// Default to building with Bazel and override with make if requested.
def nativeBuildRule = 'buildNativeBazel'
def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so'
def inferenceLibPath = '../../../bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so'
// Override for Makefile builds.
if (nativeBuildSystem == 'makefile') {
nativeBuildRule = 'buildNativeMake'
demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_demo.so'
inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_inference.so'
}
// If building with Bazel, this is the location of the bazel binary.
// NOTE: Bazel does not yet support building for Android on Windows,
// so in this case the Makefile build must be used as described above.
def bazelLocation = '/usr/local/bin/bazel'
// import DownloadModels task
project.ext.ASSET_DIR = projectDir.toString() + '/assets'
project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
// Download default models; if you wish to use your own models then
// place them in the "assets" directory and comment out this line.
apply from: "download-models.gradle"
apply plugin: 'com.android.application'
android {
compileSdkVersion 23
if (nativeBuildSystem == 'cmake') {
defaultConfig {
applicationId = 'org.tensorflow.demo'
minSdkVersion 21
targetSdkVersion 23
ndk {
abiFilters "${cpuType}"
}
externalNativeBuild {
cmake {
arguments '-DANDROID_STL=c++_static'
}
}
}
externalNativeBuild {
cmake {
path './jni/CMakeLists.txt'
}
}
}
lintOptions {
abortOnError false
}
sourceSets {
main {
if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') {
// TensorFlow Java API sources.
java {
srcDir '../../java/src/main/java'
exclude '**/examples/**'
}
// Android TensorFlow wrappers, etc.
java {
srcDir '../../tools/android/inference_interface/java'
}
}
// Android demo app sources.
java {
srcDir 'src'
}
manifest.srcFile 'AndroidManifest.xml'
resources.srcDirs = ['src']
aidl.srcDirs = ['src']
renderscript.srcDirs = ['src']
res.srcDirs = ['res']
assets.srcDirs = [project.ext.ASSET_DIR]
jniLibs.srcDirs = ['libs']
}
debug.setRoot('build-types/debug')
release.setRoot('build-types/release')
}
defaultConfig {
targetSdkVersion 23
minSdkVersion 21
}
}
task buildNativeBazel(type: Exec) {
workingDir '../../..'
commandLine bazelLocation, 'build', '-c', 'opt', \
'tensorflow/examples/android:tensorflow_native_libs', \
'--crosstool_top=//external:android/crosstool', \
'--cpu=' + cpuType, \
'--host_crosstool_top=@bazel_tools//tools/cpp:toolchain'
}
task buildNativeMake(type: Exec) {
environment "NDK_ROOT", android.ndkDirectory
// Tip: install ccache and uncomment the following to speed up
// builds significantly.
// environment "CC_PREFIX", 'ccache'
workingDir '../../..'
commandLine 'tensorflow/contrib/makefile/build_all_android.sh', \
'-s', \
'tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in', \
'-t', \
'libtensorflow_inference.so libtensorflow_demo.so all' \
, '-a', cpuType \
//, '-T' // Uncomment to skip protobuf and speed up subsequent builds.
}
task copyNativeLibs(type: Copy) {
from demoLibPath
from inferenceLibPath
into nativeOutDir
duplicatesStrategy = 'include'
dependsOn nativeBuildRule
fileMode 0644
}
tasks.whenTaskAdded { task ->
if (nativeBuildSystem == 'bazel' || nativeBuildSystem == 'makefile') {
if (task.name == 'assembleDebug') {
task.dependsOn 'copyNativeLibs'
}
if (task.name == 'assembleRelease') {
task.dependsOn 'copyNativeLibs'
}
}
}
dependencies {
if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') {
implementation 'org.tensorflow:tensorflow-android:+'
}
}
/*
* download-models.gradle
* Downloads model files from ${MODEL_URL} into application's asset folder
* Input:
* project.ext.TMP_DIR: absolute path to hold downloaded zip files
* project.ext.ASSET_DIR: absolute path to save unzipped model files
* Output:
* 3 model files will be downloaded into given folder of ext.ASSET_DIR
*/
// hard coded model files
// LINT.IfChange
def models = ['inception_v1.zip',
'object_detection/ssd_mobilenet_v1_android_export.zip',
'stylize_v1.zip',
'speech_commands_conv_actions.zip']
// LINT.ThenChange(//tensorflow/examples/android/BUILD)
// Root URL for model archives
def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models'
buildscript {
repositories {
jcenter()
}
dependencies {
classpath 'de.undercouch:gradle-download-task:3.2.0'
}
}
import de.undercouch.gradle.tasks.download.Download
task downloadFile(type: Download){
for (f in models) {
src "${MODEL_URL}/" + f
}
dest new File(project.ext.TMP_DIR)
overwrite true
}
task extractModels(type: Copy) {
for (f in models) {
def localFile = f.split("/")[-1]
from zipTree(project.ext.TMP_DIR + '/' + localFile)
}
into file(project.ext.ASSET_DIR)
fileMode 0644
exclude '**/LICENSE'
def needDownload = false
for (f in models) {
def localFile = f.split("/")[-1]
if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) {
needDownload = true
}
}
if (needDownload) {
dependsOn downloadFile
}
}
tasks.whenTaskAdded { task ->
if (task.name == 'assembleDebug') {
task.dependsOn 'extractModels'
}
if (task.name == 'assembleRelease') {
task.dependsOn 'extractModels'
}
}
#Sat Nov 18 15:06:47 CET 2017
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-4.1-all.zip
#!/usr/bin/env bash
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS=""
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn ( ) {
echo "$*"
}
die ( ) {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
esac
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin, switch paths to Windows format before running java
if $cygwin ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=$((i+1))
done
case $i in
(0) set -- ;;
(1) set -- "$args0" ;;
(2) set -- "$args0" "$args1" ;;
(3) set -- "$args0" "$args1" "$args2" ;;
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
function splitJvmOpts() {
JVM_OPTS=("$@")
}
eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS=
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto init
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto init
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:init
@rem Get command-line arguments, handling Windowz variants
if not "%OS%" == "Windows_NT" goto win9xME_args
if "%@eval[2+2]" == "4" goto 4NT_args
:win9xME_args
@rem Slurp the command line arguments.
set CMD_LINE_ARGS=
set _SKIP=2
:win9xME_args_slurp
if "x%~1" == "x" goto execute
set CMD_LINE_ARGS=%*
goto execute
:4NT_args
@rem Get arguments from the 4NT Shell from JP Software
set CMD_LINE_ARGS=%$
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
project(TENSORFLOW_DEMO)
cmake_minimum_required(VERSION 3.4.1)
set(CMAKE_VERBOSE_MAKEFILE on)
get_filename_component(TF_SRC_ROOT ${CMAKE_SOURCE_DIR}/../../../.. ABSOLUTE)
get_filename_component(SAMPLE_SRC_DIR ${CMAKE_SOURCE_DIR}/.. ABSOLUTE)
if (ANDROID_ABI MATCHES "^armeabi-v7a$")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon")
elseif(ANDROID_ABI MATCHES "^arm64-v8a")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -ftree-vectorize")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSTANDALONE_DEMO_LIB \
-std=c++11 -fno-exceptions -fno-rtti -O2 -Wno-narrowing \
-fPIE")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
-Wl,--allow-multiple-definition \
-Wl,--whole-archive -fPIE -v")
file(GLOB_RECURSE tensorflow_demo_sources ${SAMPLE_SRC_DIR}/jni/*.*)
add_library(tensorflow_demo SHARED
${tensorflow_demo_sources})
target_include_directories(tensorflow_demo PRIVATE
${TF_SRC_ROOT}
${CMAKE_SOURCE_DIR})
target_link_libraries(tensorflow_demo
android
log
jnigraphics
m
atomic
z)
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file binds the native image utility code to the Java class
// which exposes them.
#include <jni.h>
#include <stdio.h>
#include <stdlib.h>
#include "tensorflow/examples/android/jni/rgb2yuv.h"
#include "tensorflow/examples/android/jni/yuv2rgb.h"
#define IMAGEUTILS_METHOD(METHOD_NAME) \
Java_org_tensorflow_demo_env_ImageUtils_##METHOD_NAME // NOLINT
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
jint width, jint height, jboolean halfSize);
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
jintArray output, jint width, jint height, jint y_row_stride,
jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize);
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
jint height);
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
JNIEnv* env, jclass clazz, jintArray input, jbyteArray output,
jint width, jint height);
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
jint width, jint height);
#ifdef __cplusplus
}
#endif
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertYUV420SPToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray input, jintArray output,
jint width, jint height, jboolean halfSize) {
jboolean inputCopy = JNI_FALSE;
jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
jboolean outputCopy = JNI_FALSE;
jint* const o = env->GetIntArrayElements(output, &outputCopy);
if (halfSize) {
ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(i),
reinterpret_cast<uint32_t*>(o), width,
height);
} else {
ConvertYUV420SPToARGB8888(reinterpret_cast<uint8_t*>(i),
reinterpret_cast<uint8_t*>(i) + width * height,
reinterpret_cast<uint32_t*>(o), width, height);
}
env->ReleaseByteArrayElements(input, i, JNI_ABORT);
env->ReleaseIntArrayElements(output, o, 0);
}
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420ToARGB8888)(
JNIEnv* env, jclass clazz, jbyteArray y, jbyteArray u, jbyteArray v,
jintArray output, jint width, jint height, jint y_row_stride,
jint uv_row_stride, jint uv_pixel_stride, jboolean halfSize) {
jboolean inputCopy = JNI_FALSE;
jbyte* const y_buff = env->GetByteArrayElements(y, &inputCopy);
jboolean outputCopy = JNI_FALSE;
jint* const o = env->GetIntArrayElements(output, &outputCopy);
if (halfSize) {
ConvertYUV420SPToARGB8888HalfSize(reinterpret_cast<uint8_t*>(y_buff),
reinterpret_cast<uint32_t*>(o), width,
height);
} else {
jbyte* const u_buff = env->GetByteArrayElements(u, &inputCopy);
jbyte* const v_buff = env->GetByteArrayElements(v, &inputCopy);
ConvertYUV420ToARGB8888(
reinterpret_cast<uint8_t*>(y_buff), reinterpret_cast<uint8_t*>(u_buff),
reinterpret_cast<uint8_t*>(v_buff), reinterpret_cast<uint32_t*>(o),
width, height, y_row_stride, uv_row_stride, uv_pixel_stride);
env->ReleaseByteArrayElements(u, u_buff, JNI_ABORT);
env->ReleaseByteArrayElements(v, v_buff, JNI_ABORT);
}
env->ReleaseByteArrayElements(y, y_buff, JNI_ABORT);
env->ReleaseIntArrayElements(output, o, 0);
}
JNIEXPORT void JNICALL IMAGEUTILS_METHOD(convertYUV420SPToRGB565)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output, jint width,
jint height) {
jboolean inputCopy = JNI_FALSE;
jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
jboolean outputCopy = JNI_FALSE;
jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
ConvertYUV420SPToRGB565(reinterpret_cast<uint8_t*>(i),
reinterpret_cast<uint16_t*>(o), width, height);
env->ReleaseByteArrayElements(input, i, JNI_ABORT);
env->ReleaseByteArrayElements(output, o, 0);
}
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertARGB8888ToYUV420SP)(
JNIEnv* env, jclass clazz, jintArray input, jbyteArray output,
jint width, jint height) {
jboolean inputCopy = JNI_FALSE;
jint* const i = env->GetIntArrayElements(input, &inputCopy);
jboolean outputCopy = JNI_FALSE;
jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
ConvertARGB8888ToYUV420SP(reinterpret_cast<uint32_t*>(i),
reinterpret_cast<uint8_t*>(o), width, height);
env->ReleaseIntArrayElements(input, i, JNI_ABORT);
env->ReleaseByteArrayElements(output, o, 0);
}
JNIEXPORT void JNICALL
IMAGEUTILS_METHOD(convertRGB565ToYUV420SP)(
JNIEnv* env, jclass clazz, jbyteArray input, jbyteArray output,
jint width, jint height) {
jboolean inputCopy = JNI_FALSE;
jbyte* const i = env->GetByteArrayElements(input, &inputCopy);
jboolean outputCopy = JNI_FALSE;
jbyte* const o = env->GetByteArrayElements(output, &outputCopy);
ConvertRGB565ToYUV420SP(reinterpret_cast<uint16_t*>(i),
reinterpret_cast<uint8_t*>(o), width, height);
env->ReleaseByteArrayElements(input, i, JNI_ABORT);
env->ReleaseByteArrayElements(output, o, 0);
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
#include <math.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
namespace tf_tracking {
// Arbitrary keypoint type ids for labeling the origin of tracked keypoints.
enum KeypointType {
KEYPOINT_TYPE_DEFAULT = 0,
KEYPOINT_TYPE_FAST = 1,
KEYPOINT_TYPE_INTEREST = 2
};
// Struct that can be used to more richly store the results of a detection
// than a single number, while still maintaining comparability.
struct MatchScore {
explicit MatchScore(double val) : value(val) {}
MatchScore() { value = 0.0; }
double value;
MatchScore& operator+(const MatchScore& rhs) {
value += rhs.value;
return *this;
}
friend std::ostream& operator<<(std::ostream& stream,
const MatchScore& detection) {
stream << detection.value;
return stream;
}
};
inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) {
return cC1.value < cC2.value;
}
inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) {
return cC1.value > cC2.value;
}
inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) {
return cC1.value >= cC2.value;
}
inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) {
return cC1.value <= cC2.value;
}
// Fixed seed used for all random number generators.
static const int kRandomNumberSeed = 11111;
// TODO(andrewharp): Move as many of these settings as possible into a settings
// object which can be passed in from Java at runtime.
// Whether or not to use ESM instead of LK flow.
static const bool kUseEsm = false;
// This constant gets added to the diagonal of the Hessian
// before solving for translation in 2dof ESM.
// It ensures better behavior especially in the absence of
// strong texture.
static const int kEsmRegularizer = 20;
// Do we want to brightness-normalize each keypoint patch when we compute
// its flow using ESM?
static const bool kDoBrightnessNormalize = true;
// Whether or not to use fixed-point interpolated pixel lookups in optical flow.
#define USE_FIXED_POINT_FLOW 1
// Whether to normalize keypoint windows for intensity in LK optical flow.
// This is a define for now because it helps keep the code streamlined.
#define NORMALIZE 1
// Number of keypoints to store per frame.
static const int kMaxKeypoints = 76;
// Keypoint detection.
static const int kMaxTempKeypoints = 1024;
// Number of floats each keypoint takes up when exporting to an array.
static const int kKeypointStep = 7;
// Number of frame deltas to keep around in the circular queue.
static const int kNumFrames = 512;
// Number of iterations to do tracking on each keypoint at each pyramid level.
static const int kNumIterations = 3;
// The number of bins (on a side) to divide each bin from the previous
// cache level into. Higher numbers will decrease performance by increasing
// cache misses, but mean that cache hits are more locally relevant.
static const int kCacheBranchFactor = 2;
// Number of levels to put in the cache.
// Each level of the cache is a square grid of bins, length:
// branch_factor^(level - 1) on each side.
//
// This may be greater than kNumPyramidLevels. Setting it to 0 means no
// caching is enabled.
static const int kNumCacheLevels = 3;
// The level at which the cache pyramid gets cut off and replaced by a matrix
// transform if such a matrix has been provided to the cache.
static const int kCacheCutoff = 1;
static const int kNumPyramidLevels = 4;
// The minimum number of keypoints needed in an object's area.
static const int kMaxKeypointsForObject = 16;
// Minimum number of pyramid levels to use after getting cached value.
// This allows fine-scale adjustment from the cached value, which is taken
// from the center of the corresponding top cache level box.
// Can be [0, kNumPyramidLevels).
static const int kMinNumPyramidLevelsToUseForAdjustment = 1;
// Window size to integrate over to find local image derivative.
static const int kFlowIntegrationWindowSize = 3;
// Total area of integration windows.
static const int kFlowArraySize =
(2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1);
// Error that's considered good enough to early abort tracking.
static const float kTrackingAbortThreshold = 0.03f;
// Maximum number of deviations a keypoint-correspondence delta can be from the
// weighted average before being thrown out for region-based queries.
static const float kNumDeviations = 2.0f;
// The length of the allowed delta between the forward and the backward
// flow deltas in terms of the length of the forward flow vector.
static const float kMaxForwardBackwardErrorAllowed = 0.5f;
// Threshold for pixels to be considered different.
static const int kFastDiffAmount = 10;
// How far from edge of frame to stop looking for FAST keypoints.
static const int kFastBorderBuffer = 10;
// Determines if non-detected arbitrary keypoints should be added to regions.
// This will help if no keypoints have been detected in the region yet.
static const bool kAddArbitraryKeypoints = true;
// How many arbitrary keypoints to add along each axis as candidates for each
// region?
static const int kNumToAddAsCandidates = 1;
// In terms of region dimensions, how closely can we place keypoints
// next to each other?
static const float kClosestPercent = 0.6f;
// How many FAST qualifying pixels must be connected to a pixel for it to be
// considered a candidate keypoint for Harris filtering.
static const int kMinNumConnectedForFastKeypoint = 8;
// Size of the window to integrate over for Harris filtering.
// Compare to kFlowIntegrationWindowSize.
static const int kHarrisWindowSize = 2;
// DETECTOR PARAMETERS
// Before relocalizing, make sure the new proposed position is better than
// the existing position by a small amount to prevent thrashing.
static const MatchScore kMatchScoreBuffer(0.01f);
// Minimum score a tracked object can have and still be considered a match.
// TODO(andrewharp): Make this a per detector thing.
static const MatchScore kMinimumMatchScore(0.5f);
static const float kMinimumCorrelationForTracking = 0.4f;
static const MatchScore kMatchScoreForImmediateTermination(0.0f);
// Run the detector every N frames.
static const int kDetectEveryNFrames = 4;
// How many features does each feature_set contain?
static const int kFeaturesPerFeatureSet = 10;
// The number of FeatureSets managed by the object detector.
// More FeatureSets can increase recall at the cost of performance.
static const int kNumFeatureSets = 7;
// How many FeatureSets must respond affirmatively for a candidate descriptor
// and position to be given more thorough attention?
static const int kNumFeatureSetsForCandidate = 2;
// How large the thumbnails used for correlation validation are. Used for both
// width and height.
static const int kNormalizedThumbnailSize = 11;
// The area of intersection divided by union for the bounding boxes that tells
// if this tracking has slipped enough to invalidate all unlocked examples.
static const float kPositionOverlapThreshold = 0.6f;
// The number of detection failures allowed before an object goes invisible.
// Tracking will still occur, so if it is actually still being tracked and
// comes back into a detectable position, it's likely to be found.
static const int kMaxNumDetectionFailures = 4;
// Minimum square size to scan with sliding window.
static const float kScanMinSquareSize = 16.0f;
// Minimum square size to scan with sliding window.
static const float kScanMaxSquareSize = 64.0f;
// Scale difference for consecutive scans of the sliding window.
static const float kScanScaleFactor = sqrtf(2.0f);
// Step size for sliding window.
static const int kScanStepSize = 10;
// How tightly to pack the descriptor boxes for confirmed exemplars.
static const float kLockedScaleFactor = 1 / sqrtf(2.0f);
// How tightly to pack the descriptor boxes for unconfirmed exemplars.
static const float kUnlockedScaleFactor = 1 / 2.0f;
// How tightly the boxes to scan centered at the last known position will be
// packed.
static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f);
// The bounds on how close a new object example must be to existing object
// examples for detection to be valid.
static const float kMinCorrelationForNewExample = 0.75f;
static const float kMaxCorrelationForNewExample = 0.99f;
// The number of safe tries an exemplar has after being created before
// missed detections count against it.
static const int kFreeTries = 5;
// A false positive is worth this many missed detections.
static const int kFalsePositivePenalty = 5;
struct ObjectDetectorConfig {
const Size image_size;
explicit ObjectDetectorConfig(const Size& image_size)
: image_size(image_size) {}
virtual ~ObjectDetectorConfig() = default;
};
struct KeypointDetectorConfig {
const Size image_size;
bool detect_skin;
explicit KeypointDetectorConfig(const Size& image_size)
: image_size(image_size),
detect_skin(false) {}
};
struct OpticalFlowConfig {
const Size image_size;
explicit OpticalFlowConfig(const Size& image_size)
: image_size(image_size) {}
};
struct TrackerConfig {
const Size image_size;
KeypointDetectorConfig keypoint_detector_config;
OpticalFlowConfig flow_config;
bool always_track;
float object_box_scale_factor_for_features;
explicit TrackerConfig(const Size& image_size)
: image_size(image_size),
keypoint_detector_config(image_size),
flow_config(image_size),
always_track(false),
object_box_scale_factor_for_features(1.0f) {}
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
namespace tf_tracking {
// Class that helps OpticalFlow to speed up flow computation
// by caching coarse-grained flow.
class FlowCache {
public:
explicit FlowCache(const OpticalFlowConfig* const config)
: config_(config),
image_size_(config->image_size),
optical_flow_(config),
fullframe_matrix_(NULL) {
for (int i = 0; i < kNumCacheLevels; ++i) {
const int curr_dims = BlockDimForCacheLevel(i);
has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
}
}
~FlowCache() {
for (int i = 0; i < kNumCacheLevels; ++i) {
SAFE_DELETE(has_cache_[i]);
SAFE_DELETE(displacements_[i]);
}
delete[](fullframe_matrix_);
fullframe_matrix_ = NULL;
}
void NextFrame(ImageData* const new_frame,
const float* const align_matrix23) {
ClearCache();
SetFullframeAlignmentMatrix(align_matrix23);
optical_flow_.NextFrame(new_frame);
}
void ClearCache() {
for (int i = 0; i < kNumCacheLevels; ++i) {
has_cache_[i]->Clear(false);
}
delete[](fullframe_matrix_);
fullframe_matrix_ = NULL;
}
// Finds the flow at a point, using the cache for performance.
bool FindFlowAtPoint(const float u_x, const float u_y,
float* const flow_x, float* const flow_y) const {
// Get the best guess from the cache.
const Point2f guess_from_cache = LookupGuess(u_x, u_y);
*flow_x = guess_from_cache.x;
*flow_y = guess_from_cache.y;
// Now refine the guess using the image pyramid.
for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
pyramid_level >= 0; --pyramid_level) {
if (!optical_flow_.FindFlowAtPointSingleLevel(
pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
return false;
}
}
return true;
}
// Determines the displacement of a point, and uses that to calculate a new
// position.
// Returns true iff the displacement determination worked and the new position
// is in the image.
bool FindNewPositionOfPoint(const float u_x, const float u_y,
float* final_x, float* final_y) const {
float flow_x;
float flow_y;
if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
return false;
}
// Add in the displacement to get the final position.
*final_x = u_x + flow_x;
*final_y = u_y + flow_y;
// Assign the best guess, if we're still in the image.
if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
return true;
} else {
return false;
}
}
// Comparison function for qsort.
static int Compare(const void* a, const void* b) {
return *reinterpret_cast<const float*>(a) -
*reinterpret_cast<const float*>(b);
}
// Returns the median flow within the given bounding box as determined
// by a grid_width x grid_height grid.
Point2f GetMedianFlow(const BoundingBox& bounding_box,
const bool filter_by_fb_error,
const int grid_width,
const int grid_height) const {
const int kMaxPoints = 100;
SCHECK(grid_width * grid_height <= kMaxPoints,
"Too many points for Median flow!");
const BoundingBox valid_box = bounding_box.Intersect(
BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
if (valid_box.GetArea() <= 0.0f) {
return Point2f(0, 0);
}
float x_deltas[kMaxPoints];
float y_deltas[kMaxPoints];
int curr_offset = 0;
for (int i = 0; i < grid_width; ++i) {
for (int j = 0; j < grid_height; ++j) {
const float x_in = valid_box.left_ +
(valid_box.GetWidth() * i) / (grid_width - 1);
const float y_in = valid_box.top_ +
(valid_box.GetHeight() * j) / (grid_height - 1);
float curr_flow_x;
float curr_flow_y;
const bool success = FindNewPositionOfPoint(x_in, y_in,
&curr_flow_x, &curr_flow_y);
if (success) {
x_deltas[curr_offset] = curr_flow_x;
y_deltas[curr_offset] = curr_flow_y;
++curr_offset;
} else {
LOGW("Tracking failure!");
}
}
}
if (curr_offset > 0) {
qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
}
LOGW("No points were valid!");
return Point2f(0, 0);
}
void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
if (align_matrix23 != NULL) {
if (fullframe_matrix_ == NULL) {
fullframe_matrix_ = new float[6];
}
memcpy(fullframe_matrix_, align_matrix23,
6 * sizeof(fullframe_matrix_[0]));
}
}
private:
Point2f LookupGuessFromLevel(
const int cache_level, const float x, const float y) const {
// LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
// Cutoff at the target level and use the matrix transform instead.
if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
const float xnew = x * fullframe_matrix_[0] +
y * fullframe_matrix_[1] +
fullframe_matrix_[2];
const float ynew = x * fullframe_matrix_[3] +
y * fullframe_matrix_[4] +
fullframe_matrix_[5];
return Point2f(xnew - x, ynew - y);
}
const int level_dim = BlockDimForCacheLevel(cache_level);
const int pixels_per_cache_block_x =
(image_size_.width + level_dim - 1) / level_dim;
const int pixels_per_cache_block_y =
(image_size_.height + level_dim - 1) / level_dim;
const int index_x = x / pixels_per_cache_block_x;
const int index_y = y / pixels_per_cache_block_y;
Point2f displacement;
if (!(*has_cache_[cache_level])[index_y][index_x]) {
(*has_cache_[cache_level])[index_y][index_x] = true;
// Get the lower cache level's best guess, if it exists.
displacement = cache_level >= kNumCacheLevels - 1 ?
Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
// LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
// best_guess.x, best_guess.y);
// Find the center of the block.
const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
// LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
// "Querying %5.2f, %5.2f at pyramid level %d, ",
// cache_level, index_x, index_y,
// x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
// center_x, center_y, pyramid_level);
// TODO(andrewharp): Turn on FB error filtering.
const bool success = optical_flow_.FindFlowAtPointSingleLevel(
pyramid_level, center_x, center_y, false,
&displacement.x, &displacement.y);
if (!success) {
LOGV("Computation of cached value failed for level %d!", cache_level);
}
// Store the value for later use.
(*displacements_[cache_level])[index_y][index_x] = displacement;
} else {
displacement = (*displacements_[cache_level])[index_y][index_x];
}
// LOGI("Returning %5.2f, %5.2f for level %d",
// displacement.x, displacement.y, cache_level);
return displacement;
}
Point2f LookupGuess(const float x, const float y) const {
if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
return Point2f(0, 0);
}
// LOGI("Looking up guess at %5.2f %5.2f.", x, y);
if (kNumCacheLevels > 0) {
return LookupGuessFromLevel(0, x, y);
} else {
return Point2f(0, 0);
}
}
// Returns the number of cache bins in each dimension for a given level
// of the cache.
int BlockDimForCacheLevel(const int cache_level) const {
// The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
// thus if there are 4 cache levels, requesting level 3 (0-based) should
// return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
// and so on.
int block_dim = kNumCacheLevels;
for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
--curr_level) {
block_dim *= kCacheBranchFactor;
}
return block_dim;
}
// Returns the level of the image pyramid that a given cache level maps to.
int PyramidLevelForCacheLevel(const int cache_level) const {
// Higher cache and pyramid levels have smaller dimensions. The highest
// cache level should refer to the highest image pyramid level. The
// lower, finer image pyramid levels are uncached (assuming
// kNumCacheLevels < kNumPyramidLevels).
return cache_level + (kNumPyramidLevels - kNumCacheLevels);
}
const OpticalFlowConfig* const config_;
const Size image_size_;
OpticalFlow optical_flow_;
float* fullframe_matrix_;
// Whether this value is currently present in the cache.
Image<bool>* has_cache_[kNumCacheLevels];
// The cached displacement values.
Image<Point2f>* displacements_[kNumCacheLevels];
TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <float.h>
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
namespace tf_tracking {
void FramePair::Init(const int64_t start_time, const int64_t end_time) {
start_time_ = start_time;
end_time_ = end_time;
memset(optical_flow_found_keypoint_, false,
sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
number_of_keypoints_ = 0;
}
void FramePair::AdjustBox(const BoundingBox box,
float* const translation_x,
float* const translation_y,
float* const scale_x,
float* const scale_y) const {
static float weights[kMaxKeypoints];
static Point2f deltas[kMaxKeypoints];
memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
BoundingBox resized_box(box);
resized_box.Scale(0.4f, 0.4f);
FillWeights(resized_box, weights);
FillTranslations(deltas);
const Point2f translation = GetWeightedMedian(weights, deltas);
*translation_x = translation.x;
*translation_y = translation.y;
const Point2f old_center = box.GetCenter();
const int good_scale_points =
FillScales(old_center, translation, weights, deltas);
// Default scale factor is 1 for x and y.
*scale_x = 1.0f;
*scale_y = 1.0f;
// The assumption is that all deltas that make it to this stage with a
// corresponding optical_flow_found_keypoint_[i] == true are not in
// themselves degenerate.
//
// The degeneracy with scale arose because if the points are too close to the
// center of the objects, the scale ratio determination might be incalculable.
//
// The check for kMinNumInRange is not a degeneracy check, but merely an
// attempt to ensure some sort of stability. The actual degeneracy check is in
// the comparison to EPSILON in FillScales (which I've updated to return the
// number good remaining as well).
static const int kMinNumInRange = 5;
if (good_scale_points >= kMinNumInRange) {
const float scale_factor = GetWeightedMedianScale(weights, deltas);
if (scale_factor > 0.0f) {
*scale_x = scale_factor;
*scale_y = scale_factor;
}
}
}
int FramePair::FillWeights(const BoundingBox& box,
float* const weights) const {
// Compute the max score.
float max_score = -FLT_MAX;
float min_score = FLT_MAX;
for (int i = 0; i < kMaxKeypoints; ++i) {
if (optical_flow_found_keypoint_[i]) {
max_score = MAX(max_score, frame1_keypoints_[i].score_);
min_score = MIN(min_score, frame1_keypoints_[i].score_);
}
}
int num_in_range = 0;
for (int i = 0; i < kMaxKeypoints; ++i) {
if (!optical_flow_found_keypoint_[i]) {
weights[i] = 0.0f;
continue;
}
const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
if (in_box) {
++num_in_range;
}
// The weighting based off distance. Anything within the bounding box
// has a weight of 1, and everything outside of that is within the range
// [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
float distance_score = 1.0f;
if (!in_box) {
const Point2f initial = box.GetCenter();
const float sq_x_dist =
Square(initial.x - frame1_keypoints_[i].pos_.x);
const float sq_y_dist =
Square(initial.y - frame1_keypoints_[i].pos_.y);
const float squared_half_width = Square(box.GetWidth() / 2.0f);
const float squared_half_height = Square(box.GetHeight() / 2.0f);
static const float kOutOfBoxMultiplier = 0.5f;
distance_score = kOutOfBoxMultiplier *
MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
}
// The weighting based on relative score strength. kBaseScore - 1.0f.
float intrinsic_score = 1.0f;
if (max_score > min_score) {
static const float kBaseScore = 0.5f;
intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
(max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
}
// The final score will be in the range [0, 1].
weights[i] = distance_score * intrinsic_score;
}
return num_in_range;
}
void FramePair::FillTranslations(Point2f* const translations) const {
for (int i = 0; i < kMaxKeypoints; ++i) {
if (!optical_flow_found_keypoint_[i]) {
continue;
}
translations[i].x =
frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
translations[i].y =
frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
}
}
int FramePair::FillScales(const Point2f& old_center,
const Point2f& translation,
float* const weights,
Point2f* const scales) const {
int num_good = 0;
for (int i = 0; i < kMaxKeypoints; ++i) {
if (!optical_flow_found_keypoint_[i]) {
continue;
}
const Keypoint keypoint1 = frame1_keypoints_[i];
const Keypoint keypoint2 = frame2_keypoints_[i];
const float dist1_x = keypoint1.pos_.x - old_center.x;
const float dist1_y = keypoint1.pos_.y - old_center.y;
const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
// Make sure that the scale makes sense; points too close to the center
// will result in either NaNs or infinite results for scale due to
// limited tracking and floating point resolution.
// Also check that the parity of the points is the same with respect to
// x and y, as we can't really make sense of data that has flipped.
if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
(dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
((dist2_y > EPSILON && dist1_y > EPSILON) ||
(dist2_y < -EPSILON && dist1_y < -EPSILON))) {
scales[i].x = dist2_x / dist1_x;
scales[i].y = dist2_y / dist1_y;
++num_good;
} else {
weights[i] = 0.0f;
scales[i].x = 1.0f;
scales[i].y = 1.0f;
}
}
return num_good;
}
struct WeightedDelta {
float weight;
float delta;
};
// Sort by delta, not by weight.
inline int WeightedDeltaCompare(const void* const a, const void* const b) {
return (reinterpret_cast<const WeightedDelta*>(a)->delta -
reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
}
// Returns the median delta from a sorted set of weighted deltas.
static float GetMedian(const int num_items,
const WeightedDelta* const weighted_deltas,
const float sum) {
if (num_items == 0 || sum < EPSILON) {
return 0.0f;
}
float current_weight = 0.0f;
const float target_weight = sum / 2.0f;
for (int i = 0; i < num_items; ++i) {
if (weighted_deltas[i].weight > 0.0f) {
current_weight += weighted_deltas[i].weight;
if (current_weight >= target_weight) {
return weighted_deltas[i].delta;
}
}
}
LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
return 0.0f;
}
Point2f FramePair::GetWeightedMedian(
const float* const weights, const Point2f* const deltas) const {
Point2f median_delta;
// TODO(andrewharp): only sort deltas that could possibly have an effect.
static WeightedDelta weighted_deltas[kMaxKeypoints];
// Compute median X value.
{
float total_weight = 0.0f;
// Compute weighted mean and deltas.
for (int i = 0; i < kMaxKeypoints; ++i) {
weighted_deltas[i].delta = deltas[i].x;
const float weight = weights[i];
weighted_deltas[i].weight = weight;
if (weight > 0.0f) {
total_weight += weight;
}
}
qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
WeightedDeltaCompare);
median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
}
// Compute median Y value.
{
float total_weight = 0.0f;
// Compute weighted mean and deltas.
for (int i = 0; i < kMaxKeypoints; ++i) {
const float weight = weights[i];
weighted_deltas[i].weight = weight;
weighted_deltas[i].delta = deltas[i].y;
if (weight > 0.0f) {
total_weight += weight;
}
}
qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
WeightedDeltaCompare);
median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
}
return median_delta;
}
float FramePair::GetWeightedMedianScale(
const float* const weights, const Point2f* const deltas) const {
float median_delta;
// TODO(andrewharp): only sort deltas that could possibly have an effect.
static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
// Compute median scale value across x and y.
{
float total_weight = 0.0f;
// Add X values.
for (int i = 0; i < kMaxKeypoints; ++i) {
weighted_deltas[i].delta = deltas[i].x;
const float weight = weights[i];
weighted_deltas[i].weight = weight;
if (weight > 0.0f) {
total_weight += weight;
}
}
// Add Y values.
for (int i = 0; i < kMaxKeypoints; ++i) {
weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
const float weight = weights[i];
weighted_deltas[i + kMaxKeypoints].weight = weight;
if (weight > 0.0f) {
total_weight += weight;
}
}
qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
WeightedDeltaCompare);
median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
}
return median_delta;
}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
namespace tf_tracking {
// A class that records keypoint correspondences from pairs of
// consecutive frames.
class FramePair {
public:
FramePair()
: start_time_(0),
end_time_(0),
number_of_keypoints_(0) {}
// Cleans up the FramePair so that they can be reused.
void Init(const int64_t start_time, const int64_t end_time);
void AdjustBox(const BoundingBox box,
float* const translation_x,
float* const translation_y,
float* const scale_x,
float* const scale_y) const;
private:
// Returns the weighted median of the given deltas, computed independently on
// x and y. Returns 0,0 in case of failure. The assumption is that a
// translation of 0.0 in the degenerate case is the best that can be done, and
// should not be considered an error.
//
// In the case of scale, a slight exception is made just to be safe and
// there is a check for 0.0 explicitly, but that shouldn't ever be possible to
// happen naturally because of the non-zero + parity checks in FillScales.
Point2f GetWeightedMedian(const float* const weights,
const Point2f* const deltas) const;
float GetWeightedMedianScale(const float* const weights,
const Point2f* const deltas) const;
// Weights points based on the query_point and cutoff_dist.
int FillWeights(const BoundingBox& box,
float* const weights) const;
// Fills in the array of deltas with the translations of the points
// between frames.
void FillTranslations(Point2f* const translations) const;
// Fills in the array of deltas with the relative scale factor of points
// relative to a given center. Has the ability to override the weight to 0 if
// a degenerate scale is detected.
// Translation is the amount the center of the box has moved from one frame to
// the next.
int FillScales(const Point2f& old_center,
const Point2f& translation,
float* const weights,
Point2f* const scales) const;
// TODO(andrewharp): Make these private.
public:
// The time at frame1.
int64_t start_time_;
// The time at frame2.
int64_t end_time_;
// This array will contain the keypoints found in frame 1.
Keypoint frame1_keypoints_[kMaxKeypoints];
// Contain the locations of the keypoints from frame 1 in frame 2.
Keypoint frame2_keypoints_[kMaxKeypoints];
// The number of keypoints in frame 1.
int number_of_keypoints_;
// Keeps track of which keypoint correspondences were actually found from one
// frame to another.
// The i-th element of this array will be non-zero if and only if the i-th
// keypoint of frame 1 was found in frame 2.
bool optical_flow_found_keypoint_[kMaxKeypoints];
private:
TF_DISALLOW_COPY_AND_ASSIGN(FramePair);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
struct Size {
Size(const int width, const int height) : width(width), height(height) {}
int width;
int height;
};
class Point2f {
public:
Point2f() : x(0.0f), y(0.0f) {}
Point2f(const float x, const float y) : x(x), y(y) {}
inline Point2f operator- (const Point2f& that) const {
return Point2f(this->x - that.x, this->y - that.y);
}
inline Point2f operator+ (const Point2f& that) const {
return Point2f(this->x + that.x, this->y + that.y);
}
inline Point2f& operator+= (const Point2f& that) {
this->x += that.x;
this->y += that.y;
return *this;
}
inline Point2f& operator-= (const Point2f& that) {
this->x -= that.x;
this->y -= that.y;
return *this;
}
inline Point2f operator- (const Point2f& that) {
return Point2f(this->x - that.x, this->y - that.y);
}
inline float LengthSquared() {
return Square(this->x) + Square(this->y);
}
inline float Length() {
return sqrtf(LengthSquared());
}
inline float DistanceSquared(const Point2f& that) {
return Square(this->x - that.x) + Square(this->y - that.y);
}
inline float Distance(const Point2f& that) {
return sqrtf(DistanceSquared(that));
}
float x;
float y;
};
inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
stream << point.x << "," << point.y;
return stream;
}
class BoundingBox {
public:
BoundingBox()
: left_(0),
top_(0),
right_(0),
bottom_(0) {}
BoundingBox(const BoundingBox& bounding_box)
: left_(bounding_box.left_),
top_(bounding_box.top_),
right_(bounding_box.right_),
bottom_(bounding_box.bottom_) {
SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
}
BoundingBox(const float left,
const float top,
const float right,
const float bottom)
: left_(left),
top_(top),
right_(right),
bottom_(bottom) {
SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
}
BoundingBox(const Point2f& point1, const Point2f& point2)
: left_(MIN(point1.x, point2.x)),
top_(MIN(point1.y, point2.y)),
right_(MAX(point1.x, point2.x)),
bottom_(MAX(point1.y, point2.y)) {}
inline void CopyToArray(float* const bounds_array) const {
bounds_array[0] = left_;
bounds_array[1] = top_;
bounds_array[2] = right_;
bounds_array[3] = bottom_;
}
inline float GetWidth() const {
return right_ - left_;
}
inline float GetHeight() const {
return bottom_ - top_;
}
inline float GetArea() const {
const float width = GetWidth();
const float height = GetHeight();
if (width <= 0 || height <= 0) {
return 0.0f;
}
return width * height;
}
inline Point2f GetCenter() const {
return Point2f((left_ + right_) / 2.0f,
(top_ + bottom_) / 2.0f);
}
inline bool ValidBox() const {
return GetArea() > 0.0f;
}
// Returns a bounding box created from the overlapping area of these two.
inline BoundingBox Intersect(const BoundingBox& that) const {
const float new_left = MAX(this->left_, that.left_);
const float new_right = MIN(this->right_, that.right_);
if (new_left >= new_right) {
return BoundingBox();
}
const float new_top = MAX(this->top_, that.top_);
const float new_bottom = MIN(this->bottom_, that.bottom_);
if (new_top >= new_bottom) {
return BoundingBox();
}
return BoundingBox(new_left, new_top, new_right, new_bottom);
}
// Returns a bounding box that can contain both boxes.
inline BoundingBox Union(const BoundingBox& that) const {
return BoundingBox(MIN(this->left_, that.left_),
MIN(this->top_, that.top_),
MAX(this->right_, that.right_),
MAX(this->bottom_, that.bottom_));
}
inline float PascalScore(const BoundingBox& that) const {
SCHECK(GetArea() > 0.0f, "Empty bounding box!");
SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
const float intersect_area = this->Intersect(that).GetArea();
if (intersect_area <= 0) {
return 0;
}
const float score =
intersect_area / (GetArea() + that.GetArea() - intersect_area);
SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
return score;
}
inline bool Intersects(const BoundingBox& that) const {
return InRange(that.left_, left_, right_)
|| InRange(that.right_, left_, right_)
|| InRange(that.top_, top_, bottom_)
|| InRange(that.bottom_, top_, bottom_);
}
// Returns whether another bounding box is completely inside of this bounding
// box. Sharing edges is ok.
inline bool Contains(const BoundingBox& that) const {
return that.left_ >= left_ &&
that.right_ <= right_ &&
that.top_ >= top_ &&
that.bottom_ <= bottom_;
}
inline bool Contains(const Point2f& point) const {
return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
}
inline void Shift(const Point2f shift_amount) {
left_ += shift_amount.x;
top_ += shift_amount.y;
right_ += shift_amount.x;
bottom_ += shift_amount.y;
}
inline void ScaleOrigin(const float scale_x, const float scale_y) {
left_ *= scale_x;
right_ *= scale_x;
top_ *= scale_y;
bottom_ *= scale_y;
}
inline void Scale(const float scale_x, const float scale_y) {
const Point2f center = GetCenter();
const float half_width = GetWidth() / 2.0f;
const float half_height = GetHeight() / 2.0f;
left_ = center.x - half_width * scale_x;
right_ = center.x + half_width * scale_x;
top_ = center.y - half_height * scale_y;
bottom_ = center.y + half_height * scale_y;
}
float left_;
float top_;
float right_;
float bottom_;
};
inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
stream << "[" << box.left_ << " - " << box.right_
<< ", " << box.top_ << " - " << box.bottom_
<< ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
return stream;
}
class BoundingSquare {
public:
BoundingSquare(const float x, const float y, const float size)
: x_(x), y_(y), size_(size) {}
explicit BoundingSquare(const BoundingBox& box)
: x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
#ifdef SANITY_CHECKS
if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
LOG(WARNING) << "This is not a square: " << box << std::endl;
}
#endif
}
inline BoundingBox ToBoundingBox() const {
return BoundingBox(x_, y_, x_ + size_, y_ + size_);
}
inline bool ValidBox() {
return size_ > 0.0f;
}
inline void Shift(const Point2f shift_amount) {
x_ += shift_amount.x;
y_ += shift_amount.y;
}
inline void Scale(const float scale) {
const float new_size = size_ * scale;
const float position_diff = (new_size - size_) / 2.0f;
x_ -= position_diff;
y_ -= position_diff;
size_ = new_size;
}
float x_;
float y_;
float size_;
};
inline std::ostream& operator<<(std::ostream& stream,
const BoundingSquare& square) {
stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
return stream;
}
inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
const float size) {
const float width_diff = (original_box.GetWidth() - size) / 2.0f;
const float height_diff = (original_box.GetHeight() - size) / 2.0f;
return BoundingSquare(original_box.left_ + width_diff,
original_box.top_ + height_diff,
size);
}
inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
return GetCenteredSquare(
original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
#include <GLES/gl.h>
#include <GLES/glext.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
namespace tf_tracking {
// Draws a box at the given position.
inline static void DrawBox(const BoundingBox& bounding_box) {
const GLfloat line[] = {
bounding_box.left_, bounding_box.bottom_,
bounding_box.left_, bounding_box.top_,
bounding_box.left_, bounding_box.top_,
bounding_box.right_, bounding_box.top_,
bounding_box.right_, bounding_box.top_,
bounding_box.right_, bounding_box.bottom_,
bounding_box.right_, bounding_box.bottom_,
bounding_box.left_, bounding_box.bottom_
};
glVertexPointer(2, GL_FLOAT, 0, line);
glEnableClientState(GL_VERTEX_ARRAY);
glDrawArrays(GL_LINES, 0, 8);
}
// Changes the coordinate system such that drawing to an arbitrary square in
// the world can thereafter be drawn to using coordinates 0 - 1.
inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) {
glScalef(square.size_, square.size_, 1.0f);
glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f);
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
#include <stdint.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
template <typename T>
Image<T>::Image(const int width, const int height)
: width_less_one_(width - 1),
height_less_one_(height - 1),
data_size_(width * height),
own_data_(true),
width_(width),
height_(height),
stride_(width) {
Allocate();
}
template <typename T>
Image<T>::Image(const Size& size)
: width_less_one_(size.width - 1),
height_less_one_(size.height - 1),
data_size_(size.width * size.height),
own_data_(true),
width_(size.width),
height_(size.height),
stride_(size.width) {
Allocate();
}
// Constructor that creates an image from preallocated data.
// Note: The image takes ownership of the data lifecycle, unless own_data is
// set to false.
template <typename T>
Image<T>::Image(const int width, const int height, T* const image_data,
const bool own_data) :
width_less_one_(width - 1),
height_less_one_(height - 1),
data_size_(width * height),
own_data_(own_data),
width_(width),
height_(height),
stride_(width) {
image_data_ = image_data;
SCHECK(image_data_ != NULL, "Can't create image with NULL data!");
}
template <typename T>
Image<T>::~Image() {
if (own_data_) {
delete[] image_data_;
}
image_data_ = NULL;
}
template<typename T>
template<class DstType>
bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x,
const int fp_y,
const int patchwidth,
const int patchheight,
DstType* to_data) const {
// Calculate weights.
const int trunc_x = fp_x >> 16;
const int trunc_y = fp_y >> 16;
if (trunc_x < 0 || trunc_y < 0 ||
(trunc_x + patchwidth) >= width_less_one_ ||
(trunc_y + patchheight) >= height_less_one_) {
return false;
}
// Now walk over destination patch and fill from interpolated source image.
for (int y = 0; y < patchheight; ++y, to_data += patchwidth) {
for (int x = 0; x < patchwidth; ++x) {
to_data[x] =
static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16),
fp_y + (y << 16)));
}
}
return true;
}
template <typename T>
Image<T>* Image<T>::Crop(
const int left, const int top, const int right, const int bottom) const {
SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left);
SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right);
SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top);
SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom);
SCHECK(left <= right, "mismatch!");
SCHECK(top <= bottom, "mismatch!");
const int new_width = right - left + 1;
const int new_height = bottom - top + 1;
Image<T>* const cropped_image = new Image(new_width, new_height);
for (int y = 0; y < new_height; ++y) {
memcpy((*cropped_image)[y], ((*this)[y + top] + left),
new_width * sizeof(T));
}
return cropped_image;
}
template <typename T>
inline float Image<T>::GetPixelInterp(const float x, const float y) const {
// Do int conversion one time.
const int floored_x = static_cast<int>(x);
const int floored_y = static_cast<int>(y);
// Note: it might be the case that the *_[min|max] values are clipped, and
// these (the a b c d vals) aren't (for speed purposes), but that doesn't
// matter. We'll just be blending the pixel with itself in that case anyway.
const float b = x - floored_x;
const float a = 1.0f - b;
const float d = y - floored_y;
const float c = 1.0f - d;
SCHECK(ValidInterpPixel(x, y),
"x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)",
x, width_less_one_, y, height_less_one_);
const T* const pix_ptr = (*this)[floored_y] + floored_x;
// Get the pixel values surrounding this point.
const T& p1 = pix_ptr[0];
const T& p2 = pix_ptr[1];
const T& p3 = pix_ptr[width_];
const T& p4 = pix_ptr[width_ + 1];
// Simple bilinear interpolation between four reference pixels.
// If x is the value requested:
// a b
// -------
// c |p1 p2|
// | x |
// d |p3 p4|
// -------
return c * ((a * p1) + (b * p2)) +
d * ((a * p3) + (b * p4));
}
template <typename T>
inline T Image<T>::GetPixelInterpFixed1616(
const int fp_x_whole, const int fp_y_whole) const {
static const int kFixedPointOne = 0x00010000;
static const int kFixedPointHalf = 0x00008000;
static const int kFixedPointTruncateMask = 0xFFFF0000;
int trunc_x = fp_x_whole & kFixedPointTruncateMask;
int trunc_y = fp_y_whole & kFixedPointTruncateMask;
const int fp_x = fp_x_whole - trunc_x;
const int fp_y = fp_y_whole - trunc_y;
// Scale the truncated values back to regular ints.
trunc_x >>= 16;
trunc_y >>= 16;
const int one_minus_fp_x = kFixedPointOne - fp_x;
const int one_minus_fp_y = kFixedPointOne - fp_y;
const T* trunc_start = (*this)[trunc_y] + trunc_x;
const T a = trunc_start[0];
const T b = trunc_start[1];
const T c = trunc_start[stride_];
const T d = trunc_start[stride_ + 1];
return (
(one_minus_fp_y * static_cast<int64_t>(one_minus_fp_x * a + fp_x * b) +
fp_y * static_cast<int64_t>(one_minus_fp_x * c + fp_x * d) +
kFixedPointHalf) >>
32);
}
template <typename T>
inline bool Image<T>::ValidPixel(const int x, const int y) const {
return InRange(x, ZERO, width_less_one_) &&
InRange(y, ZERO, height_less_one_);
}
template <typename T>
inline BoundingBox Image<T>::GetContainingBox() const {
return BoundingBox(
0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON);
}
template <typename T>
inline bool Image<T>::Contains(const BoundingBox& bounding_box) const {
// TODO(andrewharp): Come up with a more elegant way of ensuring that bounds
// are ok.
return GetContainingBox().Contains(bounding_box);
}
template <typename T>
inline bool Image<T>::ValidInterpPixel(const float x, const float y) const {
// Exclusive of max because we can be more efficient if we don't handle
// interpolating on or past the last pixel.
return (x >= ZERO) && (x < width_less_one_) &&
(y >= ZERO) && (y < height_less_one_);
}
template <typename T>
void Image<T>::DownsampleAveraged(const T* const original, const int stride,
const int factor) {
#ifdef __ARM_NEON
if (factor == 4 || factor == 2) {
DownsampleAveragedNeon(original, stride, factor);
return;
}
#endif
// TODO(andrewharp): delete or enable this for non-uint8_t downsamples.
const int pixels_per_block = factor * factor;
// For every pixel in resulting image.
for (int y = 0; y < height_; ++y) {
const int orig_y = y * factor;
const int y_bound = orig_y + factor;
// Sum up the original pixels.
for (int x = 0; x < width_; ++x) {
const int orig_x = x * factor;
const int x_bound = orig_x + factor;
// Making this int32_t because type U or T might overflow.
int32_t pixel_sum = 0;
// Grab all the pixels that make up this pixel.
for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) {
const T* p = original + curr_y * stride + orig_x;
for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) {
pixel_sum += *p++;
}
}
(*this)[y][x] = pixel_sum / pixels_per_block;
}
}
}
template <typename T>
void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) {
// Calculating the scaling factors based on target image size.
const float factor_x = static_cast<float>(original.GetWidth()) /
static_cast<float>(width_);
const float factor_y = static_cast<float>(original.GetHeight()) /
static_cast<float>(height_);
// Calculating initial offset in x-axis.
const float offset_x = 0.5f * (original.GetWidth() - width_) / width_;
// Calculating initial offset in y-axis.
const float offset_y = 0.5f * (original.GetHeight() - height_) / height_;
float orig_y = offset_y;
// For every pixel in resulting image.
for (int y = 0; y < height_; ++y) {
float orig_x = offset_x;
// Finding nearest pixel on y-axis.
const int nearest_y = static_cast<int>(orig_y + 0.5f);
const T* row_data = original[nearest_y];
T* pixel_ptr = (*this)[y];
for (int x = 0; x < width_; ++x) {
// Finding nearest pixel on x-axis.
const int nearest_x = static_cast<int>(orig_x + 0.5f);
*pixel_ptr++ = row_data[nearest_x];
orig_x += factor_x;
}
orig_y += factor_y;
}
}
template <typename T>
void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) {
// TODO(andrewharp): Turn this into a general compare sizes/bulk
// copy method.
if (original.GetWidth() == GetWidth() &&
original.GetHeight() == GetHeight() &&
original.stride() == stride()) {
memcpy(image_data_, original.data(), data_size_ * sizeof(T));
return;
}
// Calculating the scaling factors based on target image size.
const float factor_x = static_cast<float>(original.GetWidth()) /
static_cast<float>(width_);
const float factor_y = static_cast<float>(original.GetHeight()) /
static_cast<float>(height_);
// Calculating initial offset in x-axis.
const float offset_x = 0;
const int offset_x_fp = RealToFixed1616(offset_x);
// Calculating initial offset in y-axis.
const float offset_y = 0;
const int offset_y_fp = RealToFixed1616(offset_y);
// Get the fixed point scaling factor value.
// Shift by 8 so we can fit everything into a 4 byte int later for speed
// reasons. This means the precision is limited to 1 / 256th of a pixel,
// but this should be good enough.
const int factor_x_fp = RealToFixed1616(factor_x) >> 8;
const int factor_y_fp = RealToFixed1616(factor_y) >> 8;
int src_y_fp = offset_y_fp >> 8;
static const int kFixedPointOne8 = 0x00000100;
static const int kFixedPointHalf8 = 0x00000080;
static const int kFixedPointTruncateMask8 = 0xFFFFFF00;
// For every pixel in resulting image.
for (int y = 0; y < height_; ++y) {
int src_x_fp = offset_x_fp >> 8;
int trunc_y = src_y_fp & kFixedPointTruncateMask8;
const int fp_y = src_y_fp - trunc_y;
// Scale the truncated values back to regular ints.
trunc_y >>= 8;
const int one_minus_fp_y = kFixedPointOne8 - fp_y;
T* pixel_ptr = (*this)[y];
// Make sure not to read from an invalid row.
const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1);
const T* other_top_ptr = original[trunc_y];
const T* other_bot_ptr = original[trunc_y_b];
int last_trunc_x = -1;
int trunc_x = -1;
T a = 0;
T b = 0;
T c = 0;
T d = 0;
for (int x = 0; x < width_; ++x) {
trunc_x = src_x_fp & kFixedPointTruncateMask8;
const int fp_x = (src_x_fp - trunc_x) >> 8;
// Scale the truncated values back to regular ints.
trunc_x >>= 8;
// It's possible we're reading from the same pixels
if (trunc_x != last_trunc_x) {
// Make sure not to read from an invalid column.
const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1);
a = other_top_ptr[trunc_x];
b = other_top_ptr[trunc_x_b];
c = other_bot_ptr[trunc_x];
d = other_bot_ptr[trunc_x_b];
last_trunc_x = trunc_x;
}
const int one_minus_fp_x = kFixedPointOne8 - fp_x;
const int32_t value =
((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) +
(fp_y * one_minus_fp_x * c + fp_x * d) + kFixedPointHalf8) >>
16;
*pixel_ptr++ = value;
src_x_fp += factor_x_fp;
}
src_y_fp += factor_y_fp;
}
}
template <typename T>
void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) {
for (int y = 0; y < height_; ++y) {
const int orig_y = Clip(2 * y, ZERO, original.height_less_one_);
const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_);
const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_);
for (int x = 0; x < width_; ++x) {
const int orig_x = Clip(2 * x, ZERO, original.width_less_one_);
const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_);
const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_);
// Center.
int32_t pixel_sum = original[orig_y][orig_x] * 4;
// Sides.
pixel_sum += (original[orig_y][max_x] +
original[orig_y][min_x] +
original[max_y][orig_x] +
original[min_y][orig_x]) * 2;
// Diagonals.
pixel_sum += (original[min_y][max_x] +
original[min_y][min_x] +
original[max_y][max_x] +
original[max_y][min_x]);
(*this)[y][x] = pixel_sum >> 4; // 16
}
}
}
template <typename T>
void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) {
const int max_x = original.width_less_one_;
const int max_y = original.height_less_one_;
// The JY Bouget paper on Lucas-Kanade recommends a
// [1/16 1/4 3/8 1/4 1/16]^2 filter.
// This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below.
static const int window_radius = 2;
static const int window_size = window_radius*2 + 1;
static const int window_weights[] = {1, 4, 6, 4, 1, // 16 +
4, 16, 24, 16, 4, // 64 +
6, 24, 36, 24, 6, // 96 +
4, 16, 24, 16, 4, // 64 +
1, 4, 6, 4, 1}; // 16 = 256
// We'll multiply and sum with the whole numbers first, then divide by
// the total weight to normalize at the last moment.
for (int y = 0; y < height_; ++y) {
for (int x = 0; x < width_; ++x) {
int32_t pixel_sum = 0;
const int* w = window_weights;
const int start_x = Clip((x << 1) - window_radius, ZERO, max_x);
// Clip the boundaries to the size of the image.
for (int window_y = 0; window_y < window_size; ++window_y) {
const int start_y =
Clip((y << 1) - window_radius + window_y, ZERO, max_y);
const T* p = original[start_y] + start_x;
for (int window_x = 0; window_x < window_size; ++window_x) {
pixel_sum += *p++ * *w++;
}
}
// Conversion to type T will happen here after shifting right 8 bits to
// divide by 256.
(*this)[y][x] = pixel_sum >> 8;
}
}
}
template <typename T>
template <typename U>
inline T Image<T>::ScharrPixelX(const Image<U>& original,
const int center_x, const int center_y) const {
const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_);
const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_);
const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_);
const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_);
// Convolution loop unrolled for performance...
return (3 * (original[min_y][max_x]
+ original[max_y][max_x]
- original[min_y][min_x]
- original[max_y][min_x])
+ 10 * (original[center_y][max_x]
- original[center_y][min_x])) / 32;
}
template <typename T>
template <typename U>
inline T Image<T>::ScharrPixelY(const Image<U>& original,
const int center_x, const int center_y) const {
const int min_x = Clip(center_x - 1, 0, original.width_less_one_);
const int max_x = Clip(center_x + 1, 0, original.width_less_one_);
const int min_y = Clip(center_y - 1, 0, original.height_less_one_);
const int max_y = Clip(center_y + 1, 0, original.height_less_one_);
// Convolution loop unrolled for performance...
return (3 * (original[max_y][min_x]
+ original[max_y][max_x]
- original[min_y][min_x]
- original[min_y][max_x])
+ 10 * (original[max_y][center_x]
- original[min_y][center_x])) / 32;
}
template <typename T>
template <typename U>
inline void Image<T>::ScharrX(const Image<U>& original) {
for (int y = 0; y < height_; ++y) {
for (int x = 0; x < width_; ++x) {
SetPixel(x, y, ScharrPixelX(original, x, y));
}
}
}
template <typename T>
template <typename U>
inline void Image<T>::ScharrY(const Image<U>& original) {
for (int y = 0; y < height_; ++y) {
for (int x = 0; x < width_; ++x) {
SetPixel(x, y, ScharrPixelY(original, x, y));
}
}
}
template <typename T>
template <typename U>
void Image<T>::DerivativeX(const Image<U>& original) {
for (int y = 0; y < height_; ++y) {
const U* const source_row = original[y];
T* const dest_row = (*this)[y];
// Compute first pixel. Approximated with forward difference.
dest_row[0] = source_row[1] - source_row[0];
// All the pixels in between. Central difference method.
const U* source_prev_pixel = source_row;
T* dest_pixel = dest_row + 1;
const U* source_next_pixel = source_row + 2;
for (int x = 1; x < width_less_one_; ++x) {
*dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
}
// Last pixel. Approximated with backward difference.
dest_row[width_less_one_] =
source_row[width_less_one_] - source_row[width_less_one_ - 1];
}
}
template <typename T>
template <typename U>
void Image<T>::DerivativeY(const Image<U>& original) {
const int src_stride = original.stride();
// Compute 1st row. Approximated with forward difference.
{
const U* const src_row = original[0];
T* dest_row = (*this)[0];
for (int x = 0; x < width_; ++x) {
dest_row[x] = src_row[x + src_stride] - src_row[x];
}
}
// Compute all rows in between using central difference.
for (int y = 1; y < height_less_one_; ++y) {
T* dest_row = (*this)[y];
const U* source_prev_pixel = original[y - 1];
const U* source_next_pixel = original[y + 1];
for (int x = 0; x < width_; ++x) {
*dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
}
}
// Compute last row. Approximated with backward difference.
{
const U* const src_row = original[height_less_one_];
T* dest_row = (*this)[height_less_one_];
for (int x = 0; x < width_; ++x) {
dest_row[x] = src_row[x] - src_row[x - src_stride];
}
}
}
template <typename T>
template <typename U>
inline T Image<T>::ConvolvePixel3x3(const Image<U>& original,
const int* const filter,
const int center_x, const int center_y,
const int total) const {
int32_t sum = 0;
for (int filter_y = 0; filter_y < 3; ++filter_y) {
const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight());
for (int filter_x = 0; filter_x < 3; ++filter_x) {
const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth());
sum += original[y][x] * filter[filter_y * 3 + filter_x];
}
}
return sum / total;
}
template <typename T>
template <typename U>
inline void Image<T>::Convolve3x3(const Image<U>& original,
const int32_t* const filter) {
int32_t sum = 0;
for (int i = 0; i < 9; ++i) {
sum += abs(filter[i]);
}
for (int y = 0; y < height_; ++y) {
for (int x = 0; x < width_; ++x) {
SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum));
}
}
}
template <typename T>
inline void Image<T>::FromArray(const T* const pixels, const int stride,
const int factor) {
if (factor == 1 && stride == width_) {
// If not subsampling, memcpy per line should be faster.
memcpy(this->image_data_, pixels, data_size_ * sizeof(T));
return;
}
DownsampleAveraged(pixels, stride, factor);
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
#include <stdint.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
// TODO(andrewharp): Make this a cast to uint32_t if/when we go unsigned for
// operations.
#define ZERO 0
#ifdef SANITY_CHECKS
#define CHECK_PIXEL(IMAGE, X, Y) {\
SCHECK((IMAGE)->ValidPixel((X), (Y)), \
"CHECK_PIXEL(%d,%d) in %dx%d image.", \
static_cast<int>(X), static_cast<int>(Y), \
(IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
}
#define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\
SCHECK((IMAGE)->validInterpPixel((X), (Y)), \
"CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \
static_cast<float>(X), static_cast<float>(Y), \
(IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
}
#else
#define CHECK_PIXEL(image, x, y) {}
#define CHECK_PIXEL_INTERP(IMAGE, X, Y) {}
#endif
namespace tf_tracking {
#ifdef SANITY_CHECKS
// Class which exists solely to provide bounds checking for array-style image
// data access.
template <typename T>
class RowData {
public:
RowData(T* const row_data, const int max_col)
: row_data_(row_data), max_col_(max_col) {}
inline T& operator[](const int col) const {
SCHECK(InRange(col, 0, max_col_),
"Column out of range: %d (%d max)", col, max_col_);
return row_data_[col];
}
inline operator T*() const {
return row_data_;
}
private:
T* const row_data_;
const int max_col_;
};
#endif
// Naive templated sorting function.
template <typename T>
int Comp(const void* a, const void* b) {
const T val1 = *reinterpret_cast<const T*>(a);
const T val2 = *reinterpret_cast<const T*>(b);
if (val1 == val2) {
return 0;
} else if (val1 < val2) {
return -1;
} else {
return 1;
}
}
// TODO(andrewharp): Make explicit which operations support negative numbers or
// struct/class types in image data (possibly create fast multi-dim array class
// for data where pixel arithmetic does not make sense).
// Image class optimized for working on numeric arrays as grayscale image data.
// Supports other data types as a 2D array class, so long as no pixel math
// operations are called (convolution, downsampling, etc).
template <typename T>
class Image {
public:
Image(const int width, const int height);
explicit Image(const Size& size);
// Constructor that creates an image from preallocated data.
// Note: The image takes ownership of the data lifecycle, unless own_data is
// set to false.
Image(const int width, const int height, T* const image_data,
const bool own_data = true);
~Image();
// Extract a pixel patch from this image, starting at a subpixel location.
// Uses 16:16 fixed point format for representing real values and doing the
// bilinear interpolation.
//
// Arguments fp_x and fp_y tell the subpixel position in fixed point format,
// patchwidth/patchheight give the size of the patch in pixels and
// to_data must be a valid pointer to a *contiguous* destination data array.
template<class DstType>
bool ExtractPatchAtSubpixelFixed1616(const int fp_x,
const int fp_y,
const int patchwidth,
const int patchheight,
DstType* to_data) const;
Image<T>* Crop(
const int left, const int top, const int right, const int bottom) const;
inline int GetWidth() const { return width_; }
inline int GetHeight() const { return height_; }
// Bilinearly sample a value between pixels. Values must be within the image.
inline float GetPixelInterp(const float x, const float y) const;
// Bilinearly sample a pixels at a subpixel position using fixed point
// arithmetic.
// Avoids float<->int conversions.
// Values must be within the image.
// Arguments fp_x and fp_y tell the subpixel position in
// 16:16 fixed point format.
//
// Important: This function only makes sense for integer-valued images, such
// as Image<uint8_t> or Image<int> etc.
inline T GetPixelInterpFixed1616(const int fp_x_whole,
const int fp_y_whole) const;
// Returns true iff the pixel is in the image's boundaries.
inline bool ValidPixel(const int x, const int y) const;
inline BoundingBox GetContainingBox() const;
inline bool Contains(const BoundingBox& bounding_box) const;
inline T GetMedianValue() {
qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>);
return image_data_[data_size_ >> 1];
}
// Returns true iff the pixel is in the image's boundaries for interpolation
// purposes.
// TODO(andrewharp): check in interpolation follow-up change.
inline bool ValidInterpPixel(const float x, const float y) const;
// Safe lookup with boundary enforcement.
inline T GetPixelClipped(const int x, const int y) const {
return (*this)[Clip(y, ZERO, height_less_one_)]
[Clip(x, ZERO, width_less_one_)];
}
#ifdef SANITY_CHECKS
inline RowData<T> operator[](const int row) {
SCHECK(InRange(row, 0, height_less_one_),
"Row out of range: %d (%d max)", row, height_less_one_);
return RowData<T>(image_data_ + row * stride_, width_less_one_);
}
inline const RowData<T> operator[](const int row) const {
SCHECK(InRange(row, 0, height_less_one_),
"Row out of range: %d (%d max)", row, height_less_one_);
return RowData<T>(image_data_ + row * stride_, width_less_one_);
}
#else
inline T* operator[](const int row) {
return image_data_ + row * stride_;
}
inline const T* operator[](const int row) const {
return image_data_ + row * stride_;
}
#endif
const T* data() const { return image_data_; }
inline int stride() const { return stride_; }
// Clears image to a single value.
inline void Clear(const T& val) {
memset(image_data_, val, sizeof(*image_data_) * data_size_);
}
#ifdef __ARM_NEON
void Downsample2x32ColumnsNeon(const uint8_t* const original,
const int stride, const int orig_x);
void Downsample4x32ColumnsNeon(const uint8_t* const original,
const int stride, const int orig_x);
void DownsampleAveragedNeon(const uint8_t* const original, const int stride,
const int factor);
#endif
// Naive downsampler that reduces image size by factor by averaging pixels in
// blocks of size factor x factor.
void DownsampleAveraged(const T* const original, const int stride,
const int factor);
// Naive downsampler that reduces image size by factor by averaging pixels in
// blocks of size factor x factor.
inline void DownsampleAveraged(const Image<T>& original, const int factor) {
DownsampleAveraged(original.data(), original.GetWidth(), factor);
}
// Native downsampler that reduces image size using nearest interpolation
void DownsampleInterpolateNearest(const Image<T>& original);
// Native downsampler that reduces image size using fixed-point bilinear
// interpolation
void DownsampleInterpolateLinear(const Image<T>& original);
// Relatively efficient downsampling of an image by a factor of two with a
// low-pass 3x3 smoothing operation thrown in.
void DownsampleSmoothed3x3(const Image<T>& original);
// Relatively efficient downsampling of an image by a factor of two with a
// low-pass 5x5 smoothing operation thrown in.
void DownsampleSmoothed5x5(const Image<T>& original);
// Optimized Scharr filter on a single pixel in the X direction.
// Scharr filters are like central-difference operators, but have more
// rotational symmetry in their response because they also consider the
// diagonal neighbors.
template <typename U>
inline T ScharrPixelX(const Image<U>& original,
const int center_x, const int center_y) const;
// Optimized Scharr filter on a single pixel in the X direction.
// Scharr filters are like central-difference operators, but have more
// rotational symmetry in their response because they also consider the
// diagonal neighbors.
template <typename U>
inline T ScharrPixelY(const Image<U>& original,
const int center_x, const int center_y) const;
// Convolve the image with a Scharr filter in the X direction.
// Much faster than an equivalent generic convolution.
template <typename U>
inline void ScharrX(const Image<U>& original);
// Convolve the image with a Scharr filter in the Y direction.
// Much faster than an equivalent generic convolution.
template <typename U>
inline void ScharrY(const Image<U>& original);
static inline T HalfDiff(int32_t first, int32_t second) {
return (second - first) / 2;
}
template <typename U>
void DerivativeX(const Image<U>& original);
template <typename U>
void DerivativeY(const Image<U>& original);
// Generic function for convolving pixel with 3x3 filter.
// Filter pixels should be in row major order.
template <typename U>
inline T ConvolvePixel3x3(const Image<U>& original,
const int* const filter,
const int center_x, const int center_y,
const int total) const;
// Generic function for convolving an image with a 3x3 filter.
// TODO(andrewharp): Generalize this for any size filter.
template <typename U>
inline void Convolve3x3(const Image<U>& original,
const int32_t* const filter);
// Load this image's data from a data array. The data at pixels is assumed to
// have dimensions equivalent to this image's dimensions * factor.
inline void FromArray(const T* const pixels, const int stride,
const int factor = 1);
// Copy the image back out to an appropriately sized data array.
inline void ToArray(T* const pixels) const {
// If not subsampling, memcpy should be faster.
memcpy(pixels, this->image_data_, data_size_ * sizeof(T));
}
// Precompute these for efficiency's sake as they're used by a lot of
// clipping code and loop code.
// TODO(andrewharp): make these only accessible by other Images.
const int width_less_one_;
const int height_less_one_;
// The raw size of the allocated data.
const int data_size_;
private:
inline void Allocate() {
image_data_ = new T[data_size_];
if (image_data_ == NULL) {
LOGE("Couldn't allocate image data!");
}
}
T* image_data_;
bool own_data_;
const int width_;
const int height_;
// The image stride (offset to next row).
// TODO(andrewharp): Make sure that stride is honored in all code.
const int stride_;
TF_DISALLOW_COPY_AND_ASSIGN(Image);
};
template <typename t>
inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) {
for (int y = 0; y < image.GetHeight(); ++y) {
for (int x = 0; x < image.GetWidth(); ++x) {
stream << image[y][x] << " ";
}
stream << std::endl;
}
return stream;
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
#include <stdint.h>
#include <memory>
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
namespace tf_tracking {
// Class that encapsulates all bulky processed data for a frame.
class ImageData {
public:
explicit ImageData(const int width, const int height)
: uv_frame_width_(width << 1),
uv_frame_height_(height << 1),
timestamp_(0),
image_(width, height) {
InitPyramid(width, height);
ResetComputationCache();
}
private:
void ResetComputationCache() {
uv_data_computed_ = false;
integral_image_computed_ = false;
for (int i = 0; i < kNumPyramidLevels; ++i) {
spatial_x_computed_[i] = false;
spatial_y_computed_[i] = false;
pyramid_sqrt2_computed_[i * 2] = false;
pyramid_sqrt2_computed_[i * 2 + 1] = false;
}
}
void InitPyramid(const int width, const int height) {
int level_width = width;
int level_height = height;
for (int i = 0; i < kNumPyramidLevels; ++i) {
pyramid_sqrt2_[i * 2] = NULL;
pyramid_sqrt2_[i * 2 + 1] = NULL;
spatial_x_[i] = NULL;
spatial_y_[i] = NULL;
level_width /= 2;
level_height /= 2;
}
// Alias the first pyramid level to image_.
pyramid_sqrt2_[0] = &image_;
}
public:
~ImageData() {
// The first pyramid level is actually an alias to image_,
// so make sure it doesn't get deleted here.
pyramid_sqrt2_[0] = NULL;
for (int i = 0; i < kNumPyramidLevels; ++i) {
SAFE_DELETE(pyramid_sqrt2_[i * 2]);
SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
SAFE_DELETE(spatial_x_[i]);
SAFE_DELETE(spatial_y_[i]);
}
}
void SetData(const uint8_t* const new_frame, const int stride,
const int64_t timestamp, const int downsample_factor) {
SetData(new_frame, NULL, stride, timestamp, downsample_factor);
}
void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame,
const int stride, const int64_t timestamp,
const int downsample_factor) {
ResetComputationCache();
timestamp_ = timestamp;
TimeLog("SetData!");
pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
pyramid_sqrt2_computed_[0] = true;
TimeLog("Downsampled image");
if (uv_frame != NULL) {
if (u_data_.get() == NULL) {
u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
}
GetUV(uv_frame, u_data_.get(), v_data_.get());
uv_data_computed_ = true;
TimeLog("Copied UV data");
} else {
LOGV("No uv data!");
}
#ifdef LOG_TIME
// If profiling is enabled, precompute here to make it easier to distinguish
// total costs.
Precompute();
#endif
}
inline const uint64_t GetTimestamp() const { return timestamp_; }
inline const Image<uint8_t>* GetImage() const {
SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
return pyramid_sqrt2_[0];
}
const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const {
if (!pyramid_sqrt2_computed_[level]) {
SCHECK(level != 0, "Level equals 0!");
if (level == 1) {
const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0);
if (pyramid_sqrt2_[level] == NULL) {
const int new_width =
(static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
const int new_height =
(static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
2;
pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height);
}
pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
} else {
const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2);
if (pyramid_sqrt2_[level] == NULL) {
pyramid_sqrt2_[level] = new Image<uint8_t>(
upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
}
pyramid_sqrt2_[level]->DownsampleAveraged(
upper_level.data(), upper_level.stride(), 2);
}
pyramid_sqrt2_computed_[level] = true;
}
return pyramid_sqrt2_[level];
}
inline const Image<int32_t>* GetSpatialX(const int level) const {
if (!spatial_x_computed_[level]) {
const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
if (spatial_x_[level] == NULL) {
spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
}
spatial_x_[level]->DerivativeX(src);
spatial_x_computed_[level] = true;
}
return spatial_x_[level];
}
inline const Image<int32_t>* GetSpatialY(const int level) const {
if (!spatial_y_computed_[level]) {
const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
if (spatial_y_[level] == NULL) {
spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
}
spatial_y_[level]->DerivativeY(src);
spatial_y_computed_[level] = true;
}
return spatial_y_[level];
}
// The integral image is currently only used for object detection, so lazily
// initialize it on request.
inline const IntegralImage* GetIntegralImage() const {
if (integral_image_.get() == NULL) {
integral_image_.reset(new IntegralImage(image_));
} else if (!integral_image_computed_) {
integral_image_->Recompute(image_);
}
integral_image_computed_ = true;
return integral_image_.get();
}
inline const Image<uint8_t>* GetU() const {
SCHECK(uv_data_computed_, "UV data not provided!");
return u_data_.get();
}
inline const Image<uint8_t>* GetV() const {
SCHECK(uv_data_computed_, "UV data not provided!");
return v_data_.get();
}
private:
void Precompute() {
// Create the smoothed pyramids.
for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
(void) GetPyramidSqrt2Level(i);
}
TimeLog("Created smoothed pyramids");
// Create the smoothed pyramids.
for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
(void) GetPyramidSqrt2Level(i);
}
TimeLog("Created smoothed sqrt pyramids");
// Create the spatial derivatives for frame 1.
for (int i = 0; i < kNumPyramidLevels; ++i) {
(void) GetSpatialX(i);
(void) GetSpatialY(i);
}
TimeLog("Created spatial derivatives");
(void) GetIntegralImage();
TimeLog("Got integral image!");
}
const int uv_frame_width_;
const int uv_frame_height_;
int64_t timestamp_;
Image<uint8_t> image_;
bool uv_data_computed_;
std::unique_ptr<Image<uint8_t> > u_data_;
std::unique_ptr<Image<uint8_t> > v_data_;
mutable bool spatial_x_computed_[kNumPyramidLevels];
mutable Image<int32_t>* spatial_x_[kNumPyramidLevels];
mutable bool spatial_y_computed_[kNumPyramidLevels];
mutable Image<int32_t>* spatial_y_[kNumPyramidLevels];
// Mutable so the lazy initialization can work when this class is const.
// Whether or not the integral image has been computed for the current image.
mutable bool integral_image_computed_;
mutable std::unique_ptr<IntegralImage> integral_image_;
mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2];
TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// NEON implementations of Image methods for compatible devices. Control
// should never enter this compilation unit on incompatible devices.
#ifdef __ARM_NEON
#include <arm_neon.h>
#include <stdint.h>
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
// This function does the bulk of the work.
template <>
void Image<uint8_t>::Downsample2x32ColumnsNeon(const uint8_t* const original,
const int stride,
const int orig_x) {
// Divide input x offset by 2 to find output offset.
const int new_x = orig_x >> 1;
// Initial offset into top row.
const uint8_t* offset = original + orig_x;
// This points to the leftmost pixel of our 8 horizontally arranged
// pixels in the destination data.
uint8_t* ptr_dst = (*this)[0] + new_x;
// Sum along vertical columns.
// Process 32x2 input pixels and 16x1 output pixels per iteration.
for (int new_y = 0; new_y < height_; ++new_y) {
uint16x8_t accum1 = vdupq_n_u16(0);
uint16x8_t accum2 = vdupq_n_u16(0);
// Go top to bottom across the four rows of input pixels that make up
// this output row.
for (int row_num = 0; row_num < 2; ++row_num) {
// First 16 bytes.
{
// Load 16 bytes of data from current offset.
const uint8x16_t curr_data1 = vld1q_u8(offset);
// Pairwise add and accumulate into accum vectors (16 bit to account
// for values above 255).
accum1 = vpadalq_u8(accum1, curr_data1);
}
// Second 16 bytes.
{
// Load 16 bytes of data from current offset.
const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
// Pairwise add and accumulate into accum vectors (16 bit to account
// for values above 255).
accum2 = vpadalq_u8(accum2, curr_data2);
}
// Move offset down one row.
offset += stride;
}
// Divide by 4 (number of input pixels per output
// pixel) and narrow data from 16 bits per pixel to 8 bpp.
const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2);
const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2);
// Concatenate 8x1 pixel strips into 16x1 pixel strip.
const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2);
// Copy all pixels from composite 16x1 vector into output strip.
vst1q_u8(ptr_dst, allpixels);
ptr_dst += stride_;
}
}
// This function does the bulk of the work.
template <>
void Image<uint8_t>::Downsample4x32ColumnsNeon(const uint8_t* const original,
const int stride,
const int orig_x) {
// Divide input x offset by 4 to find output offset.
const int new_x = orig_x >> 2;
// Initial offset into top row.
const uint8_t* offset = original + orig_x;
// This points to the leftmost pixel of our 8 horizontally arranged
// pixels in the destination data.
uint8_t* ptr_dst = (*this)[0] + new_x;
// Sum along vertical columns.
// Process 32x4 input pixels and 8x1 output pixels per iteration.
for (int new_y = 0; new_y < height_; ++new_y) {
uint16x8_t accum1 = vdupq_n_u16(0);
uint16x8_t accum2 = vdupq_n_u16(0);
// Go top to bottom across the four rows of input pixels that make up
// this output row.
for (int row_num = 0; row_num < 4; ++row_num) {
// First 16 bytes.
{
// Load 16 bytes of data from current offset.
const uint8x16_t curr_data1 = vld1q_u8(offset);
// Pairwise add and accumulate into accum vectors (16 bit to account
// for values above 255).
accum1 = vpadalq_u8(accum1, curr_data1);
}
// Second 16 bytes.
{
// Load 16 bytes of data from current offset.
const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
// Pairwise add and accumulate into accum vectors (16 bit to account
// for values above 255).
accum2 = vpadalq_u8(accum2, curr_data2);
}
// Move offset down one row.
offset += stride;
}
// Add and widen, then divide by 16 (number of input pixels per output
// pixel) and narrow data from 32 bits per pixel to 16 bpp.
const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4);
const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4);
// Combine 4x1 pixel strips into 8x1 pixel strip and narrow from
// 16 bits to 8 bits per pixel.
const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2));
// Copy all pixels from composite 8x1 vector into output strip.
vst1_u8(ptr_dst, allpixels);
ptr_dst += stride_;
}
}
// Hardware accelerated downsampling method for supported devices.
// Requires that image size be a multiple of 16 pixels in each dimension,
// and that downsampling be by a factor of 2 or 4.
template <>
void Image<uint8_t>::DownsampleAveragedNeon(const uint8_t* const original,
const int stride,
const int factor) {
// TODO(andrewharp): stride is a bad approximation for the src image's width.
// Better to pass that in directly.
SCHECK(width_ * factor <= stride, "Uh oh!");
const int last_starting_index = width_ * factor - 32;
// We process 32 input pixels lengthwise at a time.
// The output per pass of this loop is an 8 wide by downsampled height tall
// pixel strip.
int orig_x = 0;
for (; orig_x <= last_starting_index; orig_x += 32) {
if (factor == 2) {
Downsample2x32ColumnsNeon(original, stride, orig_x);
} else {
Downsample4x32ColumnsNeon(original, stride, orig_x);
}
}
// If a last pass is required, push it to the left enough so that it never
// goes out of bounds. This will result in some extra computation on devices
// whose frame widths are multiples of 16 and not 32.
if (orig_x < last_starting_index + 32) {
if (factor == 2) {
Downsample2x32ColumnsNeon(original, stride, last_starting_index);
} else {
Downsample4x32ColumnsNeon(original, stride, last_starting_index);
}
}
}
// Puts the image gradient matrix about a pixel into the 2x2 float array G.
// vals_x should be an array of the window x gradient values, whose indices
// can be in any order but are parallel to the vals_y entries.
// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
void CalculateGNeon(const float* const vals_x, const float* const vals_y,
const int num_vals, float* const G) {
const float32_t* const arm_vals_x = (const float32_t*) vals_x;
const float32_t* const arm_vals_y = (const float32_t*) vals_y;
// Running sums.
float32x4_t xx = vdupq_n_f32(0.0f);
float32x4_t xy = vdupq_n_f32(0.0f);
float32x4_t yy = vdupq_n_f32(0.0f);
// Maximum index we can load 4 consecutive values from.
// e.g. if there are 81 values, our last full pass can be from index 77:
// 81-4=>77 (77, 78, 79, 80)
const int max_i = num_vals - 4;
// Defined here because we want to keep track of how many values were
// processed by NEON, so that we can finish off the remainder the normal
// way.
int i = 0;
// Process values 4 at a time, accumulating the sums of
// the pixel-wise x*x, x*y, and y*y values.
for (; i <= max_i; i += 4) {
// Load xs
float32x4_t x = vld1q_f32(arm_vals_x + i);
// Multiply x*x and accumulate.
xx = vmlaq_f32(xx, x, x);
// Load ys
float32x4_t y = vld1q_f32(arm_vals_y + i);
// Multiply x*y and accumulate.
xy = vmlaq_f32(xy, x, y);
// Multiply y*y and accumulate.
yy = vmlaq_f32(yy, y, y);
}
static float32_t xx_vals[4];
static float32_t xy_vals[4];
static float32_t yy_vals[4];
vst1q_f32(xx_vals, xx);
vst1q_f32(xy_vals, xy);
vst1q_f32(yy_vals, yy);
// Accumulated values are store in sets of 4, we have to manually add
// the last bits together.
for (int j = 0; j < 4; ++j) {
G[0] += xx_vals[j];
G[1] += xy_vals[j];
G[3] += yy_vals[j];
}
// Finishes off last few values (< 4) from above.
for (; i < num_vals; ++i) {
G[0] += Square(vals_x[i]);
G[1] += vals_x[i] * vals_y[i];
G[3] += Square(vals_y[i]);
}
// The matrix is symmetric, so this is a given.
G[2] = G[1];
}
} // namespace tf_tracking
#endif
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
#include <stdint.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
inline void GetUV(const uint8_t* const input, Image<uint8_t>* const u,
Image<uint8_t>* const v) {
const uint8_t* pUV = input;
for (int row = 0; row < u->GetHeight(); ++row) {
uint8_t* u_curr = (*u)[row];
uint8_t* v_curr = (*v)[row];
for (int col = 0; col < u->GetWidth(); ++col) {
#ifdef __APPLE__
*u_curr++ = *pUV++;
*v_curr++ = *pUV++;
#else
*v_curr++ = *pUV++;
*u_curr++ = *pUV++;
#endif
}
}
}
// Marks every point within a circle of a given radius on the given boolean
// image true.
template <typename U>
inline static void MarkImage(const int x, const int y, const int radius,
Image<U>* const img) {
SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y);
// Precomputed for efficiency.
const int squared_radius = Square(radius);
// Mark every row in the circle.
for (int d_y = 0; d_y <= radius; ++d_y) {
const int squared_y_dist = Square(d_y);
const int min_y = MAX(y - d_y, 0);
const int max_y = MIN(y + d_y, img->height_less_one_);
// The max d_x of the circle must be strictly greater or equal to
// radius - d_y for any positive d_y. Thus, starting from radius - d_y will
// reduce the number of iterations required as compared to starting from
// either 0 and counting up or radius and counting down.
for (int d_x = radius - d_y; d_x <= radius; ++d_x) {
// The first time this criteria is met, we know the width of the circle at
// this row (without using sqrt).
if (squared_y_dist + Square(d_x) >= squared_radius) {
const int min_x = MAX(x - d_x, 0);
const int max_x = MIN(x + d_x, img->width_less_one_);
// Mark both above and below the center row.
bool* const top_row_start = (*img)[min_y] + min_x;
bool* const bottom_row_start = (*img)[max_y] + min_x;
const int x_width = max_x - min_x + 1;
memset(top_row_start, true, sizeof(*top_row_start) * x_width);
memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width);
// This row is marked, time to move on to the next row.
break;
}
}
}
}
#ifdef __ARM_NEON
void CalculateGNeon(
const float* const vals_x, const float* const vals_y,
const int num_vals, float* const G);
#endif
// Puts the image gradient matrix about a pixel into the 2x2 float array G.
// vals_x should be an array of the window x gradient values, whose indices
// can be in any order but are parallel to the vals_y entries.
// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
inline void CalculateG(const float* const vals_x, const float* const vals_y,
const int num_vals, float* const G) {
#ifdef __ARM_NEON
CalculateGNeon(vals_x, vals_y, num_vals, G);
return;
#endif
// Non-accelerated version.
for (int i = 0; i < num_vals; ++i) {
G[0] += Square(vals_x[i]);
G[1] += vals_x[i] * vals_y[i];
G[3] += Square(vals_y[i]);
}
// The matrix is symmetric, so this is a given.
G[2] = G[1];
}
inline void CalculateGInt16(const int16_t* const vals_x,
const int16_t* const vals_y, const int num_vals,
int* const G) {
// Non-accelerated version.
for (int i = 0; i < num_vals; ++i) {
G[0] += Square(vals_x[i]);
G[1] += vals_x[i] * vals_y[i];
G[3] += Square(vals_y[i]);
}
// The matrix is symmetric, so this is a given.
G[2] = G[1];
}
// Puts the image gradient matrix about a pixel into the 2x2 float array G.
// Looks up interpolated pixels, then calls above method for implementation.
inline void CalculateG(const int window_radius, const float center_x,
const float center_y, const Image<int32_t>& I_x,
const Image<int32_t>& I_y, float* const G) {
SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!");
// Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels).
static const int kMaxWindowRadius = 5;
SCHECK(window_radius <= kMaxWindowRadius,
"Window %d > %d!", window_radius, kMaxWindowRadius);
// Diameter of window is 2 * radius + 1 for center pixel.
static const int kWindowBufferSize =
(kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1);
// Preallocate buffers statically for efficiency.
static int16_t vals_x[kWindowBufferSize];
static int16_t vals_y[kWindowBufferSize];
const int src_left_fixed = RealToFixed1616(center_x - window_radius);
const int src_top_fixed = RealToFixed1616(center_y - window_radius);
int16_t* vals_x_ptr = vals_x;
int16_t* vals_y_ptr = vals_y;
const int window_size = 2 * window_radius + 1;
for (int y = 0; y < window_size; ++y) {
const int fp_y = src_top_fixed + (y << 16);
for (int x = 0; x < window_size; ++x) {
const int fp_x = src_left_fixed + (x << 16);
*vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
*vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
}
}
int32_t g_temp[] = {0, 0, 0, 0};
CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp);
for (int i = 0; i < 4; ++i) {
G[i] = g_temp[i];
}
}
inline float ImageCrossCorrelation(const Image<float>& image1,
const Image<float>& image2,
const int x_offset, const int y_offset) {
SCHECK(image1.GetWidth() == image2.GetWidth() &&
image1.GetHeight() == image2.GetHeight(),
"Dimension mismatch! %dx%d vs %dx%d",
image1.GetWidth(), image1.GetHeight(),
image2.GetWidth(), image2.GetHeight());
const int num_pixels = image1.GetWidth() * image1.GetHeight();
const float* data1 = image1.data();
const float* data2 = image2.data();
return ComputeCrossCorrelation(data1, data2, num_pixels);
}
// Copies an arbitrary region of an image to another (floating point)
// image, scaling as it goes using bilinear interpolation.
inline void CopyArea(const Image<uint8_t>& image,
const BoundingBox& area_to_copy,
Image<float>* const patch_image) {
VLOG(2) << "Copying from: " << area_to_copy << std::endl;
const int patch_width = patch_image->GetWidth();
const int patch_height = patch_image->GetHeight();
const float x_dist_between_samples = patch_width > 0 ?
area_to_copy.GetWidth() / (patch_width - 1) : 0;
const float y_dist_between_samples = patch_height > 0 ?
area_to_copy.GetHeight() / (patch_height - 1) : 0;
for (int y_index = 0; y_index < patch_height; ++y_index) {
const float sample_y =
y_index * y_dist_between_samples + area_to_copy.top_;
for (int x_index = 0; x_index < patch_width; ++x_index) {
const float sample_x =
x_index * x_dist_between_samples + area_to_copy.left_;
if (image.ValidInterpPixel(sample_x, sample_y)) {
// TODO(andrewharp): Do area averaging when downsampling.
(*patch_image)[y_index][x_index] =
image.GetPixelInterp(sample_x, sample_y);
} else {
(*patch_image)[y_index][x_index] = -1.0f;
}
}
}
}
// Takes a floating point image and normalizes it in-place.
//
// First, negative values will be set to the mean of the non-negative pixels
// in the image.
//
// Then, the resulting will be normalized such that it has mean value of 0.0 and
// a standard deviation of 1.0.
inline void NormalizeImage(Image<float>* const image) {
const float* const data_ptr = image->data();
// Copy only the non-negative values to some temp memory.
float running_sum = 0.0f;
int num_data_gte_zero = 0;
{
float* const curr_data = (*image)[0];
for (int i = 0; i < image->data_size_; ++i) {
if (curr_data[i] >= 0.0f) {
running_sum += curr_data[i];
++num_data_gte_zero;
} else {
curr_data[i] = -1.0f;
}
}
}
// If none of the pixels are valid, just set the entire thing to 0.0f.
if (num_data_gte_zero == 0) {
image->Clear(0.0f);
return;
}
const float corrected_mean = running_sum / num_data_gte_zero;
float* curr_data = (*image)[0];
for (int i = 0; i < image->data_size_; ++i) {
const float curr_val = *curr_data;
*curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean;
}
const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f);
if (std_dev > 0.0f) {
curr_data = (*image)[0];
for (int i = 0; i < image->data_size_; ++i) {
*curr_data++ /= std_dev;
}
#ifdef SANITY_CHECKS
LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev);
const float correlation =
ComputeCrossCorrelation(image->data(),
image->data(),
image->data_size_);
if (std::abs(correlation - 1.0f) > EPSILON) {
LOG(ERROR) << "Bad image!" << std::endl;
LOG(ERROR) << *image << std::endl;
}
SCHECK(std::abs(correlation - 1.0f) < EPSILON,
"Correlation wasn't 1.0f: %.10f", correlation);
#endif
}
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
typedef uint8_t Code;
class IntegralImage : public Image<uint32_t> {
public:
explicit IntegralImage(const Image<uint8_t>& image_base)
: Image<uint32_t>(image_base.GetWidth(), image_base.GetHeight()) {
Recompute(image_base);
}
IntegralImage(const int width, const int height)
: Image<uint32_t>(width, height) {}
void Recompute(const Image<uint8_t>& image_base) {
SCHECK(image_base.GetWidth() == GetWidth() &&
image_base.GetHeight() == GetHeight(), "Dimensions don't match!");
// Sum along first row.
{
int x_sum = 0;
for (int x = 0; x < image_base.GetWidth(); ++x) {
x_sum += image_base[0][x];
(*this)[0][x] = x_sum;
}
}
// Sum everything else.
for (int y = 1; y < image_base.GetHeight(); ++y) {
uint32_t* curr_sum = (*this)[y];
// Previously summed pointers.
const uint32_t* up_one = (*this)[y - 1];
// Current value pointer.
const uint8_t* curr_delta = image_base[y];
uint32_t row_till_now = 0;
for (int x = 0; x < GetWidth(); ++x) {
// Add the one above and the one to the left.
row_till_now += *curr_delta;
*curr_sum = *up_one + row_till_now;
// Scoot everything along.
++curr_sum;
++up_one;
++curr_delta;
}
}
SCHECK(VerifyData(image_base), "Images did not match!");
}
bool VerifyData(const Image<uint8_t>& image_base) {
for (int y = 0; y < GetHeight(); ++y) {
for (int x = 0; x < GetWidth(); ++x) {
uint32_t curr_val = (*this)[y][x];
if (x > 0) {
curr_val -= (*this)[y][x - 1];
}
if (y > 0) {
curr_val -= (*this)[y - 1][x];
}
if (x > 0 && y > 0) {
curr_val += (*this)[y - 1][x - 1];
}
if (curr_val != image_base[y][x]) {
LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]);
return false;
}
if (GetRegionSum(x, y, x, y) != curr_val) {
LOGE("Mismatch!");
}
}
}
return true;
}
// Returns the sum of all pixels in the specified region.
inline uint32_t GetRegionSum(const int x1, const int y1, const int x2,
const int y2) const {
SCHECK(x1 >= 0 && y1 >= 0 &&
x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(),
"indices out of bounds! %d-%d / %d, %d-%d / %d, ",
x1, x2, GetWidth(), y1, y2, GetHeight());
const uint32_t everything = (*this)[y2][x2];
uint32_t sum = everything;
if (x1 > 0 && y1 > 0) {
// Most common case.
const uint32_t left = (*this)[y2][x1 - 1];
const uint32_t top = (*this)[y1 - 1][x2];
const uint32_t top_left = (*this)[y1 - 1][x1 - 1];
sum = everything - left - top + top_left;
SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d",
everything, left, top, top_left, sum, x1, y1, x2, y2);
} else if (x1 > 0) {
// Flush against top of image.
// Subtract out the region to the left only.
const uint32_t top = (*this)[y2][x1 - 1];
sum = everything - top;
SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum);
} else if (y1 > 0) {
// Flush against left side of image.
// Subtract out the region above only.
const uint32_t left = (*this)[y1 - 1][x2];
sum = everything - left;
SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum);
}
SCHECK(sum >= 0, "Negative sum!");
return sum;
}
// Returns the 2bit code associated with this region, which represents
// the overall gradient.
inline Code GetCode(const BoundingBox& bounding_box) const {
return GetCode(bounding_box.left_, bounding_box.top_,
bounding_box.right_, bounding_box.bottom_);
}
inline Code GetCode(const int x1, const int y1,
const int x2, const int y2) const {
SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d",
x1, y1, x2, y2);
// Gradient computed vertically.
const int box_height = (y2 - y1) / 2;
const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height);
const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2);
const bool vertical_code = top_sum > bottom_sum;
// Gradient computed horizontally.
const int box_width = (x2 - x1) / 2;
const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2);
const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2);
const bool horizontal_code = left_sum > right_sum;
const Code final_code = (vertical_code << 1) | horizontal_code;
SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)),
"Invalid code! %d", final_code);
// Returns a value 0-3.
return final_code;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
#include <jni.h>
#include <stdint.h>
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
// The JniLongField class is used to access Java fields from native code. This
// technique of hiding pointers to native objects in opaque Java fields is how
// the Android hardware libraries work. This reduces the amount of static
// native methods and makes it easier to manage the lifetime of native objects.
class JniLongField {
public:
JniLongField(const char* field_name)
: field_name_(field_name), field_ID_(0) {}
int64_t get(JNIEnv* env, jobject thiz) {
if (field_ID_ == 0) {
jclass cls = env->GetObjectClass(thiz);
CHECK_ALWAYS(cls != 0, "Unable to find class");
field_ID_ = env->GetFieldID(cls, field_name_, "J");
CHECK_ALWAYS(field_ID_ != 0,
"Unable to find field %s. (Check proguard cfg)", field_name_);
}
return env->GetLongField(thiz, field_ID_);
}
void set(JNIEnv* env, jobject thiz, int64_t value) {
if (field_ID_ == 0) {
jclass cls = env->GetObjectClass(thiz);
CHECK_ALWAYS(cls != 0, "Unable to find class");
field_ID_ = env->GetFieldID(cls, field_name_, "J");
CHECK_ALWAYS(field_ID_ != 0,
"Unable to find field %s (Check proguard cfg)", field_name_);
}
env->SetLongField(thiz, field_ID_, value);
}
private:
const char* const field_name_;
// This is just a cache
jfieldID field_ID_;
};
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
namespace tf_tracking {
// For keeping track of keypoints.
struct Keypoint {
Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {}
Keypoint(const float x, const float y)
: pos_(x, y), score_(0.0f), type_(0) {}
Point2f pos_;
float score_;
uint8_t type_;
};
inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) {
return stream << "[" << keypoint.pos_ << ", "
<< keypoint.score_ << ", " << keypoint.type_ << "]";
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Various keypoint detecting functions.
#include <float.h>
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
namespace tf_tracking {
static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) {
return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]);
}
void KeypointDetector::ScoreKeypoints(const ImageData& image_data,
const int num_candidates,
Keypoint* const candidate_keypoints) {
const Image<int>& I_x = *image_data.GetSpatialX(0);
const Image<int>& I_y = *image_data.GetSpatialY(0);
if (config_->detect_skin) {
const Image<uint8_t>& u_data = *image_data.GetU();
const Image<uint8_t>& v_data = *image_data.GetV();
static const int reference[] = {111, 155};
// Score all the keypoints.
for (int i = 0; i < num_candidates; ++i) {
Keypoint* const keypoint = candidate_keypoints + i;
const int x_pos = keypoint->pos_.x * 2;
const int y_pos = keypoint->pos_.y * 2;
const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]};
keypoint->score_ =
HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) /
GetDistSquaredBetween(reference, curr_color);
}
} else {
// Score all the keypoints.
for (int i = 0; i < num_candidates; ++i) {
Keypoint* const keypoint = candidate_keypoints + i;
keypoint->score_ =
HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y);
}
}
}
inline int KeypointCompare(const void* const a, const void* const b) {
return (reinterpret_cast<const Keypoint*>(a)->score_ -
reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1;
}
// Quicksorts detected keypoints by score.
void KeypointDetector::SortKeypoints(const int num_candidates,
Keypoint* const candidate_keypoints) const {
qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare);
#ifdef SANITY_CHECKS
// Verify that the array got sorted.
float last_score = FLT_MAX;
for (int i = 0; i < num_candidates; ++i) {
const float curr_score = candidate_keypoints[i].score_;
// Scores should be monotonically increasing.
SCHECK(last_score >= curr_score,
"Quicksort failure! %d: %.5f > %d: %.5f (%d total)",
i - 1, last_score, i, curr_score, num_candidates);
last_score = curr_score;
}
#endif
}
int KeypointDetector::SelectKeypointsInBox(
const BoundingBox& box,
const Keypoint* const candidate_keypoints,
const int num_candidates,
const int max_keypoints,
const int num_existing_keypoints,
const Keypoint* const existing_keypoints,
Keypoint* const final_keypoints) const {
if (max_keypoints <= 0) {
return 0;
}
// This is the distance within which keypoints may be placed to each other
// within this box, roughly based on the box dimensions.
const int distance =
MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f);
// First, mark keypoints that already happen to be inside this region. Ignore
// keypoints that are outside it, however close they might be.
interest_map_->Clear(false);
for (int i = 0; i < num_existing_keypoints; ++i) {
const Keypoint& candidate = existing_keypoints[i];
const int x_pos = candidate.pos_.x;
const int y_pos = candidate.pos_.y;
if (box.Contains(candidate.pos_)) {
MarkImage(x_pos, y_pos, distance, interest_map_.get());
}
}
// Now, go through and check which keypoints will still fit in the box.
int num_keypoints_selected = 0;
for (int i = 0; i < num_candidates; ++i) {
const Keypoint& candidate = candidate_keypoints[i];
const int x_pos = candidate.pos_.x;
const int y_pos = candidate.pos_.y;
if (!box.Contains(candidate.pos_) ||
!interest_map_->ValidPixel(x_pos, y_pos)) {
continue;
}
if (!(*interest_map_)[y_pos][x_pos]) {
final_keypoints[num_keypoints_selected++] = candidate;
if (num_keypoints_selected >= max_keypoints) {
break;
}
MarkImage(x_pos, y_pos, distance, interest_map_.get());
}
}
return num_keypoints_selected;
}
void KeypointDetector::SelectKeypoints(
const std::vector<BoundingBox>& boxes,
const Keypoint* const candidate_keypoints,
const int num_candidates,
FramePair* const curr_change) const {
// Now select all the interesting keypoints that fall insider our boxes.
curr_change->number_of_keypoints_ = 0;
for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
iter != boxes.end(); ++iter) {
const BoundingBox bounding_box = *iter;
// Count up keypoints that have already been selected, and fall within our
// box.
int num_keypoints_already_in_box = 0;
for (int i = 0; i < curr_change->number_of_keypoints_; ++i) {
if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) {
++num_keypoints_already_in_box;
}
}
const int max_keypoints_to_find_in_box =
MIN(kMaxKeypointsForObject - num_keypoints_already_in_box,
kMaxKeypoints - curr_change->number_of_keypoints_);
const int num_new_keypoints_in_box = SelectKeypointsInBox(
bounding_box,
candidate_keypoints,
num_candidates,
max_keypoints_to_find_in_box,
curr_change->number_of_keypoints_,
curr_change->frame1_keypoints_,
curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_);
curr_change->number_of_keypoints_ += num_new_keypoints_in_box;
LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_);
}
}
// Walks along the given circle checking for pixels above or below the center.
// Returns a score, or 0 if the keypoint did not pass the criteria.
//
// Parameters:
// circle_perimeter: the circumference in pixels of the circle.
// threshold: the minimum number of contiguous pixels that must be above or
// below the center value.
// center_ptr: the location of the center pixel in memory
// offsets: the relative offsets from the center pixel of the edge pixels.
inline int TestCircle(const int circle_perimeter, const int threshold,
const uint8_t* const center_ptr, const int* offsets) {
// Get the actual value of the center pixel for easier reference later on.
const int center_value = static_cast<int>(*center_ptr);
// Number of total pixels to check. Have to wrap around some in case
// the contiguous section is split by the array edges.
const int num_total = circle_perimeter + threshold - 1;
int num_above = 0;
int above_diff = 0;
int num_below = 0;
int below_diff = 0;
// Used to tell when this is definitely not going to meet the threshold so we
// can early abort.
int minimum_by_now = threshold - num_total + 1;
// Go through every pixel along the perimeter of the circle, and then around
// again a little bit.
for (int i = 0; i < num_total; ++i) {
// This should be faster than mod.
const int perim_index = i < circle_perimeter ? i : i - circle_perimeter;
// This gets the value of the current pixel along the perimeter by using
// a precomputed offset.
const int curr_value =
static_cast<int>(center_ptr[offsets[perim_index]]);
const int difference = curr_value - center_value;
if (difference > kFastDiffAmount) {
above_diff += difference;
++num_above;
num_below = 0;
below_diff = 0;
if (num_above >= threshold) {
return above_diff;
}
} else if (difference < -kFastDiffAmount) {
below_diff += difference;
++num_below;
num_above = 0;
above_diff = 0;
if (num_below >= threshold) {
return below_diff;
}
} else {
num_above = 0;
num_below = 0;
above_diff = 0;
below_diff = 0;
}
// See if there's any chance of making the threshold.
if (MAX(num_above, num_below) < minimum_by_now) {
// Didn't pass.
return 0;
}
++minimum_by_now;
}
// Didn't pass.
return 0;
}
// Returns a score in the range [0.0, positive infinity) which represents the
// relative likelihood of a point being a corner.
float KeypointDetector::HarrisFilter(const Image<int32_t>& I_x,
const Image<int32_t>& I_y, const float x,
const float y) const {
if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) &&
I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) {
// Image gradient matrix.
float G[] = { 0, 0, 0, 0 };
CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G);
const float dx = G[0];
const float dy = G[3];
const float dxy = G[1];
// Harris-Nobel corner score.
return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN);
}
return 0.0f;
}
int KeypointDetector::AddExtraCandidatesForBoxes(
const std::vector<BoundingBox>& boxes,
const int max_num_keypoints,
Keypoint* const keypoints) const {
int num_keypoints_added = 0;
for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
iter != boxes.end(); ++iter) {
const BoundingBox box = *iter;
for (int i = 0; i < kNumToAddAsCandidates; ++i) {
for (int j = 0; j < kNumToAddAsCandidates; ++j) {
if (num_keypoints_added >= max_num_keypoints) {
LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints);
return num_keypoints_added;
}
Keypoint& curr_keypoint = keypoints[num_keypoints_added++];
curr_keypoint.pos_ = Point2f(
box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates,
box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates);
curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST;
}
}
}
return num_keypoints_added;
}
void KeypointDetector::FindKeypoints(const ImageData& image_data,
const std::vector<BoundingBox>& rois,
const FramePair& prev_change,
FramePair* const curr_change) {
// Copy keypoints from second frame of last pass to temp keypoints of this
// pass.
int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_);
const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints;
number_of_tmp_keypoints +=
FindFastKeypoints(image_data, max_num_fast,
tmp_keypoints_ + number_of_tmp_keypoints);
TimeLog("Found FAST keypoints");
if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints",
kMaxTempKeypoints, number_of_tmp_keypoints);
}
if (kAddArbitraryKeypoints) {
// Add some for each object prior to scoring.
const int max_num_box_keypoints =
kMaxTempKeypoints - number_of_tmp_keypoints;
number_of_tmp_keypoints +=
AddExtraCandidatesForBoxes(rois, max_num_box_keypoints,
tmp_keypoints_ + number_of_tmp_keypoints);
TimeLog("Added box keypoints");
if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints",
kMaxTempKeypoints, number_of_tmp_keypoints);
}
}
// Score them...
LOGV("Scoring %d keypoints!", number_of_tmp_keypoints);
ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_);
TimeLog("Scored keypoints");
// Now pare it down a bit.
SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_);
TimeLog("Sorted keypoints");
LOGV("%d keypoints to select from!", number_of_tmp_keypoints);
SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change);
TimeLog("Selected keypoints");
LOGV("Picked %d (%d max) final keypoints out of %d potential.",
curr_change->number_of_keypoints_,
kMaxKeypoints, number_of_tmp_keypoints);
}
int KeypointDetector::CopyKeypoints(const FramePair& prev_change,
Keypoint* const new_keypoints) {
int number_of_keypoints = 0;
// Caching values from last pass, just copy and compact.
for (int i = 0; i < prev_change.number_of_keypoints_; ++i) {
if (prev_change.optical_flow_found_keypoint_[i]) {
new_keypoints[number_of_keypoints] =
prev_change.frame2_keypoints_[i];
new_keypoints[number_of_keypoints].score_ =
prev_change.frame1_keypoints_[i].score_;
++number_of_keypoints;
}
}
TimeLog("Copied keypoints");
return number_of_keypoints;
}
// FAST keypoint detector.
int KeypointDetector::FindFastKeypoints(const Image<uint8_t>& frame,
const int quadrant,
const int downsample_factor,
const int max_num_keypoints,
Keypoint* const keypoints) {
/*
// Reference for a circle of diameter 7.
const int circle[] = {0, 0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 0, 1, 0,
1, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 0, 0, 1,
0, 1, 0, 0, 0, 1, 0,
0, 0, 1, 1, 1, 0, 0};
const int circle_offset[] =
{2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46};
*/
// Quick test of compass directions. Any length 16 circle with a break of up
// to 4 pixels will have at least 3 of these 4 pixels active.
static const int short_circle_perimeter = 4;
static const int short_threshold = 3;
static const int short_circle_x[] = { -3, 0, +3, 0 };
static const int short_circle_y[] = { 0, -3, 0, +3 };
// Precompute image offsets.
int short_offsets[short_circle_perimeter];
for (int i = 0; i < short_circle_perimeter; ++i) {
short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth();
}
// Large circle values.
static const int full_circle_perimeter = 16;
static const int full_threshold = 12;
static const int full_circle_x[] =
{ -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 };
static const int full_circle_y[] =
{ -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 };
// Precompute image offsets.
int full_offsets[full_circle_perimeter];
for (int i = 0; i < full_circle_perimeter; ++i) {
full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth();
}
const int scratch_stride = frame.stride();
keypoint_scratch_->Clear(0);
// Set up the bounds on the region to test based on the passed-in quadrant.
const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer;
const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer;
const int start_x =
kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width);
const int start_y =
kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height);
const int end_x = start_x + quadrant_width;
const int end_y = start_y + quadrant_height;
// Loop through once to find FAST keypoint clumps.
for (int img_y = start_y; img_y < end_y; ++img_y) {
const uint8_t* curr_pixel_ptr = frame[img_y] + start_x;
for (int img_x = start_x; img_x < end_x; ++img_x) {
// Only insert it if it meets the quick minimum requirements test.
if (TestCircle(short_circle_perimeter, short_threshold,
curr_pixel_ptr, short_offsets) != 0) {
// Longer test for actual keypoint score..
const int fast_score = TestCircle(full_circle_perimeter,
full_threshold,
curr_pixel_ptr,
full_offsets);
// Non-zero score means the keypoint was found.
if (fast_score != 0) {
uint8_t* const center_ptr = (*keypoint_scratch_)[img_y] + img_x;
// Increase the keypoint count on this pixel and the pixels in all
// 4 cardinal directions.
*center_ptr += 5;
*(center_ptr - 1) += 1;
*(center_ptr + 1) += 1;
*(center_ptr - scratch_stride) += 1;
*(center_ptr + scratch_stride) += 1;
}
}
++curr_pixel_ptr;
} // x
} // y
TimeLog("Found FAST keypoints.");
int num_keypoints = 0;
// Loop through again and Harris filter pixels in the center of clumps.
// We can shrink the window by 1 pixel on every side.
for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) {
const uint8_t* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x;
for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) {
if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) {
Keypoint* const keypoint = keypoints + num_keypoints;
keypoint->pos_ = Point2f(
img_x * downsample_factor, img_y * downsample_factor);
keypoint->score_ = 0;
keypoint->type_ = KEYPOINT_TYPE_FAST;
++num_keypoints;
if (num_keypoints >= max_num_keypoints) {
return num_keypoints;
}
}
++curr_pixel_ptr;
} // x
} // y
TimeLog("Picked FAST keypoints.");
return num_keypoints;
}
int KeypointDetector::FindFastKeypoints(const ImageData& image_data,
const int max_num_keypoints,
Keypoint* const keypoints) {
int downsample_factor = 1;
int num_found = 0;
// TODO(andrewharp): Get this working for multiple image scales.
for (int i = 0; i < 1; ++i) {
const Image<uint8_t>& frame = *image_data.GetPyramidSqrt2Level(i);
num_found += FindFastKeypoints(
frame, fast_quadrant_,
downsample_factor, max_num_keypoints, keypoints + num_found);
downsample_factor *= 2;
}
// Increment the current quadrant.
fast_quadrant_ = (fast_quadrant_ + 1) % 4;
return num_found;
}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
#include <stdint.h>
#include <vector>
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
namespace tf_tracking {
struct Keypoint;
class KeypointDetector {
public:
explicit KeypointDetector(const KeypointDetectorConfig* const config)
: config_(config),
keypoint_scratch_(new Image<uint8_t>(config_->image_size)),
interest_map_(new Image<bool>(config_->image_size)),
fast_quadrant_(0) {
interest_map_->Clear(false);
}
~KeypointDetector() {}
// Finds a new set of keypoints for the current frame, picked from the current
// set of keypoints and also from a set discovered via a keypoint detector.
// Special attention is applied to make sure that keypoints are distributed
// within the supplied ROIs.
void FindKeypoints(const ImageData& image_data,
const std::vector<BoundingBox>& rois,
const FramePair& prev_change,
FramePair* const curr_change);
private:
// Compute the corneriness of a point in the image.
float HarrisFilter(const Image<int32_t>& I_x, const Image<int32_t>& I_y,
const float x, const float y) const;
// Adds a grid of candidate keypoints to the given box, up to
// max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
int AddExtraCandidatesForBoxes(
const std::vector<BoundingBox>& boxes,
const int max_num_keypoints,
Keypoint* const keypoints) const;
// Scan the frame for potential keypoints using the FAST keypoint detector.
// Quadrant is an argument 0-3 which refers to the quadrant of the image in
// which to detect keypoints.
int FindFastKeypoints(const Image<uint8_t>& frame, const int quadrant,
const int downsample_factor,
const int max_num_keypoints, Keypoint* const keypoints);
int FindFastKeypoints(const ImageData& image_data,
const int max_num_keypoints,
Keypoint* const keypoints);
// Score a bunch of candidate keypoints. Assigns the scores to the input
// candidate_keypoints array entries.
void ScoreKeypoints(const ImageData& image_data,
const int num_candidates,
Keypoint* const candidate_keypoints);
void SortKeypoints(const int num_candidates,
Keypoint* const candidate_keypoints) const;
// Selects a set of keypoints falling within the supplied box such that the
// most highly rated keypoints are picked first, and so that none of them are
// too close together.
int SelectKeypointsInBox(
const BoundingBox& box,
const Keypoint* const candidate_keypoints,
const int num_candidates,
const int max_keypoints,
const int num_existing_keypoints,
const Keypoint* const existing_keypoints,
Keypoint* const final_keypoints) const;
// Selects from the supplied sorted keypoint pool a set of keypoints that will
// best cover the given set of boxes, such that each box is covered at a
// resolution proportional to its size.
void SelectKeypoints(
const std::vector<BoundingBox>& boxes,
const Keypoint* const candidate_keypoints,
const int num_candidates,
FramePair* const frame_change) const;
// Copies and compacts the found keypoints in the second frame of prev_change
// into the array at new_keypoints.
static int CopyKeypoints(const FramePair& prev_change,
Keypoint* const new_keypoints);
const KeypointDetectorConfig* const config_;
// Scratch memory for keypoint candidacy detection and non-max suppression.
std::unique_ptr<Image<uint8_t> > keypoint_scratch_;
// Regions of the image to pay special attention to.
std::unique_ptr<Image<bool> > interest_map_;
// The current quadrant of the image to detect FAST keypoints in.
// Keypoint detection is staggered for performance reasons. Every four frames
// a full scan of the frame will have been performed.
int fast_quadrant_;
Keypoint tmp_keypoints_[kMaxTempKeypoints];
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#ifdef STANDALONE_DEMO_LIB
#include <android/log.h>
#include <stdlib.h>
#include <time.h>
#include <iostream>
#include <sstream>
LogMessage::LogMessage(const char* fname, int line, int severity)
: fname_(fname), line_(line), severity_(severity) {}
void LogMessage::GenerateLogMessage() {
int android_log_level;
switch (severity_) {
case INFO:
android_log_level = ANDROID_LOG_INFO;
break;
case WARNING:
android_log_level = ANDROID_LOG_WARN;
break;
case ERROR:
android_log_level = ANDROID_LOG_ERROR;
break;
case FATAL:
android_log_level = ANDROID_LOG_FATAL;
break;
default:
if (severity_ < INFO) {
android_log_level = ANDROID_LOG_VERBOSE;
} else {
android_log_level = ANDROID_LOG_ERROR;
}
break;
}
std::stringstream ss;
const char* const partial_name = strrchr(fname_, '/');
ss << (partial_name != nullptr ? partial_name + 1 : fname_) << ":" << line_
<< " " << str();
__android_log_write(android_log_level, "native", ss.str().c_str());
// Also log to stderr (for standalone Android apps).
std::cerr << "native : " << ss.str() << std::endl;
// Android logging at level FATAL does not terminate execution, so abort()
// is still required to stop the program.
if (severity_ == FATAL) {
abort();
}
}
namespace {
// Parse log level (int64) from environment variable (char*)
int64_t LogLevelStrToInt(const char* tf_env_var_val) {
if (tf_env_var_val == nullptr) {
return 0;
}
// Ideally we would use env_var / safe_strto64, but it is
// hard to use here without pulling in a lot of dependencies,
// so we use std:istringstream instead
std::string min_log_level(tf_env_var_val);
std::istringstream ss(min_log_level);
int64_t level;
if (!(ss >> level)) {
// Invalid vlog level setting, set level to default (0)
level = 0;
}
return level;
}
int64_t MinLogLevelFromEnv() {
const char* tf_env_var_val = getenv("TF_CPP_MIN_LOG_LEVEL");
return LogLevelStrToInt(tf_env_var_val);
}
int64_t MinVLogLevelFromEnv() {
const char* tf_env_var_val = getenv("TF_CPP_MIN_VLOG_LEVEL");
return LogLevelStrToInt(tf_env_var_val);
}
} // namespace
LogMessage::~LogMessage() {
// Read the min log level once during the first call to logging.
static int64_t min_log_level = MinLogLevelFromEnv();
if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage();
}
int64_t LogMessage::MinVLogLevel() {
static const int64_t min_vlog_level = MinVLogLevelFromEnv();
return min_vlog_level;
}
LogMessageFatal::LogMessageFatal(const char* file, int line)
: LogMessage(file, line, ANDROID_LOG_FATAL) {}
LogMessageFatal::~LogMessageFatal() {
// abort() ensures we don't return (we promised we would not via
// ATTRIBUTE_NORETURN).
GenerateLogMessage();
abort();
}
void LogString(const char* fname, int line, int severity,
const std::string& message) {
LogMessage(fname, line, severity) << message;
}
void LogPrintF(const int severity, const char* format, ...) {
char message[1024];
va_list argptr;
va_start(argptr, format);
vsnprintf(message, 1024, format, argptr);
va_end(argptr);
__android_log_write(severity, "native", message);
// Also log to stderr (for standalone Android apps).
std::cerr << "native : " << message << std::endl;
}
#endif
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
#include <android/log.h>
#include <string.h>
#include <ostream>
#include <sstream>
#include <string>
// Allow this library to be built without depending on TensorFlow by
// defining STANDALONE_DEMO_LIB. Otherwise TensorFlow headers will be
// used.
#ifdef STANDALONE_DEMO_LIB
// A macro to disallow the copy constructor and operator= functions
// This is usually placed in the private: declarations for a class.
#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#if defined(COMPILER_GCC3)
#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
#else
#define TF_PREDICT_FALSE(x) (x)
#define TF_PREDICT_TRUE(x) (x)
#endif
// Log levels equivalent to those defined by
// third_party/tensorflow/core/platform/logging.h
const int INFO = 0; // base_logging::INFO;
const int WARNING = 1; // base_logging::WARNING;
const int ERROR = 2; // base_logging::ERROR;
const int FATAL = 3; // base_logging::FATAL;
const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES;
class LogMessage : public std::basic_ostringstream<char> {
public:
LogMessage(const char* fname, int line, int severity);
~LogMessage();
// Returns the minimum log level for VLOG statements.
// E.g., if MinVLogLevel() is 2, then VLOG(2) statements will produce output,
// but VLOG(3) will not. Defaults to 0.
static int64_t MinVLogLevel();
protected:
void GenerateLogMessage();
private:
const char* fname_;
int line_;
int severity_;
};
// LogMessageFatal ensures the process will exit in failure after
// logging this message.
class LogMessageFatal : public LogMessage {
public:
LogMessageFatal(const char* file, int line);
~LogMessageFatal();
};
#define _TF_LOG_INFO \
::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
#define _TF_LOG_WARNING \
::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::WARNING)
#define _TF_LOG_ERROR \
::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::ERROR)
#define _TF_LOG_FATAL \
::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__)
#define _TF_LOG_QFATAL _TF_LOG_FATAL
#define LOG(severity) _TF_LOG_##severity
#define VLOG_IS_ON(lvl) ((lvl) <= LogMessage::MinVLogLevel())
#define VLOG(lvl) \
if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \
LogMessage(__FILE__, __LINE__, ANDROID_LOG_INFO)
void LogPrintF(const int severity, const char* format, ...);
// Support for printf style logging.
#define LOGV(...)
#define LOGD(...)
#define LOGI(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__);
#define LOGW(...) LogPrintF(ANDROID_LOG_INFO, __VA_ARGS__);
#define LOGE(...) LogPrintF(ANDROID_LOG_ERROR, __VA_ARGS__);
#else
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
// Support for printf style logging.
#define LOGV(...)
#define LOGD(...)
#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
#endif
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// NOTE: no native object detectors are currently provided or used by the code
// in this directory. This class remains mainly for historical reasons.
// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
namespace tf_tracking {
// This is here so that the vtable gets created properly.
ObjectDetectorBase::~ObjectDetectorBase() {}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// NOTE: no native object detectors are currently provided or used by the code
// in this directory. This class remains mainly for historical reasons.
// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
// Defines the ObjectDetector class that is the main interface for detecting
// ObjectModelBases in frames.
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
#include <float.h>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#ifdef __RENDER_OPENGL__
#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
#endif
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
namespace tf_tracking {
// Adds BoundingSquares to a vector such that the first square added is centered
// in the position given and of square_size, and the remaining squares are added
// concentrentically, scaling down by scale_factor until the minimum threshold
// size is passed.
// Squares that do not fall completely within image_bounds will not be added.
static inline void FillWithSquares(
const BoundingBox& image_bounds,
const BoundingBox& position,
const float starting_square_size,
const float smallest_square_size,
const float scale_factor,
std::vector<BoundingSquare>* const squares) {
BoundingSquare descriptor_area =
GetCenteredSquare(position, starting_square_size);
SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);
// Use a do/while loop to ensure that at least one descriptor is created.
do {
if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
squares->push_back(descriptor_area);
}
descriptor_area.Scale(scale_factor);
} while (descriptor_area.size_ >= smallest_square_size - EPSILON);
LOGV("Created %zu squares starting from size %.2f to min size %.2f "
"using scale factor: %.2f",
squares->size(), starting_square_size, smallest_square_size,
scale_factor);
}
// Represents a potential detection of a specific ObjectExemplar and Descriptor
// at a specific position in the image.
class Detection {
public:
explicit Detection(const ObjectModelBase* const object_model,
const MatchScore match_score,
const BoundingBox& bounding_box)
: object_model_(object_model),
match_score_(match_score),
bounding_box_(bounding_box) {}
Detection(const Detection& other)
: object_model_(other.object_model_),
match_score_(other.match_score_),
bounding_box_(other.bounding_box_) {}
virtual ~Detection() {}
inline BoundingBox GetObjectBoundingBox() const {
return bounding_box_;
}
inline MatchScore GetMatchScore() const {
return match_score_;
}
inline const ObjectModelBase* GetObjectModel() const {
return object_model_;
}
inline bool Intersects(const Detection& other) {
// Check if any of the four axes separates us, there must be at least one.
return bounding_box_.Intersects(other.bounding_box_);
}
struct Comp {
inline bool operator()(const Detection& a, const Detection& b) const {
return a.match_score_ > b.match_score_;
}
};
// TODO(andrewharp): add accessors to update these instead.
const ObjectModelBase* object_model_;
MatchScore match_score_;
BoundingBox bounding_box_;
};
inline std::ostream& operator<<(std::ostream& stream,
const Detection& detection) {
const BoundingBox actual_area = detection.GetObjectBoundingBox();
stream << actual_area;
return stream;
}
class ObjectDetectorBase {
public:
explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
: config_(config),
image_data_(NULL) {}
virtual ~ObjectDetectorBase();
// Sets the current image data. All calls to ObjectDetector other than
// FillDescriptors use the image data last set.
inline void SetImageData(const ImageData* const image_data) {
image_data_ = image_data;
}
// Main entry point into the detection algorithm.
// Scans the frame for candidates, tweaks them, and fills in the
// given std::vector of Detection objects with acceptable matches.
virtual void Detect(const std::vector<BoundingSquare>& positions,
std::vector<Detection>* const detections) const = 0;
virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;
virtual void DeleteObjectModel(const std::string& name) = 0;
virtual void GetObjectModels(
std::vector<const ObjectModelBase*>* models) const = 0;
// Creates a new ObjectExemplar from the given position in the context of
// the last frame passed to NextFrame.
// Will return null in the case that there's no room for a descriptor to be
// created in the example area, or the example area is not completely
// contained within the frame.
virtual void UpdateModel(const Image<uint8_t>& base_image,
const IntegralImage& integral_image,
const BoundingBox& bounding_box, const bool locked,
ObjectModelBase* model) const = 0;
virtual void Draw() const = 0;
virtual bool AllowSpontaneousDetections() = 0;
protected:
const std::unique_ptr<const ObjectDetectorConfig> config_;
// The latest frame data, upon which all detections will be performed.
// Not owned by this object, just provided for reference by ObjectTracker
// via SetImageData().
const ImageData* image_data_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
};
template <typename ModelType>
class ObjectDetector : public ObjectDetectorBase {
public:
explicit ObjectDetector(const ObjectDetectorConfig* const config)
: ObjectDetectorBase(config) {}
virtual ~ObjectDetector() {
typename std::map<std::string, ModelType*>::const_iterator it =
object_models_.begin();
for (; it != object_models_.end(); ++it) {
ModelType* model = it->second;
delete model;
}
}
virtual void DeleteObjectModel(const std::string& name) {
ModelType* model = object_models_[name];
CHECK_ALWAYS(model != NULL, "Model was null!");
object_models_.erase(name);
SAFE_DELETE(model);
}
virtual void GetObjectModels(
std::vector<const ObjectModelBase*>* models) const {
typename std::map<std::string, ModelType*>::const_iterator it =
object_models_.begin();
for (; it != object_models_.end(); ++it) {
models->push_back(it->second);
}
}
virtual bool AllowSpontaneousDetections() {
return false;
}
protected:
std::map<std::string, ModelType*> object_models_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// NOTE: no native object detectors are currently provided or used by the code
// in this directory. This class remains mainly for historical reasons.
// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
// Contains ObjectModelBase declaration.
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
#ifdef __RENDER_OPENGL__
#include <GLES/gl.h>
#include <GLES/glext.h>
#endif
#include <vector>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#ifdef __RENDER_OPENGL__
#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
#endif
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
namespace tf_tracking {
// The ObjectModelBase class represents all the known appearance information for
// an object. It is not a specific instance of the object in the world,
// but just the general appearance information that enables detection. An
// ObjectModelBase can be reused across multiple-instances of TrackedObjects.
class ObjectModelBase {
public:
ObjectModelBase(const std::string& name) : name_(name) {}
virtual ~ObjectModelBase() {}
// Called when the next step in an ongoing track occurs.
virtual void TrackStep(const BoundingBox& position,
const Image<uint8_t>& image,
const IntegralImage& integral_image,
const bool authoritative) {}
// Called when an object track is lost.
virtual void TrackLost() {}
// Called when an object track is confirmed as legitimate.
virtual void TrackConfirmed() {}
virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0;
virtual MatchScore GetMatchScore(
const BoundingBox& position, const ImageData& image_data) const = 0;
virtual void Draw(float* const depth) const = 0;
inline const std::string& GetName() const {
return name_;
}
protected:
const std::string name_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase);
};
template <typename DetectorType>
class ObjectModel : public ObjectModelBase {
public:
ObjectModel<DetectorType>(const DetectorType* const detector,
const std::string& name)
: ObjectModelBase(name), detector_(detector) {}
protected:
const DetectorType* const detector_;
TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef __RENDER_OPENGL__
#include <GLES/gl.h>
#include <GLES/glext.h>
#endif
#include <cinttypes>
#include <map>
#include <string>
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
ObjectTracker::ObjectTracker(const TrackerConfig* const config,
ObjectDetectorBase* const detector)
: config_(config),
frame_width_(config->image_size.width),
frame_height_(config->image_size.height),
curr_time_(0),
num_frames_(0),
flow_cache_(&config->flow_config),
keypoint_detector_(&config->keypoint_detector_config),
curr_num_frame_pairs_(0),
first_frame_index_(0),
frame1_(new ImageData(frame_width_, frame_height_)),
frame2_(new ImageData(frame_width_, frame_height_)),
detector_(detector),
num_detected_(0) {
for (int i = 0; i < kNumFrames; ++i) {
frame_pairs_[i].Init(-1, -1);
}
}
ObjectTracker::~ObjectTracker() {
for (TrackedObjectMap::iterator iter = objects_.begin();
iter != objects_.end(); iter++) {
TrackedObject* object = iter->second;
SAFE_DELETE(object);
}
}
// Finds the correspondences for all the points in the current pair of frames.
// Stores the results in the given FramePair.
void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
// Keypoints aren't found until they're found.
memset(frame_pair->optical_flow_found_keypoint_, false,
sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
TimeLog("Cleared old found keypoints");
int num_keypoints_found = 0;
// For every keypoint...
for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
if (flow_cache_.FindNewPositionOfPoint(
keypoint1->pos_.x, keypoint1->pos_.y,
&keypoint2->pos_.x, &keypoint2->pos_.y)) {
frame_pair->optical_flow_found_keypoint_[i_feat] = true;
++num_keypoints_found;
}
}
TimeLog("Found correspondences");
LOGV("Found %d of %d keypoint correspondences",
num_keypoints_found, frame_pair->number_of_keypoints_);
}
void ObjectTracker::NextFrame(const uint8_t* const new_frame,
const uint8_t* const uv_frame,
const int64_t timestamp,
const float* const alignment_matrix_2x3) {
IncrementFrameIndex();
LOGV("Received frame %d", num_frames_);
FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
curr_change->Init(curr_time_, timestamp);
CHECK_ALWAYS(curr_time_ < timestamp,
"Timestamp must monotonically increase! Went from %" PRId64
" to %" PRId64 " on frame %d.",
curr_time_, timestamp, num_frames_);
curr_time_ = timestamp;
// Swap the frames.
frame1_.swap(frame2_);
frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
if (detector_.get() != NULL) {
detector_->SetImageData(frame2_.get());
}
flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
if (num_frames_ == 1) {
// This must be the first frame, so abort.
return;
}
if (config_->always_track || objects_.size() > 0) {
LOGV("Tracking %zu targets", objects_.size());
ComputeKeypoints(true);
TimeLog("Keypoints computed!");
FindCorrespondences(curr_change);
TimeLog("Flow computed!");
TrackObjects();
}
TimeLog("Targets tracked!");
if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
DetectTargets();
}
TimeLog("Detected objects.");
}
TrackedObject* ObjectTracker::MaybeAddObject(
const std::string& id, const Image<uint8_t>& source_image,
const BoundingBox& bounding_box, const ObjectModelBase* object_model) {
// Train the detector if this is a new object.
if (objects_.find(id) != objects_.end()) {
return objects_[id];
}
// Need to get a non-const version of the model, or create a new one if it
// wasn't given.
ObjectModelBase* model = NULL;
if (detector_ != NULL) {
// If a detector is registered, then this new object must have a model.
CHECK_ALWAYS(object_model != NULL, "No model given!");
model = detector_->CreateObjectModel(object_model->GetName());
}
TrackedObject* const object =
new TrackedObject(id, source_image, bounding_box, model);
objects_[id] = object;
return object;
}
void ObjectTracker::RegisterNewObjectWithAppearance(
const std::string& id, const uint8_t* const new_frame,
const BoundingBox& bounding_box) {
ObjectModelBase* object_model = NULL;
Image<uint8_t> image(frame_width_, frame_height_);
image.FromArray(new_frame, frame_width_, 1);
if (detector_ != NULL) {
object_model = detector_->CreateObjectModel(id);
CHECK_ALWAYS(object_model != NULL, "Null object model!");
const IntegralImage integral_image(image);
object_model->TrackStep(bounding_box, image, integral_image, true);
}
// Create an object at this position.
CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
if (objects_.find(id) == objects_.end()) {
TrackedObject* const object =
MaybeAddObject(id, image, bounding_box, object_model);
CHECK_ALWAYS(object != NULL, "Object not created!");
}
}
void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
const BoundingBox& bounding_box,
const int64_t timestamp) {
CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
CHECK_ALWAYS(timestamp <= curr_time_,
"Timestamp too great! %" PRId64 " vs %" PRId64, timestamp,
curr_time_);
TrackedObject* const object = GetObject(id);
// Track this bounding box from the past to the current time.
const BoundingBox current_position = TrackBox(bounding_box, timestamp);
object->UpdatePosition(current_position, curr_time_, *frame2_, false);
VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
<< std::endl;
}
void ObjectTracker::SetCurrentPositionOfObject(
const std::string& id, const BoundingBox& bounding_box) {
SetPreviousPositionOfObject(id, bounding_box, curr_time_);
}
void ObjectTracker::ForgetTarget(const std::string& id) {
LOGV("Forgetting object %s", id.c_str());
TrackedObject* const object = GetObject(id);
delete object;
objects_.erase(id);
if (detector_ != NULL) {
detector_->DeleteObjectModel(id);
}
}
int ObjectTracker::GetKeypointsPacked(uint16_t* const out_data,
const float scale) const {
const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
uint16_t* curr_data = out_data;
int num_keypoints = 0;
for (int i = 0; i < change.number_of_keypoints_; ++i) {
if (change.optical_flow_found_keypoint_[i]) {
++num_keypoints;
const Point2f& point1 = change.frame1_keypoints_[i].pos_;
*curr_data++ = RealToFixed115(point1.x * scale);
*curr_data++ = RealToFixed115(point1.y * scale);
const Point2f& point2 = change.frame2_keypoints_[i].pos_;
*curr_data++ = RealToFixed115(point2.x * scale);
*curr_data++ = RealToFixed115(point2.y * scale);
}
}
return num_keypoints;
}
int ObjectTracker::GetKeypoints(const bool only_found,
float* const out_data) const {
int curr_keypoint = 0;
const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
for (int i = 0; i < change.number_of_keypoints_; ++i) {
if (!only_found || change.optical_flow_found_keypoint_[i]) {
const int base = curr_keypoint * kKeypointStep;
out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
out_data[base + 2] =
change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
out_data[base + 5] = change.frame1_keypoints_[i].score_;
out_data[base + 6] = change.frame1_keypoints_[i].type_;
++curr_keypoint;
}
}
LOGV("Got %d keypoints.", curr_keypoint);
return curr_keypoint;
}
BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
const FramePair& frame_pair) const {
float translation_x;
float translation_y;
float scale_x;
float scale_y;
BoundingBox tracked_box(region);
frame_pair.AdjustBox(
tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
tracked_box.Shift(Point2f(translation_x, translation_y));
if (scale_x > 0 && scale_y > 0) {
tracked_box.Scale(scale_x, scale_y);
}
return tracked_box;
}
BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
const int64_t timestamp) const {
CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
// Anything that ended before the requested timestamp is of no concern to us.
bool found_it = false;
int num_frames_back = -1;
for (int i = 0; i < curr_num_frame_pairs_; ++i) {
const FramePair& frame_pair =
frame_pairs_[GetNthIndexFromEnd(i)];
if (frame_pair.end_time_ <= timestamp) {
num_frames_back = i - 1;
if (num_frames_back > 0) {
LOGV("Went %d out of %d frames before finding frame. (index: %d)",
num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
}
found_it = true;
break;
}
}
if (!found_it) {
LOGW("History did not go back far enough! %" PRId64 " vs %" PRId64,
frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
frame_pairs_[GetNthIndexFromStart(0)].end_time_,
frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
}
// Loop over all the frames in the queue, tracking the accumulated delta
// of the point from frame to frame. It's possible the point could
// go out of frame, but keep tracking as best we can, using points near
// the edge of the screen where it went out of bounds.
BoundingBox tracked_box(region);
for (int i = num_frames_back; i >= 0; --i) {
const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
tracked_box = TrackBox(tracked_box, frame_pair);
}
return tracked_box;
}
// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
// 3d transformation matrix.
inline void Convert3x3To4x4(
const float* const in_matrix, float* const out_matrix) {
// X
out_matrix[0] = in_matrix[0];
out_matrix[1] = in_matrix[3];
out_matrix[2] = 0.0f;
out_matrix[3] = 0.0f;
// Y
out_matrix[4] = in_matrix[1];
out_matrix[5] = in_matrix[4];
out_matrix[6] = 0.0f;
out_matrix[7] = 0.0f;
// Z
out_matrix[8] = 0.0f;
out_matrix[9] = 0.0f;
out_matrix[10] = 1.0f;
out_matrix[11] = 0.0f;
// Translation
out_matrix[12] = in_matrix[2];
out_matrix[13] = in_matrix[5];
out_matrix[14] = 0.0f;
out_matrix[15] = 1.0f;
}
void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
const float* const frame_to_canvas) const {
#ifdef __RENDER_OPENGL__
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
glMatrixMode(GL_PROJECTION);
glLoadIdentity();
glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
// To make Y go the right direction (0 at top of frame).
glScalef(1.0f, -1.0f, 1.0f);
glTranslatef(0.0f, -canvas_height, 0.0f);
glMatrixMode(GL_MODELVIEW);
glLoadIdentity();
glPushMatrix();
// Apply the frame to canvas transformation.
static GLfloat transformation[16];
Convert3x3To4x4(frame_to_canvas, transformation);
glMultMatrixf(transformation);
// Draw tracked object bounding boxes.
for (TrackedObjectMap::const_iterator iter = objects_.begin();
iter != objects_.end(); ++iter) {
TrackedObject* tracked_object = iter->second;
tracked_object->Draw();
}
static const bool kRenderDebugPyramid = false;
if (kRenderDebugPyramid) {
glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
}
}
static const bool kRenderDebugDerivative = false;
if (kRenderDebugDerivative) {
glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
for (int i = 0; i < kNumPyramidLevels; ++i) {
const Image<int32_t>& dx = *frame1_->GetSpatialX(i);
Image<uint8_t> render_image(dx.GetWidth(), dx.GetHeight());
for (int y = 0; y < dx.GetHeight(); ++y) {
const int32_t* dx_ptr = dx[y];
uint8_t* dst_ptr = render_image[y];
for (int x = 0; x < dx.GetWidth(); ++x) {
*dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
}
}
Sprite(render_image).Draw();
}
}
if (detector_ != NULL) {
glDisable(GL_CULL_FACE);
detector_->Draw();
}
glPopMatrix();
#endif
}
static void AddQuadrants(const BoundingBox& box,
std::vector<BoundingBox>* boxes) {
const Point2f center = box.GetCenter();
float x1 = box.left_;
float x2 = center.x;
float x3 = box.right_;
float y1 = box.top_;
float y2 = center.y;
float y3 = box.bottom_;
// Upper left.
boxes->push_back(BoundingBox(x1, y1, x2, y2));
// Upper right.
boxes->push_back(BoundingBox(x2, y1, x3, y2));
// Bottom left.
boxes->push_back(BoundingBox(x1, y2, x2, y3));
// Bottom right.
boxes->push_back(BoundingBox(x2, y2, x3, y3));
// Whole thing.
boxes->push_back(box);
}
void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
std::vector<BoundingBox> boxes;
for (TrackedObjectMap::iterator object_iter = objects_.begin();
object_iter != objects_.end(); ++object_iter) {
BoundingBox box = object_iter->second->GetPosition();
box.Scale(config_->object_box_scale_factor_for_features,
config_->object_box_scale_factor_for_features);
AddQuadrants(box, &boxes);
}
AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
}
// Given a vector of detections and a model, simply returns the Detection for
// that model with the highest correlation.
bool ObjectTracker::GetBestObjectForDetection(
const Detection& detection, TrackedObject** match) const {
TrackedObject* best_match = NULL;
float best_overlap = -FLT_MAX;
LOGV("Looking for matches in %zu objects!", objects_.size());
for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
object_iter != objects_.end(); ++object_iter) {
TrackedObject* const tracked_object = object_iter->second;
const float overlap = tracked_object->GetPosition().PascalScore(
detection.GetObjectBoundingBox());
if (!detector_->AllowSpontaneousDetections() &&
(detection.GetObjectModel() != tracked_object->GetModel())) {
if (overlap > 0.0f) {
return false;
}
continue;
}
const float jump_distance =
(tracked_object->GetPosition().GetCenter() -
detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
const float allowed_distance =
tracked_object->GetAllowableDistanceSquared();
LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
jump_distance, allowed_distance, overlap);
// TODO(andrewharp): No need to do this verification twice, eliminate
// one of the score checks (the other being in OnDetection).
if (jump_distance < allowed_distance &&
overlap > best_overlap &&
tracked_object->GetMatchScore() + kMatchScoreBuffer <
detection.GetMatchScore()) {
best_match = tracked_object;
best_overlap = overlap;
} else if (overlap > 0.0f) {
return false;
}
}
*match = best_match;
return true;
}
void ObjectTracker::ProcessDetections(
std::vector<Detection>* const detections) {
LOGV("Initial detection done, iterating over %zu detections now.",
detections->size());
const bool spontaneous_detections_allowed =
detector_->AllowSpontaneousDetections();
for (std::vector<Detection>::const_iterator it = detections->begin();
it != detections->end(); ++it) {
const Detection& detection = *it;
SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
"Frame does not contain bounding box!");
TrackedObject* best_match = NULL;
const bool no_collisions =
GetBestObjectForDetection(detection, &best_match);
// Need to get a non-const version of the model, or create a new one if it
// wasn't given.
ObjectModelBase* model =
const_cast<ObjectModelBase*>(detection.GetObjectModel());
if (best_match != NULL) {
if (model != best_match->GetModel()) {
CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
"Model for object changed but spontaneous detections not allowed!");
}
best_match->OnDetection(model,
detection.GetObjectBoundingBox(),
detection.GetMatchScore(),
curr_time_, *frame2_);
} else if (no_collisions && spontaneous_detections_allowed) {
if (detection.GetMatchScore() > kMinimumMatchScore) {
LOGV("No match, adding it!");
const ObjectModelBase* model = detection.GetObjectModel();
std::ostringstream ss;
// TODO(andrewharp): Generate this in a more general fashion.
ss << "hand_" << num_detected_++;
std::string object_name = ss.str();
MaybeAddObject(object_name, *frame2_->GetImage(),
detection.GetObjectBoundingBox(), model);
}
}
}
}
void ObjectTracker::DetectTargets() {
// Detect all object model types that we're currently tracking.
std::vector<const ObjectModelBase*> object_models;
detector_->GetObjectModels(&object_models);
if (object_models.size() == 0) {
LOGV("No objects to search for, aborting.");
return;
}
LOGV("Trying to detect %zu models", object_models.size());
LOGV("Creating test vector!");
std::vector<BoundingSquare> positions;
for (TrackedObjectMap::iterator object_iter = objects_.begin();
object_iter != objects_.end(); ++object_iter) {
TrackedObject* const tracked_object = object_iter->second;
#if DEBUG_PREDATOR
positions.push_back(GetCenteredSquare(
frame2_->GetImage()->GetContainingBox(), 32.0f));
#else
const BoundingBox& position = tracked_object->GetPosition();
const float square_size = MAX(
kScanMinSquareSize / (kLastKnownPositionScaleFactor *
kLastKnownPositionScaleFactor),
MIN(position.GetWidth(),
position.GetHeight())) / kLastKnownPositionScaleFactor;
FillWithSquares(frame2_->GetImage()->GetContainingBox(),
tracked_object->GetPosition(),
square_size,
kScanMinSquareSize,
kLastKnownPositionScaleFactor,
&positions);
}
#endif
LOGV("Created test vector!");
std::vector<Detection> detections;
LOGV("Detecting!");
detector_->Detect(positions, &detections);
LOGV("Found %zu detections", detections.size());
TimeLog("Finished detection.");
ProcessDetections(&detections);
TimeLog("iterated over detections");
LOGV("Done detecting!");
}
void ObjectTracker::TrackObjects() {
// TODO(andrewharp): Correlation should be allowed to remove objects too.
const bool automatic_removal_allowed = detector_.get() != NULL ?
detector_->AllowSpontaneousDetections() : false;
LOGV("Tracking %zu objects!", objects_.size());
std::vector<std::string> dead_objects;
for (TrackedObjectMap::iterator iter = objects_.begin();
iter != objects_.end(); iter++) {
TrackedObject* object = iter->second;
const BoundingBox tracked_position = TrackBox(
object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
if (automatic_removal_allowed &&
object->GetNumConsecutiveFramesBelowThreshold() >
kMaxNumDetectionFailures * 5) {
dead_objects.push_back(iter->first);
}
}
if (detector_ != NULL && automatic_removal_allowed) {
for (std::vector<std::string>::iterator iter = dead_objects.begin();
iter != dead_objects.end(); iter++) {
LOGE("Removing object! %s", iter->c_str());
ForgetTarget(*iter);
}
}
TimeLog("Tracked all objects.");
LOGV("%zu objects tracked!", objects_.size());
}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
#include <map>
#include <string>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
namespace tf_tracking {
typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
inline std::ostream& operator<<(std::ostream& stream,
const TrackedObjectMap& map) {
for (TrackedObjectMap::const_iterator iter = map.begin();
iter != map.end(); ++iter) {
const TrackedObject& tracked_object = *iter->second;
const std::string& key = iter->first;
stream << key << ": " << tracked_object;
}
return stream;
}
// ObjectTracker is the highest-level class in the tracking/detection framework.
// It handles basic image processing, keypoint detection, keypoint tracking,
// object tracking, and object detection/relocalization.
class ObjectTracker {
public:
ObjectTracker(const TrackerConfig* const config,
ObjectDetectorBase* const detector);
virtual ~ObjectTracker();
virtual void NextFrame(const uint8_t* const new_frame,
const int64_t timestamp,
const float* const alignment_matrix_2x3) {
NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
}
// Called upon the arrival of a new frame of raw data.
// Does all image processing, keypoint detection, and object
// tracking/detection for registered objects.
// Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
// represents the main transformation that has happened between the last
// and the current frame.
// Argument align_level is the pyramid level (where 0 == finest) that
// the matrix is valid for.
virtual void NextFrame(const uint8_t* const new_frame,
const uint8_t* const uv_frame, const int64_t timestamp,
const float* const alignment_matrix_2x3);
virtual void RegisterNewObjectWithAppearance(const std::string& id,
const uint8_t* const new_frame,
const BoundingBox& bounding_box);
// Updates the position of a tracked object, given that it was known to be at
// a certain position at some point in the past.
virtual void SetPreviousPositionOfObject(const std::string& id,
const BoundingBox& bounding_box,
const int64_t timestamp);
// Sets the current position of the object in the most recent frame provided.
virtual void SetCurrentPositionOfObject(const std::string& id,
const BoundingBox& bounding_box);
// Tells the ObjectTracker to stop tracking a target.
void ForgetTarget(const std::string& id);
// Fills the given out_data buffer with the latest detected keypoint
// correspondences, first scaled by scale_factor (to adjust for downsampling
// that may have occurred elsewhere), then packed in a fixed-point format.
int GetKeypointsPacked(uint16_t* const out_data,
const float scale_factor) const;
// Copy the keypoint arrays after computeFlow is called.
// out_data should be at least kMaxKeypoints * kKeypointStep long.
// Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
// where N is the number of keypoints tracked. N is returned as the result.
int GetKeypoints(const bool only_found, float* const out_data) const;
// Returns the current position of a box, given that it was at a certain
// position at the given time.
BoundingBox TrackBox(const BoundingBox& region,
const int64_t timestamp) const;
// Returns the number of frames that have been passed to NextFrame().
inline int GetNumFrames() const {
return num_frames_;
}
inline bool HaveObject(const std::string& id) const {
return objects_.find(id) != objects_.end();
}
// Returns the TrackedObject associated with the given id.
inline const TrackedObject* GetObject(const std::string& id) const {
TrackedObjectMap::const_iterator iter = objects_.find(id);
CHECK_ALWAYS(iter != objects_.end(),
"Unknown object key! \"%s\"", id.c_str());
TrackedObject* const object = iter->second;
return object;
}
// Returns the TrackedObject associated with the given id.
inline TrackedObject* GetObject(const std::string& id) {
TrackedObjectMap::iterator iter = objects_.find(id);
CHECK_ALWAYS(iter != objects_.end(),
"Unknown object key! \"%s\"", id.c_str());
TrackedObject* const object = iter->second;
return object;
}
bool IsObjectVisible(const std::string& id) const {
SCHECK(HaveObject(id), "Don't have this object.");
const TrackedObject* object = GetObject(id);
return object->IsVisible();
}
virtual void Draw(const int canvas_width, const int canvas_height,
const float* const frame_to_canvas) const;
protected:
// Creates a new tracked object at the given position.
// If an object model is provided, then that model will be associated with the
// object. If not, a new model may be created from the appearance at the
// initial position and registered with the object detector.
virtual TrackedObject* MaybeAddObject(const std::string& id,
const Image<uint8_t>& image,
const BoundingBox& bounding_box,
const ObjectModelBase* object_model);
// Find the keypoints in the frame before the current frame.
// If only one frame exists, keypoints will be found in that frame.
void ComputeKeypoints(const bool cached_ok = false);
// Finds the correspondences for all the points in the current pair of frames.
// Stores the results in the given FramePair.
void FindCorrespondences(FramePair* const curr_change) const;
inline int GetNthIndexFromEnd(const int offset) const {
return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
}
BoundingBox TrackBox(const BoundingBox& region,
const FramePair& frame_pair) const;
inline void IncrementFrameIndex() {
// Move the current framechange index up.
++num_frames_;
++curr_num_frame_pairs_;
// If we've got too many, push up the start of the queue.
if (curr_num_frame_pairs_ > kNumFrames) {
first_frame_index_ = GetNthIndexFromStart(1);
--curr_num_frame_pairs_;
}
}
inline int GetNthIndexFromStart(const int offset) const {
SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
"Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_);
return (first_frame_index_ + offset) % kNumFrames;
}
void TrackObjects();
const std::unique_ptr<const TrackerConfig> config_;
const int frame_width_;
const int frame_height_;
int64_t curr_time_;
int num_frames_;
TrackedObjectMap objects_;
FlowCache flow_cache_;
KeypointDetector keypoint_detector_;
int curr_num_frame_pairs_;
int first_frame_index_;
std::unique_ptr<ImageData> frame1_;
std::unique_ptr<ImageData> frame2_;
FramePair frame_pairs_[kNumFrames];
std::unique_ptr<ObjectDetectorBase> detector_;
int num_detected_;
private:
void TrackTarget(TrackedObject* const object);
bool GetBestObjectForDetection(
const Detection& detection, TrackedObject** match) const;
void ProcessDetections(std::vector<Detection>* const detections);
void DetectTargets();
// Temp object used in ObjectTracker::CreateNewExample.
mutable std::vector<BoundingSquare> squares;
friend std::ostream& operator<<(std::ostream& stream,
const ObjectTracker& tracker);
TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
};
inline std::ostream& operator<<(std::ostream& stream,
const ObjectTracker& tracker) {
stream << "Frame size: " << tracker.frame_width_ << "x"
<< tracker.frame_height_ << std::endl;
stream << "Num frames: " << tracker.num_frames_ << std::endl;
stream << "Curr time: " << tracker.curr_time_ << std::endl;
const int first_frame_index = tracker.GetNthIndexFromStart(0);
const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
const int last_frame_index = tracker.GetNthIndexFromEnd(0);
const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
stream << "first frame: " << first_frame_index << ","
<< first_frame_pair.end_time_ << " "
<< "last frame: " << last_frame_index << ","
<< last_frame_pair.end_time_ << " diff: "
<< last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
<< std::endl;
stream << "Tracked targets:";
stream << tracker.objects_;
return stream;
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <android/log.h>
#include <jni.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <cstdint>
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
namespace tf_tracking {
#define OBJECT_TRACKER_METHOD(METHOD_NAME) \
Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT
JniLongField object_tracker_field("nativeObjectTracker");
ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
ObjectTracker* const object_tracker =
reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
return object_tracker;
}
void set_object_tracker(JNIEnv* env, jobject thiz,
const ObjectTracker* object_tracker) {
object_tracker_field.set(env, thiz,
reinterpret_cast<intptr_t>(object_tracker));
}
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
jint width, jint height,
jboolean always_track);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
jobject thiz);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
jfloat x2, jfloat y2, jbyteArray frame_data);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
jfloat x2, jfloat y2, jlong timestamp);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
jfloat x2, jfloat y2);
JNIEXPORT
jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
jstring object_id);
JNIEXPORT
jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
jobject thiz,
jstring object_id);
JNIEXPORT
jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
jobject thiz,
jstring object_id);
JNIEXPORT
jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
jobject thiz,
jstring object_id);
JNIEXPORT
jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
jstring object_id);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
jbyteArray y_data,
jbyteArray uv_data,
jlong timestamp,
jfloatArray vg_matrix_2x3);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
jstring object_id);
JNIEXPORT
jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
JNIEnv* env, jobject thiz, jfloat scale_factor);
JNIEXPORT
jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
JNIEnv* env, jobject thiz, jboolean only_found_);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
jfloat position_y1, jfloat position_x2, jfloat position_y2,
jfloatArray delta);
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
jint view_width,
jint view_height,
jfloatArray delta);
JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
jbyteArray input, jint factor, jbyteArray output);
#ifdef __cplusplus
}
#endif
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
jint width, jint height,
jboolean always_track) {
LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
const Size image_size(width, height);
TrackerConfig* const tracker_config = new TrackerConfig(image_size);
tracker_config->always_track = always_track;
// XXX detector
ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
set_object_tracker(env, thiz, tracker);
LOGI("Initialized!");
CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
"Failure to set hand tracker!");
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
jobject thiz) {
delete get_object_tracker(env, thiz);
set_object_tracker(env, thiz, NULL);
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
jfloat x2, jfloat y2, jbyteArray frame_data) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
x2, y2);
jboolean iCopied = JNI_FALSE;
// Copy image into currFrame.
jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
BoundingBox bounding_box(x1, y1, x2, y2);
get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box);
env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
env->ReleaseStringUTFChars(object_id, id_str);
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
jfloat x2, jfloat y2, jlong timestamp) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
LOGI(
"Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
" at time %lld",
id_str, x1, y1, x2, y2, static_cast<long long>(timestamp));
get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
id_str, BoundingBox(x1, y1, x2, y2), timestamp);
env->ReleaseStringUTFChars(object_id, id_str);
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
jfloat x2, jfloat y2) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
x2, y2);
get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
id_str, BoundingBox(x1, y1, x2, y2));
env->ReleaseStringUTFChars(object_id, id_str);
}
JNIEXPORT
jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
jstring object_id) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
env->ReleaseStringUTFChars(object_id, id_str);
return haveObject;
}
JNIEXPORT
jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
jobject thiz,
jstring object_id) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
env->ReleaseStringUTFChars(object_id, id_str);
return visible;
}
JNIEXPORT
jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
jobject thiz,
jstring object_id) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
const TrackedObject* const object =
get_object_tracker(env, thiz)->GetObject(id_str);
env->ReleaseStringUTFChars(object_id, id_str);
jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
return model_name;
}
JNIEXPORT
jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
jobject thiz,
jstring object_id) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
const float correlation =
get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
env->ReleaseStringUTFChars(object_id, id_str);
return correlation;
}
JNIEXPORT
jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
jstring object_id) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
const float match_score =
get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
env->ReleaseStringUTFChars(object_id, id_str);
return match_score;
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
jboolean iCopied = JNI_FALSE;
const char* const id_str = env->GetStringUTFChars(object_id, 0);
const BoundingBox bounding_box =
get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
env->ReleaseStringUTFChars(object_id, id_str);
jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
env->ReleaseFloatArrayElements(rect_array, rect, 0);
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
jbyteArray y_data,
jbyteArray uv_data,
jlong timestamp,
jfloatArray vg_matrix_2x3) {
TimeLog("Starting object tracker");
jboolean iCopied = JNI_FALSE;
float vision_gyro_matrix_array[6];
jfloat* jmat = NULL;
if (vg_matrix_2x3 != NULL) {
// Copy the alignment matrix into a float array.
jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
for (int i = 0; i < 6; ++i) {
vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
}
}
// Copy image into currFrame.
jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
jbyte* uv_pixels =
uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
TimeLog("Got elements");
// Add the frame to the object tracker object.
get_object_tracker(env, thiz)->NextFrame(
reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels),
timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
if (uv_data != NULL) {
env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
}
if (vg_matrix_2x3 != NULL) {
env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
}
TimeLog("Released elements");
PrintTimeLog();
ResetTimeLog();
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
jstring object_id) {
const char* const id_str = env->GetStringUTFChars(object_id, 0);
get_object_tracker(env, thiz)->ForgetTarget(id_str);
env->ReleaseStringUTFChars(object_id, id_str);
}
JNIEXPORT
jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
JNIEnv* env, jobject thiz, jboolean only_found) {
jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
const int number_of_keypoints =
get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
// Create and return the array that will be passed back to Java.
jfloatArray keypoints =
env->NewFloatArray(number_of_keypoints * kKeypointStep);
if (keypoints == NULL) {
LOGE("null array!");
return NULL;
}
env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
keypoint_arr);
return keypoints;
}
JNIEXPORT
jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
JNIEnv* env, jobject thiz, jfloat scale_factor) {
// 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint.
const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2;
jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
const int number_of_keypoints =
get_object_tracker(env, thiz)->GetKeypointsPacked(
reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor);
// Create and return the array that will be passed back to Java.
jbyteArray keypoints =
env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
if (keypoints == NULL) {
LOGE("null array!");
return NULL;
}
env->SetByteArrayRegion(
keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
return keypoints;
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
jfloat position_y1, jfloat position_x2, jfloat position_y2,
jfloatArray delta) {
jfloat point_arr[4];
const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
BoundingBox(position_x1, position_y1, position_x2, position_y2),
timestamp);
new_position.CopyToArray(point_arr);
env->SetFloatArrayRegion(delta, 0, 4, point_arr);
}
JNIEXPORT
void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
JNIEnv* env, jobject thiz, jint view_width, jint view_height,
jfloatArray frame_to_canvas_arr) {
ObjectTracker* object_tracker = get_object_tracker(env, thiz);
if (object_tracker != NULL) {
jfloat* frame_to_canvas =
env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
object_tracker->Draw(view_width, view_height, frame_to_canvas);
env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
JNI_ABORT);
}
}
JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
jbyteArray input, jint factor, jbyteArray output) {
if (input == NULL || output == NULL) {
LOGW("Received null arrays, hopefully this is a test!");
return;
}
jbyte* const input_array = env->GetByteArrayElements(input, 0);
jbyte* const output_array = env->GetByteArrayElements(output, 0);
{
tf_tracking::Image<uint8_t> full_image(
width, height, reinterpret_cast<uint8_t*>(input_array), false);
const int new_width = (width + factor - 1) / factor;
const int new_height = (height + factor - 1) / factor;
tf_tracking::Image<uint8_t> downsampled_image(
new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false);
downsampled_image.DownsampleAveraged(
reinterpret_cast<uint8_t*>(input_array), row_stride, factor);
}
env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
env->ReleaseByteArrayElements(output, output_array, 0);
}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <math.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
namespace tf_tracking {
OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config)
: config_(config),
frame1_(NULL),
frame2_(NULL),
working_size_(config->image_size) {}
void OpticalFlow::NextFrame(const ImageData* const image_data) {
// Special case for the first frame: make sure the image ends up in
// frame1_ so that keypoint detection can be done on it if desired.
frame1_ = (frame1_ == NULL) ? image_data : frame2_;
frame2_ = image_data;
}
// Static heart of the optical flow computation.
// Lucas Kanade algorithm.
bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8_t>& img_I,
const Image<uint8_t>& img_J,
const Image<int32_t>& I_x,
const Image<int32_t>& I_y, const float p_x,
const float p_y, float* out_g_x,
float* out_g_y) {
float g_x = *out_g_x;
float g_y = *out_g_y;
// Get values for frame 1. They remain constant through the inner
// iteration loop.
float vals_I[kFlowArraySize];
float vals_I_x[kFlowArraySize];
float vals_I_y[kFlowArraySize];
const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1;
const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize);
#if USE_FIXED_POINT_FLOW
const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1;
const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1;
#else
const float real_x_max = I_x.width_less_one_ - EPSILON;
const float real_y_max = I_x.height_less_one_ - EPSILON;
#endif
// Get the window around the original point.
const float src_left_real = p_x - kWindowSizeFloat;
const float src_top_real = p_y - kWindowSizeFloat;
float* vals_I_ptr = vals_I;
float* vals_I_x_ptr = vals_I_x;
float* vals_I_y_ptr = vals_I_y;
#if USE_FIXED_POINT_FLOW
// Source integer coordinates.
const int src_left_fixed = RealToFixed1616(src_left_real);
const int src_top_fixed = RealToFixed1616(src_top_real);
for (int y = 0; y < kPatchSize; ++y) {
const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max);
for (int x = 0; x < kPatchSize; ++x) {
const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max);
*vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y);
*vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
*vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
}
}
#else
for (int y = 0; y < kPatchSize; ++y) {
const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max);
for (int x = 0; x < kPatchSize; ++x) {
const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max);
*vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos);
*vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos);
*vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos);
}
}
#endif
// Compute the spatial gradient matrix about point p.
float G[] = { 0, 0, 0, 0 };
CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G);
// Find the inverse of G.
float G_inv[4];
if (!Invert2x2(G, G_inv)) {
return false;
}
#if NORMALIZE
const float mean_I = ComputeMean(vals_I, kFlowArraySize);
const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I);
#endif
// Iterate kNumIterations times or until we converge.
for (int iteration = 0; iteration < kNumIterations; ++iteration) {
// Get values for frame 2.
float vals_J[kFlowArraySize];
// Get the window around the destination point.
const float left_real = p_x + g_x - kWindowSizeFloat;
const float top_real = p_y + g_y - kWindowSizeFloat;
float* vals_J_ptr = vals_J;
#if USE_FIXED_POINT_FLOW
// The top-left sub-pixel is set for the current iteration (in 16:16
// fixed). This is constant over one iteration.
const int left_fixed = RealToFixed1616(left_real);
const int top_fixed = RealToFixed1616(top_real);
for (int win_y = 0; win_y < kPatchSize; ++win_y) {
const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max);
for (int win_x = 0; win_x < kPatchSize; ++win_x) {
const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max);
*vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y);
}
}
#else
for (int win_y = 0; win_y < kPatchSize; ++win_y) {
const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max);
for (int win_x = 0; win_x < kPatchSize; ++win_x) {
const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max);
*vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos);
}
}
#endif
#if NORMALIZE
const float mean_J = ComputeMean(vals_J, kFlowArraySize);
const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J);
// TODO(andrewharp): Probably better to completely detect and handle the
// "corner case" where the patch is fully outside the image diagonally.
const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f;
#endif
// Compute image mismatch vector.
float b_x = 0.0f;
float b_y = 0.0f;
vals_I_ptr = vals_I;
vals_J_ptr = vals_J;
vals_I_x_ptr = vals_I_x;
vals_I_y_ptr = vals_I_y;
for (int win_y = 0; win_y < kPatchSize; ++win_y) {
for (int win_x = 0; win_x < kPatchSize; ++win_x) {
#if NORMALIZE
// Normalized Image difference.
const float dI =
(*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio;
#else
const float dI = *vals_I_ptr++ - *vals_J_ptr++;
#endif
b_x += dI * *vals_I_x_ptr++;
b_y += dI * *vals_I_y_ptr++;
}
}
// Optical flow... solve n = G^-1 * b
const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y);
const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y);
// Update best guess with residual displacement from this level and
// iteration.
g_x += n_x;
g_y += n_y;
// LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y);
// Abort early if we're already below the threshold.
if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) {
break;
}
} // Iteration.
// Copy value back into output.
*out_g_x = g_x;
*out_g_y = g_y;
return true;
}
// Pointwise flow using translational 2dof ESM.
bool OpticalFlow::FindFlowAtPoint_ESM(
const Image<uint8_t>& img_I, const Image<uint8_t>& img_J,
const Image<int32_t>& I_x, const Image<int32_t>& I_y,
const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x,
const float p_y, float* out_g_x, float* out_g_y) {
float g_x = *out_g_x;
float g_y = *out_g_y;
const float area_inv = 1.0f / static_cast<float>(kFlowArraySize);
// Get values for frame 1. They remain constant through the inner
// iteration loop.
uint8_t vals_I[kFlowArraySize];
uint8_t vals_J[kFlowArraySize];
int16_t src_gradient_x[kFlowArraySize];
int16_t src_gradient_y[kFlowArraySize];
// TODO(rspring): try out the IntegerPatchAlign() method once
// the code for that is in ../common.
const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize);
const int src_left_fixed = RealToFixed1616(p_x - wsize_float);
const int src_top_fixed = RealToFixed1616(p_y - wsize_float);
const int patch_size = 2 * kFlowIntegrationWindowSize + 1;
// Create the keypoint template patch from a subpixel location.
if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
patch_size, patch_size, vals_I) ||
!I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
patch_size, patch_size,
src_gradient_x) ||
!I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
patch_size, patch_size,
src_gradient_y)) {
return false;
}
int bright_offset = 0;
int sum_diff = 0;
// The top-left sub-pixel is set for the current iteration (in 16:16 fixed).
// This is constant over one iteration.
int left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
int top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
// The truncated version gives the most top-left pixel that is used.
int left_trunc = left_fixed >> 16;
int top_trunc = top_fixed >> 16;
// Compute an initial brightness offset.
if (kDoBrightnessNormalize &&
left_trunc >= 0 && top_trunc >= 0 &&
(left_trunc + patch_size) < img_J.width_less_one_ &&
(top_trunc + patch_size) < img_J.height_less_one_) {
int templ_index = 0;
const uint8_t* j_row = img_J[top_trunc] + left_trunc;
const int j_stride = img_J.stride();
for (int y = 0; y < patch_size; ++y, j_row += j_stride) {
for (int x = 0; x < patch_size; ++x) {
sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++];
}
}
bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv);
}
// Iterate kNumIterations times or until we go out of image.
for (int iteration = 0; iteration < kNumIterations; ++iteration) {
int jtj[3] = { 0, 0, 0 };
int jtr[2] = { 0, 0 };
sum_diff = 0;
// Extract the target image values.
// Extract the gradient from the target image patch and accumulate to
// the gradient of the source image patch.
if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed,
patch_size, patch_size,
vals_J)) {
break;
}
const uint8_t* templ_row = vals_I;
const uint8_t* extract_row = vals_J;
const int16_t* src_dx_row = src_gradient_x;
const int16_t* src_dy_row = src_gradient_y;
for (int y = 0; y < patch_size; ++y, templ_row += patch_size,
src_dx_row += patch_size, src_dy_row += patch_size,
extract_row += patch_size) {
const int fp_y = top_fixed + (y << 16);
for (int x = 0; x < patch_size; ++x) {
const int fp_x = left_fixed + (x << 16);
int32_t target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y);
int32_t target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y);
// Combine the two Jacobians.
// Right-shift by one to account for the fact that we add
// two Jacobians.
int32_t dx = (src_dx_row[x] + target_dx) >> 1;
int32_t dy = (src_dy_row[x] + target_dy) >> 1;
// The current residual b - h(q) == extracted - (template + offset)
int32_t diff = static_cast<int32_t>(extract_row[x]) -
static_cast<int32_t>(templ_row[x]) - bright_offset;
jtj[0] += dx * dx;
jtj[1] += dx * dy;
jtj[2] += dy * dy;
jtr[0] += dx * diff;
jtr[1] += dy * diff;
sum_diff += diff;
}
}
const float jtr1_float = static_cast<float>(jtr[0]);
const float jtr2_float = static_cast<float>(jtr[1]);
// Add some baseline stability to the system.
jtj[0] += kEsmRegularizer;
jtj[2] += kEsmRegularizer;
const int64_t prod1 = static_cast<int64_t>(jtj[0]) * jtj[2];
const int64_t prod2 = static_cast<int64_t>(jtj[1]) * jtj[1];
// One ESM step.
const float jtj_1[4] = { static_cast<float>(jtj[2]),
static_cast<float>(-jtj[1]),
static_cast<float>(-jtj[1]),
static_cast<float>(jtj[0]) };
const double det_inv = 1.0 / static_cast<double>(prod1 - prod2);
g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float);
g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float);
if (kDoBrightnessNormalize) {
bright_offset +=
static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f);
}
// Update top left position.
left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
left_trunc = left_fixed >> 16;
top_trunc = top_fixed >> 16;
// Abort iterations if we go out of borders.
if (left_trunc < 0 || top_trunc < 0 ||
(left_trunc + patch_size) >= J_x.width_less_one_ ||
(top_trunc + patch_size) >= J_y.height_less_one_) {
break;
}
} // Iteration.
// Copy value back into output.
*out_g_x = g_x;
*out_g_y = g_y;
return true;
}
bool OpticalFlow::FindFlowAtPointReversible(
const int level, const float u_x, const float u_y,
const bool reverse_flow,
float* flow_x, float* flow_y) const {
const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_;
const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_;
// Images I (prev) and J (next).
const Image<uint8_t>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2);
const Image<uint8_t>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2);
// Computed gradients.
const Image<int32_t>& I_x = *frame_a.GetSpatialX(level);
const Image<int32_t>& I_y = *frame_a.GetSpatialY(level);
const Image<int32_t>& J_x = *frame_b.GetSpatialX(level);
const Image<int32_t>& J_y = *frame_b.GetSpatialY(level);
// Shrink factor from original.
const float shrink_factor = (1 << level);
// Image position vector (p := u^l), scaled for this level.
const float scaled_p_x = u_x / shrink_factor;
const float scaled_p_y = u_y / shrink_factor;
float scaled_flow_x = *flow_x / shrink_factor;
float scaled_flow_y = *flow_y / shrink_factor;
// LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level,
// scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y);
const bool success = kUseEsm ?
FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y,
scaled_p_x, scaled_p_y,
&scaled_flow_x, &scaled_flow_y) :
FindFlowAtPoint_LK(img_I, img_J, I_x, I_y,
scaled_p_x, scaled_p_y,
&scaled_flow_x, &scaled_flow_y);
*flow_x = scaled_flow_x * shrink_factor;
*flow_y = scaled_flow_y * shrink_factor;
return success;
}
bool OpticalFlow::FindFlowAtPointSingleLevel(
const int level,
const float u_x, const float u_y,
const bool filter_by_fb_error,
float* flow_x, float* flow_y) const {
if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) {
return false;
}
if (filter_by_fb_error) {
const float new_position_x = u_x + *flow_x;
const float new_position_y = u_y + *flow_y;
float reverse_flow_x = 0.0f;
float reverse_flow_y = 0.0f;
// Now find the backwards flow and confirm it lines up with the original
// starting point.
if (!FindFlowAtPointReversible(level, new_position_x, new_position_y,
true,
&reverse_flow_x, &reverse_flow_y)) {
LOGE("Backward error!");
return false;
}
const float discrepancy_length =
sqrtf(Square(*flow_x + reverse_flow_x) +
Square(*flow_y + reverse_flow_y));
const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y));
return discrepancy_length <
(kMaxForwardBackwardErrorAllowed * flow_length);
}
return true;
}
// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm.
// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details.
bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y,
const bool filter_by_fb_error,
float* flow_x, float* flow_y) const {
const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment,
kNumPyramidLevels - kNumCacheLevels);
// For every level in the pyramid, update the coordinates of the best match.
for (int l = max_level - 1; l >= 0; --l) {
if (!FindFlowAtPointSingleLevel(l, u_x, u_y,
filter_by_fb_error, flow_x, flow_y)) {
return false;
}
}
return true;
}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#include "tensorflow/examples/android/jni/object_tracking/config.h"
#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
namespace tf_tracking {
class FlowCache;
// Class encapsulating all the data and logic necessary for performing optical
// flow.
class OpticalFlow {
public:
explicit OpticalFlow(const OpticalFlowConfig* const config);
// Add a new frame to the optical flow. Will update all the non-keypoint
// related member variables.
//
// new_frame should be a buffer of grayscale values, one byte per pixel,
// at the original frame_width and frame_height used to initialize the
// OpticalFlow object. Downsampling will be handled internally.
//
// time_stamp should be a time in milliseconds that later calls to this and
// other methods will be relative to.
void NextFrame(const ImageData* const image_data);
// An implementation of the Lucas-Kanade Optical Flow algorithm.
static bool FindFlowAtPoint_LK(const Image<uint8_t>& img_I,
const Image<uint8_t>& img_J,
const Image<int32_t>& I_x,
const Image<int32_t>& I_y, const float p_x,
const float p_y, float* out_g_x,
float* out_g_y);
// Pointwise flow using translational 2dof ESM.
static bool FindFlowAtPoint_ESM(
const Image<uint8_t>& img_I, const Image<uint8_t>& img_J,
const Image<int32_t>& I_x, const Image<int32_t>& I_y,
const Image<int32_t>& J_x, const Image<int32_t>& J_y, const float p_x,
const float p_y, float* out_g_x, float* out_g_y);
// Finds the flow using a specific level, in either direction.
// If reversed, the coordinates are in the context of the latest
// frame, not the frame before it.
// All coordinates used in parameters are global, not scaled.
bool FindFlowAtPointReversible(
const int level, const float u_x, const float u_y,
const bool reverse_flow,
float* final_x, float* final_y) const;
// Finds the flow using a specific level, filterable by forward-backward
// error. All coordinates used in parameters are global, not scaled.
bool FindFlowAtPointSingleLevel(const int level,
const float u_x, const float u_y,
const bool filter_by_fb_error,
float* flow_x, float* flow_y) const;
// Pyramidal optical-flow using all levels.
bool FindFlowAtPointPyramidal(const float u_x, const float u_y,
const bool filter_by_fb_error,
float* flow_x, float* flow_y) const;
private:
const OpticalFlowConfig* const config_;
const ImageData* frame1_;
const ImageData* frame2_;
// Size of the internally allocated images (after original is downsampled).
const Size working_size_;
TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow);
};
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
#ifdef __RENDER_OPENGL__
#include <GLES/gl.h>
#include <GLES/glext.h>
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
namespace tf_tracking {
// This class encapsulates the logic necessary to load an render image data
// at the same aspect ratio as the original source.
class Sprite {
public:
// Only create Sprites when you have an OpenGl context.
explicit Sprite(const Image<uint8_t>& image) { LoadTexture(image, NULL); }
Sprite(const Image<uint8_t>& image, const BoundingBox* const area) {
LoadTexture(image, area);
}
// Also, try to only delete a Sprite when holding an OpenGl context.
~Sprite() {
glDeleteTextures(1, &texture_);
}
inline int GetWidth() const {
return actual_width_;
}
inline int GetHeight() const {
return actual_height_;
}
// Draw the sprite at 0,0 - original width/height in the current reference
// frame. Any transformations desired must be applied before calling this
// function.
void Draw() const {
const float float_width = static_cast<float>(actual_width_);
const float float_height = static_cast<float>(actual_height_);
// Where it gets rendered to.
const float vertices[] = { 0.0f, 0.0f, 0.0f,
0.0f, float_height, 0.0f,
float_width, 0.0f, 0.0f,
float_width, float_height, 0.0f,
};
// The coordinates the texture gets drawn from.
const float max_x = float_width / texture_width_;
const float max_y = float_height / texture_height_;
const float textureVertices[] = {
0, 0,
0, max_y,
max_x, 0,
max_x, max_y,
};
glEnable(GL_TEXTURE_2D);
glBindTexture(GL_TEXTURE_2D, texture_);
glEnableClientState(GL_VERTEX_ARRAY);
glEnableClientState(GL_TEXTURE_COORD_ARRAY);
glVertexPointer(3, GL_FLOAT, 0, vertices);
glTexCoordPointer(2, GL_FLOAT, 0, textureVertices);
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
glDisableClientState(GL_VERTEX_ARRAY);
glDisableClientState(GL_TEXTURE_COORD_ARRAY);
}
private:
inline int GetNextPowerOfTwo(const int number) const {
int power_of_two = 1;
while (power_of_two < number) {
power_of_two *= 2;
}
return power_of_two;
}
// TODO(andrewharp): Allow sprites to have their textures reloaded.
void LoadTexture(const Image<uint8_t>& texture_source,
const BoundingBox* const area) {
glEnable(GL_TEXTURE_2D);
glGenTextures(1, &texture_);
glBindTexture(GL_TEXTURE_2D, texture_);
int left = 0;
int top = 0;
if (area != NULL) {
// If a sub-region was provided to pull the texture from, use that.
left = area->left_;
top = area->top_;
actual_width_ = area->GetWidth();
actual_height_ = area->GetHeight();
} else {
actual_width_ = texture_source.GetWidth();
actual_height_ = texture_source.GetHeight();
}
// The textures must be a power of two, so find the sizes that are large
// enough to contain the image data.
texture_width_ = GetNextPowerOfTwo(actual_width_);
texture_height_ = GetNextPowerOfTwo(actual_height_);
bool allocated_data = false;
uint8_t* texture_data;
// Except in the lucky case where we're not using a sub-region of the
// original image AND the source data has dimensions that are power of two,
// care must be taken to copy data at the appropriate source and destination
// strides so that the final block can be copied directly into texture
// memory.
// TODO(andrewharp): Figure out if data can be pulled directly from the
// source image with some alignment modifications.
if (left != 0 || top != 0 ||
actual_width_ != texture_source.GetWidth() ||
actual_height_ != texture_source.GetHeight()) {
texture_data = new uint8_t[actual_width_ * actual_height_];
for (int y = 0; y < actual_height_; ++y) {
memcpy(texture_data + actual_width_ * y, texture_source[top + y] + left,
actual_width_ * sizeof(uint8_t));
}
allocated_data = true;
} else {
// Cast away const-ness because for some reason glTexSubImage2D wants
// a non-const data pointer.
texture_data = const_cast<uint8_t*>(texture_source.data());
}
glTexImage2D(GL_TEXTURE_2D,
0,
GL_LUMINANCE,
texture_width_,
texture_height_,
0,
GL_LUMINANCE,
GL_UNSIGNED_BYTE,
NULL);
glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
glTexSubImage2D(GL_TEXTURE_2D,
0,
0,
0,
actual_width_,
actual_height_,
GL_LUMINANCE,
GL_UNSIGNED_BYTE,
texture_data);
if (allocated_data) {
delete(texture_data);
}
glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
}
// The id for the texture on the GPU.
GLuint texture_;
// The width and height to be used for display purposes, referring to the
// dimensions of the original texture.
int actual_width_;
int actual_height_;
// The allocated dimensions of the texture data, which must be powers of 2.
int texture_width_;
int texture_height_;
TF_DISALLOW_COPY_AND_ASSIGN(Sprite);
};
} // namespace tf_tracking
#endif // __RENDER_OPENGL__
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
#ifdef LOG_TIME
// Storage for logging functionality.
int num_time_logs = 0;
LogEntry time_logs[NUM_LOGS];
int num_avg_entries = 0;
AverageEntry avg_entries[NUM_LOGS];
#endif
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility functions for performance profiling.
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
#include <stdint.h>
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
#ifdef LOG_TIME
// Blend constant for running average.
#define ALPHA 0.98f
#define NUM_LOGS 100
struct LogEntry {
const char* id;
int64_t time_stamp;
};
struct AverageEntry {
const char* id;
float average_duration;
};
// Storage for keeping track of this frame's values.
extern int num_time_logs;
extern LogEntry time_logs[NUM_LOGS];
// Storage for keeping track of average values (each entry may not be printed
// out each frame).
extern AverageEntry avg_entries[NUM_LOGS];
extern int num_avg_entries;
// Call this at the start of a logging phase.
inline static void ResetTimeLog() {
num_time_logs = 0;
}
// Log a message to be printed out when printTimeLog is called, along with the
// amount of time in ms that has passed since the last call to this function.
inline static void TimeLog(const char* const str) {
LOGV("%s", str);
if (num_time_logs >= NUM_LOGS) {
LOGE("Out of log entries!");
return;
}
time_logs[num_time_logs].id = str;
time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos();
++num_time_logs;
}
inline static float Blend(float old_val, float new_val) {
return ALPHA * old_val + (1.0f - ALPHA) * new_val;
}
inline static float UpdateAverage(const char* str, const float new_val) {
for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) {
AverageEntry* const entry = avg_entries + entry_num;
if (str == entry->id) {
entry->average_duration = Blend(entry->average_duration, new_val);
return entry->average_duration;
}
}
if (num_avg_entries >= NUM_LOGS) {
LOGE("Too many log entries!");
}
// If it wasn't there already, add it.
avg_entries[num_avg_entries].id = str;
avg_entries[num_avg_entries].average_duration = new_val;
++num_avg_entries;
return new_val;
}
// Prints out all the timeLog statements in chronological order with the
// interval that passed between subsequent statements. The total time between
// the first and last statements is printed last.
inline static void PrintTimeLog() {
LogEntry* last_time = time_logs;
float average_running_total = 0.0f;
for (int i = 0; i < num_time_logs; ++i) {
LogEntry* const this_time = time_logs + i;
const float curr_time =
(this_time->time_stamp - last_time->time_stamp) / 1000000.0f;
const float avg_time = UpdateAverage(this_time->id, curr_time);
average_running_total += avg_time;
LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time);
last_time = this_time;
}
const float total_time =
(last_time->time_stamp - time_logs->time_stamp) / 1000000.0f;
LOGD("TOTAL TIME: %6.3fms %6.4fms\n",
total_time, average_running_total);
LOGD(" ");
}
#else
inline static void ResetTimeLog() {}
inline static void TimeLog(const char* const str) {
LOGV("%s", str);
}
inline static void PrintTimeLog() {}
#endif
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
namespace tf_tracking {
static const float kInitialDistance = 20.0f;
static void InitNormalized(const Image<uint8_t>& src_image,
const BoundingBox& position,
Image<float>* const dst_image) {
BoundingBox scaled_box(position);
CopyArea(src_image, scaled_box, dst_image);
NormalizeImage(dst_image);
}
TrackedObject::TrackedObject(const std::string& id, const Image<uint8_t>& image,
const BoundingBox& bounding_box,
ObjectModelBase* const model)
: id_(id),
last_known_position_(bounding_box),
last_detection_position_(bounding_box),
position_last_computed_time_(-1),
object_model_(model),
last_detection_thumbnail_(kNormalizedThumbnailSize,
kNormalizedThumbnailSize),
last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize),
tracked_correlation_(0.0f),
tracked_match_score_(0.0),
num_consecutive_frames_below_threshold_(0),
allowable_detection_distance_(Square(kInitialDistance)) {
InitNormalized(image, bounding_box, &last_detection_thumbnail_);
}
TrackedObject::~TrackedObject() {}
void TrackedObject::UpdatePosition(const BoundingBox& new_position,
const int64_t timestamp,
const ImageData& image_data,
const bool authoritative) {
last_known_position_ = new_position;
position_last_computed_time_ = timestamp;
InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_);
const float last_localization_correlation = ComputeCrossCorrelation(
last_detection_thumbnail_.data(),
last_frame_thumbnail_.data(),
last_frame_thumbnail_.data_size_);
LOGV("Tracked correlation to last localization: %.6f",
last_localization_correlation);
// Correlation to object model, if it exists.
if (object_model_ != NULL) {
tracked_correlation_ =
object_model_->GetMaxCorrelation(last_frame_thumbnail_);
LOGV("Tracked correlation to model: %.6f",
tracked_correlation_);
tracked_match_score_ =
object_model_->GetMatchScore(new_position, image_data);
LOGV("Tracked match score with model: %.6f",
tracked_match_score_.value);
} else {
// If there's no model to check against, set the tracked correlation to
// simply be the correlation to the last set position.
tracked_correlation_ = last_localization_correlation;
tracked_match_score_ = MatchScore(0.0f);
}
// Determine if it's still being tracked.
if (tracked_correlation_ >= kMinimumCorrelationForTracking &&
tracked_match_score_ >= kMinimumMatchScore) {
num_consecutive_frames_below_threshold_ = 0;
if (object_model_ != NULL) {
object_model_->TrackStep(last_known_position_, *image_data.GetImage(),
*image_data.GetIntegralImage(), authoritative);
}
} else if (tracked_match_score_ < kMatchScoreForImmediateTermination) {
if (num_consecutive_frames_below_threshold_ < 1000) {
LOGD("Tracked match score is way too low (%.6f), aborting track.",
tracked_match_score_.value);
}
// Add an absurd amount of missed frames so that all heuristics will
// consider it a lost track.
num_consecutive_frames_below_threshold_ += 1000;
if (object_model_ != NULL) {
object_model_->TrackLost();
}
} else {
++num_consecutive_frames_below_threshold_;
allowable_detection_distance_ *= 1.1f;
}
}
void TrackedObject::OnDetection(ObjectModelBase* const model,
const BoundingBox& detection_position,
const MatchScore match_score,
const int64_t timestamp,
const ImageData& image_data) {
const float overlap = detection_position.PascalScore(last_known_position_);
if (overlap > kPositionOverlapThreshold) {
// If the position agreement with the current tracked position is good
// enough, lock all the current unlocked examples.
object_model_->TrackConfirmed();
num_consecutive_frames_below_threshold_ = 0;
}
// Before relocalizing, make sure the new proposed position is better than
// the existing position by a small amount to prevent thrashing.
if (match_score <= tracked_match_score_ + kMatchScoreBuffer) {
LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f",
match_score.value, tracked_match_score_.value,
kMatchScoreBuffer.value);
return;
}
LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to "
"(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f",
last_known_position_.left_, last_known_position_.top_,
last_known_position_.GetWidth(), last_known_position_.GetHeight(),
detection_position.left_, detection_position.top_,
detection_position.GetWidth(), detection_position.GetHeight(),
match_score.value, tracked_match_score_.value);
if (overlap < kPositionOverlapThreshold) {
// The path might be good, it might be bad, but it's no longer a path
// since we're moving the box to a new position, so just nuke it from
// orbit to be safe.
object_model_->TrackLost();
}
object_model_ = model;
// Reset the last detected appearance.
InitNormalized(
*image_data.GetImage(), detection_position, &last_detection_thumbnail_);
num_consecutive_frames_below_threshold_ = 0;
last_detection_position_ = detection_position;
UpdatePosition(detection_position, timestamp, image_data, false);
allowable_detection_distance_ = Square(kInitialDistance);
}
} // namespace tf_tracking
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
#ifdef __RENDER_OPENGL__
#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
#endif
#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
namespace tf_tracking {
// A TrackedObject is a specific instance of an ObjectModel, with a known
// position in the world.
// It provides the last known position and number of recent detection failures,
// in addition to the more general appearance data associated with the object
// class (which is in ObjectModel).
// TODO(andrewharp): Make getters/setters follow styleguide.
class TrackedObject {
public:
TrackedObject(const std::string& id, const Image<uint8_t>& image,
const BoundingBox& bounding_box, ObjectModelBase* const model);
~TrackedObject();
void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp,
const ImageData& image_data, const bool authoritative);
// This method is called when the tracked object is detected at a
// given position, and allows the associated Model to grow and/or prune
// itself based on where the detection occurred.
void OnDetection(ObjectModelBase* const model,
const BoundingBox& detection_position,
const MatchScore match_score, const int64_t timestamp,
const ImageData& image_data);
// Called when there's no detection of the tracked object. This will cause
// a tracking failure after enough consecutive failures if the area under
// the current bounding box also doesn't meet a minimum correlation threshold
// with the model.
void OnDetectionFailure() {}
inline bool IsVisible() const {
return tracked_correlation_ >= kMinimumCorrelationForTracking ||
num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
}
inline float GetCorrelation() {
return tracked_correlation_;
}
inline MatchScore GetMatchScore() {
return tracked_match_score_;
}
inline BoundingBox GetPosition() const {
return last_known_position_;
}
inline BoundingBox GetLastDetectionPosition() const {
return last_detection_position_;
}
inline const ObjectModelBase* GetModel() const {
return object_model_;
}
inline const std::string& GetName() const {
return id_;
}
inline void Draw() const {
#ifdef __RENDER_OPENGL__
if (tracked_correlation_ < kMinimumCorrelationForTracking) {
glColor4f(MAX(0.0f, -tracked_correlation_),
MAX(0.0f, tracked_correlation_),
0.0f,
1.0f);
} else {
glColor4f(MAX(0.0f, -tracked_correlation_),
MAX(0.0f, tracked_correlation_),
1.0f,
1.0f);
}
// Render the box itself.
BoundingBox temp_box(last_known_position_);
DrawBox(temp_box);
// Render a box inside this one (in case the actual box is hidden).
const float kBufferSize = 1.0f;
temp_box.left_ -= kBufferSize;
temp_box.top_ -= kBufferSize;
temp_box.right_ += kBufferSize;
temp_box.bottom_ += kBufferSize;
DrawBox(temp_box);
// Render one outside as well.
temp_box.left_ -= -2.0f * kBufferSize;
temp_box.top_ -= -2.0f * kBufferSize;
temp_box.right_ += -2.0f * kBufferSize;
temp_box.bottom_ += -2.0f * kBufferSize;
DrawBox(temp_box);
#endif
}
// Get current object's num_consecutive_frames_below_threshold_.
inline int64_t GetNumConsecutiveFramesBelowThreshold() {
return num_consecutive_frames_below_threshold_;
}
// Reset num_consecutive_frames_below_threshold_ to 0.
inline void resetNumConsecutiveFramesBelowThreshold() {
num_consecutive_frames_below_threshold_ = 0;
}
inline float GetAllowableDistanceSquared() const {
return allowable_detection_distance_;
}
private:
// The unique id used throughout the system to identify this
// tracked object.
const std::string id_;
// The last known position of the object.
BoundingBox last_known_position_;
// The last known position of the object.
BoundingBox last_detection_position_;
// When the position was last computed.
int64_t position_last_computed_time_;
// The object model this tracked object is representative of.
ObjectModelBase* object_model_;
Image<float> last_detection_thumbnail_;
Image<float> last_frame_thumbnail_;
// The correlation of the object model with the preview frame at its last
// tracked position.
float tracked_correlation_;
MatchScore tracked_match_score_;
// The number of consecutive frames that the tracked position for this object
// has been under the correlation threshold.
int num_consecutive_frames_below_threshold_;
float allowable_detection_distance_;
friend std::ostream& operator<<(std::ostream& stream,
const TrackedObject& tracked_object);
TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
};
inline std::ostream& operator<<(std::ostream& stream,
const TrackedObject& tracked_object) {
stream << tracked_object.id_
<< " " << tracked_object.last_known_position_
<< " " << tracked_object.position_last_computed_time_
<< " " << tracked_object.num_consecutive_frames_below_threshold_
<< " " << tracked_object.object_model_
<< " " << tracked_object.tracked_correlation_;
return stream;
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
#include <math.h>
#include <stdint.h>
#include <stdlib.h>
#include <time.h>
#include <cmath> // for std::abs(float)
#ifndef HAVE_CLOCK_GETTIME
// Use gettimeofday() instead of clock_gettime().
#include <sys/time.h>
#endif // ifdef HAVE_CLOCK_GETTIME
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
// TODO(andrewharp): clean up these macros to use the codebase statndard.
// A very small number, generally used as the tolerance for accumulated
// floating point errors in bounds-checks.
#define EPSILON 0.00001f
#define SAFE_DELETE(pointer) {\
if ((pointer) != NULL) {\
LOGV("Safe deleting pointer: %s", #pointer);\
delete (pointer);\
(pointer) = NULL;\
} else {\
LOGV("Pointer already null: %s", #pointer);\
}\
}
#ifdef __GOOGLE__
#define CHECK_ALWAYS(condition, format, ...) {\
CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
}
#define SCHECK(condition, format, ...) {\
DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
}
#else
#define CHECK_ALWAYS(condition, format, ...) {\
if (!(condition)) {\
LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\
abort();\
}\
}
#ifdef SANITY_CHECKS
#define SCHECK(condition, format, ...) {\
CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\
}
#else
#define SCHECK(condition, format, ...) {}
#endif // SANITY_CHECKS
#endif // __GOOGLE__
#ifndef MAX
#define MAX(a, b) (((a) > (b)) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) (((a) > (b)) ? (b) : (a))
#endif
inline static int64_t CurrentThreadTimeNanos() {
#ifdef HAVE_CLOCK_GETTIME
struct timespec tm;
clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm);
return tm.tv_sec * 1000000000LL + tm.tv_nsec;
#else
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000000 + tv.tv_usec * 1000;
#endif
}
inline static int64_t CurrentRealTimeMillis() {
#ifdef HAVE_CLOCK_GETTIME
struct timespec tm;
clock_gettime(CLOCK_MONOTONIC, &tm);
return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL;
#else
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000 + tv.tv_usec / 1000;
#endif
}
template<typename T>
inline static T Square(const T a) {
return a * a;
}
template<typename T>
inline static T Clip(const T a, const T floor, const T ceil) {
SCHECK(ceil >= floor, "Bounds mismatch!");
return (a <= floor) ? floor : ((a >= ceil) ? ceil : a);
}
template<typename T>
inline static int Floor(const T a) {
return static_cast<int>(a);
}
template<typename T>
inline static int Ceil(const T a) {
return Floor(a) + 1;
}
template<typename T>
inline static bool InRange(const T a, const T min, const T max) {
return (a >= min) && (a <= max);
}
inline static bool ValidIndex(const int a, const int max) {
return (a >= 0) && (a < max);
}
inline bool NearlyEqual(const float a, const float b, const float tolerance) {
return std::abs(a - b) < tolerance;
}
inline bool NearlyEqual(const float a, const float b) {
return NearlyEqual(a, b, EPSILON);
}
template<typename T>
inline static int Round(const float a) {
return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a));
}
template<typename T>
inline static void Swap(T* const a, T* const b) {
// Cache out the VALUE of what's at a.
T tmp = *a;
*a = *b;
*b = tmp;
}
static inline float randf() {
return rand() / static_cast<float>(RAND_MAX);
}
static inline float randf(const float min_value, const float max_value) {
return randf() * (max_value - min_value) + min_value;
}
static inline uint16_t RealToFixed115(const float real_number) {
SCHECK(InRange(real_number, 0.0f, 2048.0f),
"Value out of range! %.2f", real_number);
static const float kMult = 32.0f;
const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
return static_cast<uint16_t>(real_number * kMult + round_add);
}
static inline float FixedToFloat115(const uint16_t fp_number) {
const float kDiv = 32.0f;
return (static_cast<float>(fp_number) / kDiv);
}
static inline int RealToFixed1616(const float real_number) {
static const float kMult = 65536.0f;
SCHECK(InRange(real_number, -kMult, kMult),
"Value out of range! %.2f", real_number);
const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
return static_cast<int>(real_number * kMult + round_add);
}
static inline float FixedToFloat1616(const int fp_number) {
const float kDiv = 65536.0f;
return (static_cast<float>(fp_number) / kDiv);
}
template<typename T>
// produces numbers in range [0,2*M_PI] (rather than -PI,PI)
inline T FastAtan2(const T y, const T x) {
static const T coeff_1 = (T)(M_PI / 4.0);
static const T coeff_2 = (T)(3.0 * coeff_1);
const T abs_y = fabs(y);
T angle;
if (x >= 0) {
T r = (x - abs_y) / (x + abs_y);
angle = coeff_1 - coeff_1 * r;
} else {
T r = (x + abs_y) / (abs_y - x);
angle = coeff_2 - coeff_1 * r;
}
static const T PI_2 = 2.0 * M_PI;
return y < 0 ? PI_2 - angle : angle;
}
#define NELEMS(X) (sizeof(X) / sizeof(X[0]))
namespace tf_tracking {
#ifdef __ARM_NEON
float ComputeMeanNeon(const float* const values, const int num_vals);
float ComputeStdDevNeon(const float* const values, const int num_vals,
const float mean);
float ComputeWeightedMeanNeon(const float* const values,
const float* const weights, const int num_vals);
float ComputeCrossCorrelationNeon(const float* const values1,
const float* const values2,
const int num_vals);
#endif
inline float ComputeMeanCpu(const float* const values, const int num_vals) {
// Get mean.
float sum = values[0];
for (int i = 1; i < num_vals; ++i) {
sum += values[i];
}
return sum / static_cast<float>(num_vals);
}
inline float ComputeMean(const float* const values, const int num_vals) {
return
#ifdef __ARM_NEON
(num_vals >= 8) ? ComputeMeanNeon(values, num_vals) :
#endif
ComputeMeanCpu(values, num_vals);
}
inline float ComputeStdDevCpu(const float* const values,
const int num_vals,
const float mean) {
// Get Std dev.
float squared_sum = 0.0f;
for (int i = 0; i < num_vals; ++i) {
squared_sum += Square(values[i] - mean);
}
return sqrt(squared_sum / static_cast<float>(num_vals));
}
inline float ComputeStdDev(const float* const values,
const int num_vals,
const float mean) {
return
#ifdef __ARM_NEON
(num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) :
#endif
ComputeStdDevCpu(values, num_vals, mean);
}
// TODO(andrewharp): Accelerate with NEON.
inline float ComputeWeightedMean(const float* const values,
const float* const weights,
const int num_vals) {
float sum = 0.0f;
float total_weight = 0.0f;
for (int i = 0; i < num_vals; ++i) {
sum += values[i] * weights[i];
total_weight += weights[i];
}
return sum / num_vals;
}
inline float ComputeCrossCorrelationCpu(const float* const values1,
const float* const values2,
const int num_vals) {
float sxy = 0.0f;
for (int offset = 0; offset < num_vals; ++offset) {
sxy += values1[offset] * values2[offset];
}
const float cross_correlation = sxy / num_vals;
return cross_correlation;
}
inline float ComputeCrossCorrelation(const float* const values1,
const float* const values2,
const int num_vals) {
return
#ifdef __ARM_NEON
(num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals)
:
#endif
ComputeCrossCorrelationCpu(values1, values2, num_vals);
}
inline void NormalizeNumbers(float* const values, const int num_vals) {
// Find the mean and then subtract so that the new mean is 0.0.
const float mean = ComputeMean(values, num_vals);
VLOG(2) << "Mean is " << mean;
float* curr_data = values;
for (int i = 0; i < num_vals; ++i) {
*curr_data -= mean;
curr_data++;
}
// Now divide by the std deviation so the new standard deviation is 1.0.
// The numbers might all be identical (and thus shifted to 0.0 now),
// so only scale by the standard deviation if this is not the case.
const float std_dev = ComputeStdDev(values, num_vals, 0.0f);
if (std_dev > 0.0f) {
VLOG(2) << "Std dev is " << std_dev;
curr_data = values;
for (int i = 0; i < num_vals; ++i) {
*curr_data /= std_dev;
curr_data++;
}
}
}
// Returns the determinant of a 2x2 matrix.
template<class T>
inline T FindDeterminant2x2(const T* const a) {
// Determinant: (ad - bc)
return a[0] * a[3] - a[1] * a[2];
}
// Finds the inverse of a 2x2 matrix.
// Returns true upon success, false if the matrix is not invertible.
template<class T>
inline bool Invert2x2(const T* const a, float* const a_inv) {
const float det = static_cast<float>(FindDeterminant2x2(a));
if (fabs(det) < EPSILON) {
return false;
}
const float inv_det = 1.0f / det;
a_inv[0] = inv_det * static_cast<float>(a[3]); // d
a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b
a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c
a_inv[3] = inv_det * static_cast<float>(a[0]); // a
return true;
}
} // namespace tf_tracking
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// NEON implementations of Image methods for compatible devices. Control
// should never enter this compilation unit on incompatible devices.
#ifdef __ARM_NEON
#include <arm_neon.h>
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
#include "tensorflow/examples/android/jni/object_tracking/image.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
namespace tf_tracking {
inline static float GetSum(const float32x4_t& values) {
static float32_t summed_values[4];
vst1q_f32(summed_values, values);
return summed_values[0]
+ summed_values[1]
+ summed_values[2]
+ summed_values[3];
}
float ComputeMeanNeon(const float* const values, const int num_vals) {
SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
const float32_t* const arm_vals = (const float32_t* const) values;
float32x4_t accum = vdupq_n_f32(0.0f);
int offset = 0;
for (; offset <= num_vals - 4; offset += 4) {
accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset]));
}
// Pull the accumulated values into a single variable.
float sum = GetSum(accum);
// Get the remaining 1 to 3 values.
for (; offset < num_vals; ++offset) {
sum += values[offset];
}
const float mean_neon = sum / static_cast<float>(num_vals);
#ifdef SANITY_CHECKS
const float mean_cpu = ComputeMeanCpu(values, num_vals);
SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals),
"Neon mismatch with CPU mean! %.10f vs %.10f",
mean_neon, mean_cpu);
#endif
return mean_neon;
}
float ComputeStdDevNeon(const float* const values,
const int num_vals, const float mean) {
SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
const float32_t* const arm_vals = (const float32_t* const) values;
const float32x4_t mean_vec = vdupq_n_f32(-mean);
float32x4_t accum = vdupq_n_f32(0.0f);
int offset = 0;
for (; offset <= num_vals - 4; offset += 4) {
const float32x4_t deltas =
vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset]));
accum = vmlaq_f32(accum, deltas, deltas);
}
// Pull the accumulated values into a single variable.
float squared_sum = GetSum(accum);
// Get the remaining 1 to 3 values.
for (; offset < num_vals; ++offset) {
squared_sum += Square(values[offset] - mean);
}
const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals));
#ifdef SANITY_CHECKS
const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean);
SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals),
"Neon mismatch with CPU std dev! %.10f vs %.10f",
std_dev_neon, std_dev_cpu);
#endif
return std_dev_neon;
}
float ComputeCrossCorrelationNeon(const float* const values1,
const float* const values2,
const int num_vals) {
SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
const float32_t* const arm_vals1 = (const float32_t* const) values1;
const float32_t* const arm_vals2 = (const float32_t* const) values2;
float32x4_t accum = vdupq_n_f32(0.0f);
int offset = 0;
for (; offset <= num_vals - 4; offset += 4) {
accum = vmlaq_f32(accum,
vld1q_f32(&arm_vals1[offset]),
vld1q_f32(&arm_vals2[offset]));
}
// Pull the accumulated values into a single variable.
float sxy = GetSum(accum);
// Get the remaining 1 to 3 values.
for (; offset < num_vals; ++offset) {
sxy += values1[offset] * values2[offset];
}
const float cross_correlation_neon = sxy / num_vals;
#ifdef SANITY_CHECKS
const float cross_correlation_cpu =
ComputeCrossCorrelationCpu(values1, values2, num_vals);
SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu,
EPSILON * num_vals),
"Neon mismatch with CPU cross correlation! %.10f vs %.10f",
cross_correlation_neon, cross_correlation_cpu);
#endif
return cross_correlation_neon;
}
} // namespace tf_tracking
#endif // __ARM_NEON
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// These utility functions allow for the conversion of RGB data to YUV data.
#include "tensorflow/examples/android/jni/rgb2yuv.h"
static inline void WriteYUV(const int x, const int y, const int width,
const int r8, const int g8, const int b8,
uint8_t* const pY, uint8_t* const pUV) {
// Using formulas from http://msdn.microsoft.com/en-us/library/ms893078
*pY = ((66 * r8 + 129 * g8 + 25 * b8 + 128) >> 8) + 16;
// Odd widths get rounded up so that UV blocks on the side don't get cut off.
const int blocks_per_row = (width + 1) / 2;
// 2 bytes per UV block
const int offset = 2 * (((y / 2) * blocks_per_row + (x / 2)));
// U and V are the average values of all 4 pixels in the block.
if (!(x & 1) && !(y & 1)) {
// Explicitly clear the block if this is the first pixel in it.
pUV[offset] = 0;
pUV[offset + 1] = 0;
}
// V (with divide by 4 factored in)
#ifdef __APPLE__
const int u_offset = 0;
const int v_offset = 1;
#else
const int u_offset = 1;
const int v_offset = 0;
#endif
pUV[offset + v_offset] += ((112 * r8 - 94 * g8 - 18 * b8 + 128) >> 10) + 32;
// U (with divide by 4 factored in)
pUV[offset + u_offset] += ((-38 * r8 - 74 * g8 + 112 * b8 + 128) >> 10) + 32;
}
void ConvertARGB8888ToYUV420SP(const uint32_t* const input,
uint8_t* const output, int width, int height) {
uint8_t* pY = output;
uint8_t* pUV = output + (width * height);
const uint32_t* in = input;
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
const uint32_t rgb = *in++;
#ifdef __APPLE__
const int nB = (rgb >> 8) & 0xFF;
const int nG = (rgb >> 16) & 0xFF;
const int nR = (rgb >> 24) & 0xFF;
#else
const int nR = (rgb >> 16) & 0xFF;
const int nG = (rgb >> 8) & 0xFF;
const int nB = rgb & 0xFF;
#endif
WriteYUV(x, y, width, nR, nG, nB, pY++, pUV);
}
}
}
void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
const int width, const int height) {
uint8_t* pY = output;
uint8_t* pUV = output + (width * height);
const uint16_t* in = input;
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
const uint32_t rgb = *in++;
const int r5 = ((rgb >> 11) & 0x1F);
const int g6 = ((rgb >> 5) & 0x3F);
const int b5 = (rgb & 0x1F);
// Shift left, then fill in the empty low bits with a copy of the high
// bits so we can stretch across the entire 0 - 255 range.
const int r8 = r5 << 3 | r5 >> 2;
const int g8 = g6 << 2 | g6 >> 4;
const int b8 = b5 << 3 | b5 >> 2;
WriteYUV(x, y, width, r8, g8, b8, pY++, pUV);
}
}
}
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
void ConvertARGB8888ToYUV420SP(const uint32_t* const input,
uint8_t* const output, int width, int height);
void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
const int width, const int height);
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
VERS_1.0 {
# Export JNI symbols.
global:
Java_*;
JNI_OnLoad;
JNI_OnUnload;
# Hide everything else.
local:
*;
};
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a collection of routines which converts various YUV image formats
// to ARGB.
#include "tensorflow/examples/android/jni/yuv2rgb.h"
#ifndef MAX
#define MAX(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a > _b ? _a : _b; })
#define MIN(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a < _b ? _a : _b; })
#endif
// This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
// are normalized to eight bits.
static const int kMaxChannelValue = 262143;
static inline uint32_t YUV2RGB(int nY, int nU, int nV) {
nY -= 16;
nU -= 128;
nV -= 128;
if (nY < 0) nY = 0;
// This is the floating point equivalent. We do the conversion in integer
// because some Android devices do not have floating point in hardware.
// nR = (int)(1.164 * nY + 2.018 * nU);
// nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
// nB = (int)(1.164 * nY + 1.596 * nV);
int nR = 1192 * nY + 1634 * nV;
int nG = 1192 * nY - 833 * nV - 400 * nU;
int nB = 1192 * nY + 2066 * nU;
nR = MIN(kMaxChannelValue, MAX(0, nR));
nG = MIN(kMaxChannelValue, MAX(0, nG));
nB = MIN(kMaxChannelValue, MAX(0, nB));
nR = (nR >> 10) & 0xff;
nG = (nG >> 10) & 0xff;
nB = (nB >> 10) & 0xff;
return 0xff000000 | (nR << 16) | (nG << 8) | nB;
}
// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by
// separate u and v planes with arbitrary row and column strides,
// containing 8 bit 2x2 subsampled chroma samples.
// Converts to a packed ARGB 32 bit output of the same pixel dimensions.
void ConvertYUV420ToARGB8888(const uint8_t* const yData,
const uint8_t* const uData,
const uint8_t* const vData, uint32_t* const output,
const int width, const int height,
const int y_row_stride, const int uv_row_stride,
const int uv_pixel_stride) {
uint32_t* out = output;
for (int y = 0; y < height; y++) {
const uint8_t* pY = yData + y_row_stride * y;
const int uv_row_start = uv_row_stride * (y >> 1);
const uint8_t* pU = uData + uv_row_start;
const uint8_t* pV = vData + uv_row_start;
for (int x = 0; x < width; x++) {
const int uv_offset = (x >> 1) * uv_pixel_stride;
*out++ = YUV2RGB(pY[x], pU[uv_offset], pV[uv_offset]);
}
}
}
// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
// except the interleave order of U and V is reversed. Converts to a packed
// ARGB 32 bit output of the same pixel dimensions.
void ConvertYUV420SPToARGB8888(const uint8_t* const yData,
const uint8_t* const uvData,
uint32_t* const output, const int width,
const int height) {
const uint8_t* pY = yData;
const uint8_t* pUV = uvData;
uint32_t* out = output;
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int nY = *pY++;
int offset = (y >> 1) * width + 2 * (x >> 1);
#ifdef __APPLE__
int nU = pUV[offset];
int nV = pUV[offset + 1];
#else
int nV = pUV[offset];
int nU = pUV[offset + 1];
#endif
*out++ = YUV2RGB(nY, nU, nV);
}
}
}
// The same as above, but downsamples each dimension to half size.
void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input,
uint32_t* const output, int width,
int height) {
const uint8_t* pY = input;
const uint8_t* pUV = input + (width * height);
uint32_t* out = output;
int stride = width;
width >>= 1;
height >>= 1;
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int nY = (pY[0] + pY[1] + pY[stride] + pY[stride + 1]) >> 2;
pY += 2;
#ifdef __APPLE__
int nU = *pUV++;
int nV = *pUV++;
#else
int nV = *pUV++;
int nU = *pUV++;
#endif
*out++ = YUV2RGB(nY, nU, nV);
}
pY += stride;
}
}
// Accepts a YUV 4:2:0 image with a plane of 8 bit Y samples followed by an
// interleaved U/V plane containing 8 bit 2x2 subsampled chroma samples,
// except the interleave order of U and V is reversed. Converts to a packed
// RGB 565 bit output of the same pixel dimensions.
void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
const int width, const int height) {
const uint8_t* pY = input;
const uint8_t* pUV = input + (width * height);
uint16_t* out = output;
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int nY = *pY++;
int offset = (y >> 1) * width + 2 * (x >> 1);
#ifdef __APPLE__
int nU = pUV[offset];
int nV = pUV[offset + 1];
#else
int nV = pUV[offset];
int nU = pUV[offset + 1];
#endif
nY -= 16;
nU -= 128;
nV -= 128;
if (nY < 0) nY = 0;
// This is the floating point equivalent. We do the conversion in integer
// because some Android devices do not have floating point in hardware.
// nR = (int)(1.164 * nY + 2.018 * nU);
// nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
// nB = (int)(1.164 * nY + 1.596 * nV);
int nR = 1192 * nY + 1634 * nV;
int nG = 1192 * nY - 833 * nV - 400 * nU;
int nB = 1192 * nY + 2066 * nU;
nR = MIN(kMaxChannelValue, MAX(0, nR));
nG = MIN(kMaxChannelValue, MAX(0, nG));
nB = MIN(kMaxChannelValue, MAX(0, nB));
// Shift more than for ARGB8888 and apply appropriate bitmask.
nR = (nR >> 13) & 0x1f;
nG = (nG >> 12) & 0x3f;
nB = (nB >> 13) & 0x1f;
// R is high 5 bits, G is middle 6 bits, and B is low 5 bits.
*out++ = (nR << 11) | (nG << 5) | nB;
}
}
}
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a collection of routines which converts various YUV image formats
// to (A)RGB.
#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
#define TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
void ConvertYUV420ToARGB8888(const uint8_t* const yData,
const uint8_t* const uData,
const uint8_t* const vData, uint32_t* const output,
const int width, const int height,
const int y_row_stride, const int uv_row_stride,
const int uv_pixel_stride);
// Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
// and height. The input and output must already be allocated and non-null.
// For efficiency, no error checking is performed.
void ConvertYUV420SPToARGB8888(const uint8_t* const pY,
const uint8_t* const pUV, uint32_t* const output,
const int width, const int height);
// The same as above, but downsamples each dimension to half size.
void ConvertYUV420SPToARGB8888HalfSize(const uint8_t* const input,
uint32_t* const output, int width,
int height);
// Converts YUV420 semi-planar data to RGB 565 data using the supplied width
// and height. The input and output must already be allocated and non-null.
// For efficiency, no error checking is performed.
void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
const int width, const int height);
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<set xmlns:android="http://schemas.android.com/apk/res/android"
android:ordering="sequentially">
<objectAnimator
android:propertyName="backgroundColor"
android:duration="375"
android:valueFrom="0x00b3ccff"
android:valueTo="0xffb3ccff"
android:valueType="colorType"/>
<objectAnimator
android:propertyName="backgroundColor"
android:duration="375"
android:valueFrom="0xffb3ccff"
android:valueTo="0x00b3ccff"
android:valueType="colorType"/>
</set>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle" >
<solid android:color="#00000000" />
<stroke android:width="1dip" android:color="#cccccc" />
</shape>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/container"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#000"
tools:context="org.tensorflow.demo.CameraActivity" />
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<FrameLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="org.tensorflow.demo.SpeechActivity">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Say one of the words below!"
android:id="@+id/textView"
android:textAlignment="center"
android:layout_gravity="top"
android:textSize="24dp"
android:layout_marginTop="10dp"
android:layout_marginLeft="10dp"
/>
<ListView
android:id="@+id/list_view"
android:layout_width="240dp"
android:layout_height="wrap_content"
android:background="@drawable/border"
android:layout_gravity="top|center_horizontal"
android:textAlignment="center"
android:layout_marginTop="100dp"
/>
<Button
android:id="@+id/quit"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Quit"
android:layout_gravity="bottom|center_horizontal"
android:layout_marginBottom="10dp"
/>
</FrameLayout>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent">
<org.tensorflow.demo.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true" />
<org.tensorflow.demo.RecognitionScoreView
android:id="@+id/results"
android:layout_width="match_parent"
android:layout_height="112dp"
android:layout_alignParentTop="true" />
<org.tensorflow.demo.OverlayView
android:id="@+id/debug_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentBottom="true" />
</RelativeLayout>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:orientation="vertical"
android:layout_width="match_parent"
android:layout_height="match_parent">
<org.tensorflow.demo.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentTop="true" />
<RelativeLayout
android:id="@+id/black"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="#FF000000" />
<GridView
android:id="@+id/grid_layout"
android:numColumns="7"
android:stretchMode="columnWidth"
android:layout_alignParentBottom="true"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
<org.tensorflow.demo.OverlayView
android:id="@+id/overlay"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentTop="true" />
<org.tensorflow.demo.OverlayView
android:id="@+id/debug_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentTop="true" />
</RelativeLayout>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent">
<org.tensorflow.demo.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"/>
<org.tensorflow.demo.OverlayView
android:id="@+id/tracking_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
<org.tensorflow.demo.OverlayView
android:id="@+id/debug_overlay"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
</FrameLayout>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<TextView
xmlns:android="http://schemas.android.com/apk/res/android"
android:id="@+id/list_text_item"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="TextView"
android:textSize="24dp"
android:textAlignment="center"
android:gravity="center_horizontal"
/>
<!--
Copyright 2013 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<!-- Semantic definitions -->
<dimen name="horizontal_page_margin">@dimen/margin_huge</dimen>
<dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
</resources>
<!--
Copyright 2013 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<style name="Widget.SampleMessage">
<item name="android:textAppearance">?android:textAppearanceLarge</item>
<item name="android:lineSpacingMultiplier">1.2</item>
<item name="android:shadowDy">-6.5</item>
</style>
</resources>
<?xml version="1.0" encoding="utf-8"?>
<resources>
<!--
Base application theme for API 11+. This theme completely replaces
AppBaseTheme from res/values/styles.xml on API 11+ devices.
-->
<style name="AppBaseTheme" parent="android:Theme.Holo.Light">
<!-- API 11 theme customizations can go here. -->
</style>
<style name="FullscreenTheme" parent="android:Theme.Holo">
<item name="android:actionBarStyle">@style/FullscreenActionBarStyle</item>
<item name="android:windowActionBarOverlay">true</item>
<item name="android:windowBackground">@null</item>
<item name="metaButtonBarStyle">?android:attr/buttonBarStyle</item>
<item name="metaButtonBarButtonStyle">?android:attr/buttonBarButtonStyle</item>
</style>
<style name="FullscreenActionBarStyle" parent="android:Widget.Holo.ActionBar">
<!-- <item name="android:background">@color/black_overlay</item> -->
</style>
</resources>
<!--
Copyright 2013 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<!-- Activity themes -->
<style name="Theme.Base" parent="android:Theme.Holo.Light" />
</resources>
<resources>
<!--
Base application theme for API 14+. This theme completely replaces
AppBaseTheme from BOTH res/values/styles.xml and
res/values-v11/styles.xml on API 14+ devices.
-->
<style name="AppBaseTheme" parent="android:Theme.Holo.Light.DarkActionBar">
<!-- API 14 theme customizations can go here. -->
</style>
</resources>
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright 2013 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
</resources>
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<!-- Activity themes -->
<style name="Theme.Base" parent="android:Theme.Material.Light">
</style>
</resources>
<resources>
<!--
Declare custom theme attributes that allow changing which styles are
used for button bars depending on the API level.
?android:attr/buttonBarStyle is new as of API 11 so this is
necessary to support previous API levels.
-->
<declare-styleable name="ButtonBarContainerTheme">
<attr name="metaButtonBarStyle" format="reference" />
<attr name="metaButtonBarButtonStyle" format="reference" />
</declare-styleable>
</resources>
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<string name="app_name">TensorFlow Demo</string>
<string name="activity_name_classification">TF Classify</string>
<string name="activity_name_detection">TF Detect</string>
<string name="activity_name_stylize">TF Stylize</string>
<string name="activity_name_speech">TF Speech</string>
</resources>
<?xml version="1.0" encoding="utf-8"?>
<!--
Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<color name="control_background">#cc4285f4</color>
</resources>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<string name="description_info">Info</string>
<string name="request_permission">This sample needs camera permission.</string>
<string name="camera_error">This device doesn\'t support Camera2 API.</string>
</resources>
<?xml version="1.0" encoding="utf-8"?><!--
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
</resources>
<!--
Copyright 2013 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<!-- Define standard dimensions to comply with Holo-style grids and rhythm. -->
<dimen name="margin_tiny">4dp</dimen>
<dimen name="margin_small">8dp</dimen>
<dimen name="margin_medium">16dp</dimen>
<dimen name="margin_large">32dp</dimen>
<dimen name="margin_huge">64dp</dimen>
<!-- Semantic definitions -->
<dimen name="horizontal_page_margin">@dimen/margin_medium</dimen>
<dimen name="vertical_page_margin">@dimen/margin_medium</dimen>
</resources>
<!--
Copyright 2013 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<resources>
<!-- Activity themes -->
<style name="Theme.Base" parent="android:Theme.Light" />
<style name="Theme.Sample" parent="Theme.Base" />
<style name="AppTheme" parent="Theme.Sample" />
<!-- Widget styling -->
<style name="Widget" />
<style name="Widget.SampleMessage">
<item name="android:textAppearance">?android:textAppearanceMedium</item>
<item name="android:lineSpacingMultiplier">1.1</item>
</style>
<style name="Widget.SampleMessageTile">
<item name="android:background">@drawable/tile</item>
<item name="android:shadowColor">#7F000000</item>
<item name="android:shadowDy">-3.5</item>
<item name="android:shadowRadius">2</item>
</style>
</resources>
/*
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.content.Context;
import android.util.AttributeSet;
import android.view.TextureView;
/**
* A {@link TextureView} that can be adjusted to a specified aspect ratio.
*/
public class AutoFitTextureView extends TextureView {
private int ratioWidth = 0;
private int ratioHeight = 0;
public AutoFitTextureView(final Context context) {
this(context, null);
}
public AutoFitTextureView(final Context context, final AttributeSet attrs) {
this(context, attrs, 0);
}
public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) {
super(context, attrs, defStyle);
}
/**
* Sets the aspect ratio for this view. The size of the view will be measured based on the ratio
* calculated from the parameters. Note that the actual sizes of parameters don't matter, that
* is, calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result.
*
* @param width Relative horizontal size
* @param height Relative vertical size
*/
public void setAspectRatio(final int width, final int height) {
if (width < 0 || height < 0) {
throw new IllegalArgumentException("Size cannot be negative.");
}
ratioWidth = width;
ratioHeight = height;
requestLayout();
}
@Override
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
final int width = MeasureSpec.getSize(widthMeasureSpec);
final int height = MeasureSpec.getSize(heightMeasureSpec);
if (0 == ratioWidth || 0 == ratioHeight) {
setMeasuredDimension(width, height);
} else {
if (width < height * ratioWidth / ratioHeight) {
setMeasuredDimension(width, width * ratioHeight / ratioWidth);
} else {
setMeasuredDimension(height * ratioWidth / ratioHeight, height);
}
}
}
}
/*
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.Manifest;
import android.app.Activity;
import android.app.Fragment;
import android.content.Context;
import android.content.pm.PackageManager;
import android.hardware.Camera;
import android.hardware.camera2.CameraAccessException;
import android.hardware.camera2.CameraCharacteristics;
import android.hardware.camera2.CameraManager;
import android.hardware.camera2.params.StreamConfigurationMap;
import android.media.Image;
import android.media.Image.Plane;
import android.media.ImageReader;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.Build;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Trace;
import android.util.Size;
import android.view.KeyEvent;
import android.view.Surface;
import android.view.WindowManager;
import android.widget.Toast;
import java.nio.ByteBuffer;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
public abstract class CameraActivity extends Activity
implements OnImageAvailableListener, Camera.PreviewCallback {
private static final Logger LOGGER = new Logger();
private static final int PERMISSIONS_REQUEST = 1;
private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA;
private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
private boolean debug = false;
private Handler handler;
private HandlerThread handlerThread;
private boolean useCamera2API;
private boolean isProcessingFrame = false;
private byte[][] yuvBytes = new byte[3][];
private int[] rgbBytes = null;
private int yRowStride;
protected int previewWidth = 0;
protected int previewHeight = 0;
private Runnable postInferenceCallback;
private Runnable imageConverter;
@Override
protected void onCreate(final Bundle savedInstanceState) {
LOGGER.d("onCreate " + this);
super.onCreate(null);
getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
setContentView(R.layout.activity_camera);
if (hasPermission()) {
setFragment();
} else {
requestPermission();
}
}
private byte[] lastPreviewFrame;
protected int[] getRgbBytes() {
imageConverter.run();
return rgbBytes;
}
protected int getLuminanceStride() {
return yRowStride;
}
protected byte[] getLuminance() {
return yuvBytes[0];
}
/**
* Callback for android.hardware.Camera API
*/
@Override
public void onPreviewFrame(final byte[] bytes, final Camera camera) {
if (isProcessingFrame) {
LOGGER.w("Dropping frame!");
return;
}
try {
// Initialize the storage bitmaps once when the resolution is known.
if (rgbBytes == null) {
Camera.Size previewSize = camera.getParameters().getPreviewSize();
previewHeight = previewSize.height;
previewWidth = previewSize.width;
rgbBytes = new int[previewWidth * previewHeight];
onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90);
}
} catch (final Exception e) {
LOGGER.e(e, "Exception!");
return;
}
isProcessingFrame = true;
lastPreviewFrame = bytes;
yuvBytes[0] = bytes;
yRowStride = previewWidth;
imageConverter =
new Runnable() {
@Override
public void run() {
ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes);
}
};
postInferenceCallback =
new Runnable() {
@Override
public void run() {
camera.addCallbackBuffer(bytes);
isProcessingFrame = false;
}
};
processImage();
}
/**
* Callback for Camera2 API
*/
@Override
public void onImageAvailable(final ImageReader reader) {
//We need wait until we have some size from onPreviewSizeChosen
if (previewWidth == 0 || previewHeight == 0) {
return;
}
if (rgbBytes == null) {
rgbBytes = new int[previewWidth * previewHeight];
}
try {
final Image image = reader.acquireLatestImage();
if (image == null) {
return;
}
if (isProcessingFrame) {
image.close();
return;
}
isProcessingFrame = true;
Trace.beginSection("imageAvailable");
final Plane[] planes = image.getPlanes();
fillBytes(planes, yuvBytes);
yRowStride = planes[0].getRowStride();
final int uvRowStride = planes[1].getRowStride();
final int uvPixelStride = planes[1].getPixelStride();
imageConverter =
new Runnable() {
@Override
public void run() {
ImageUtils.convertYUV420ToARGB8888(
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
rgbBytes);
}
};
postInferenceCallback =
new Runnable() {
@Override
public void run() {
image.close();
isProcessingFrame = false;
}
};
processImage();
} catch (final Exception e) {
LOGGER.e(e, "Exception!");
Trace.endSection();
return;
}
Trace.endSection();
}
@Override
public synchronized void onStart() {
LOGGER.d("onStart " + this);
super.onStart();
}
@Override
public synchronized void onResume() {
LOGGER.d("onResume " + this);
super.onResume();
handlerThread = new HandlerThread("inference");
handlerThread.start();
handler = new Handler(handlerThread.getLooper());
}
@Override
public synchronized void onPause() {
LOGGER.d("onPause " + this);
if (!isFinishing()) {
LOGGER.d("Requesting finish");
finish();
}
handlerThread.quitSafely();
try {
handlerThread.join();
handlerThread = null;
handler = null;
} catch (final InterruptedException e) {
LOGGER.e(e, "Exception!");
}
super.onPause();
}
@Override
public synchronized void onStop() {
LOGGER.d("onStop " + this);
super.onStop();
}
@Override
public synchronized void onDestroy() {
LOGGER.d("onDestroy " + this);
super.onDestroy();
}
protected synchronized void runInBackground(final Runnable r) {
if (handler != null) {
handler.post(r);
}
}
@Override
public void onRequestPermissionsResult(
final int requestCode, final String[] permissions, final int[] grantResults) {
if (requestCode == PERMISSIONS_REQUEST) {
if (grantResults.length > 0
&& grantResults[0] == PackageManager.PERMISSION_GRANTED
&& grantResults[1] == PackageManager.PERMISSION_GRANTED) {
setFragment();
} else {
requestPermission();
}
}
}
private boolean hasPermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED &&
checkSelfPermission(PERMISSION_STORAGE) == PackageManager.PERMISSION_GRANTED;
} else {
return true;
}
}
private void requestPermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA) ||
shouldShowRequestPermissionRationale(PERMISSION_STORAGE)) {
Toast.makeText(CameraActivity.this,
"Camera AND storage permission are required for this demo", Toast.LENGTH_LONG).show();
}
requestPermissions(new String[] {PERMISSION_CAMERA, PERMISSION_STORAGE}, PERMISSIONS_REQUEST);
}
}
// Returns true if the device supports the required hardware level, or better.
private boolean isHardwareLevelSupported(
CameraCharacteristics characteristics, int requiredLevel) {
int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL);
if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) {
return requiredLevel == deviceLevel;
}
// deviceLevel is not LEGACY, can use numerical sort
return requiredLevel <= deviceLevel;
}
private String chooseCamera() {
final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE);
try {
for (final String cameraId : manager.getCameraIdList()) {
final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
// We don't use a front facing camera in this sample.
final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING);
if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) {
continue;
}
final StreamConfigurationMap map =
characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
if (map == null) {
continue;
}
// Fallback to camera1 API for internal cameras that don't have full support.
// This should help with legacy situations where using the camera2 API causes
// distorted or otherwise broken previews.
useCamera2API = (facing == CameraCharacteristics.LENS_FACING_EXTERNAL)
|| isHardwareLevelSupported(characteristics,
CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL);
LOGGER.i("Camera API lv2?: %s", useCamera2API);
return cameraId;
}
} catch (CameraAccessException e) {
LOGGER.e(e, "Not allowed to access camera");
}
return null;
}
protected void setFragment() {
String cameraId = chooseCamera();
if (cameraId == null) {
Toast.makeText(this, "No Camera Detected", Toast.LENGTH_SHORT).show();
finish();
}
Fragment fragment;
if (useCamera2API) {
CameraConnectionFragment camera2Fragment =
CameraConnectionFragment.newInstance(
new CameraConnectionFragment.ConnectionCallback() {
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
previewHeight = size.getHeight();
previewWidth = size.getWidth();
CameraActivity.this.onPreviewSizeChosen(size, rotation);
}
},
this,
getLayoutId(),
getDesiredPreviewFrameSize());
camera2Fragment.setCamera(cameraId);
fragment = camera2Fragment;
} else {
fragment =
new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize());
}
getFragmentManager()
.beginTransaction()
.replace(R.id.container, fragment)
.commit();
}
protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) {
// Because of the variable row stride it's not possible to know in
// advance the actual necessary dimensions of the yuv planes.
for (int i = 0; i < planes.length; ++i) {
final ByteBuffer buffer = planes[i].getBuffer();
if (yuvBytes[i] == null) {
LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity());
yuvBytes[i] = new byte[buffer.capacity()];
}
buffer.get(yuvBytes[i]);
}
}
public boolean isDebug() {
return debug;
}
public void requestRender() {
final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay);
if (overlay != null) {
overlay.postInvalidate();
}
}
public void addCallback(final OverlayView.DrawCallback callback) {
final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay);
if (overlay != null) {
overlay.addCallback(callback);
}
}
public void onSetDebug(final boolean debug) {}
@Override
public boolean onKeyDown(final int keyCode, final KeyEvent event) {
if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP
|| keyCode == KeyEvent.KEYCODE_BUTTON_L1 || keyCode == KeyEvent.KEYCODE_DPAD_CENTER) {
debug = !debug;
requestRender();
onSetDebug(debug);
return true;
}
return super.onKeyDown(keyCode, event);
}
protected void readyForNextImage() {
if (postInferenceCallback != null) {
postInferenceCallback.run();
}
}
protected int getScreenOrientation() {
switch (getWindowManager().getDefaultDisplay().getRotation()) {
case Surface.ROTATION_270:
return 270;
case Surface.ROTATION_180:
return 180;
case Surface.ROTATION_90:
return 90;
default:
return 0;
}
}
protected abstract void processImage();
protected abstract void onPreviewSizeChosen(final Size size, final int rotation);
protected abstract int getLayoutId();
protected abstract Size getDesiredPreviewFrameSize();
}
/*
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.app.Activity;
import android.app.AlertDialog;
import android.app.Dialog;
import android.app.DialogFragment;
import android.app.Fragment;
import android.content.Context;
import android.content.DialogInterface;
import android.content.res.Configuration;
import android.graphics.ImageFormat;
import android.graphics.Matrix;
import android.graphics.RectF;
import android.graphics.SurfaceTexture;
import android.hardware.camera2.CameraAccessException;
import android.hardware.camera2.CameraCaptureSession;
import android.hardware.camera2.CameraCharacteristics;
import android.hardware.camera2.CameraDevice;
import android.hardware.camera2.CameraManager;
import android.hardware.camera2.CaptureRequest;
import android.hardware.camera2.CaptureResult;
import android.hardware.camera2.TotalCaptureResult;
import android.hardware.camera2.params.StreamConfigurationMap;
import android.media.ImageReader;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.text.TextUtils;
import android.util.Size;
import android.util.SparseIntArray;
import android.view.LayoutInflater;
import android.view.Surface;
import android.view.TextureView;
import android.view.View;
import android.view.ViewGroup;
import android.widget.Toast;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
public class CameraConnectionFragment extends Fragment {
private static final Logger LOGGER = new Logger();
/**
* The camera preview size will be chosen to be the smallest frame by pixel size capable of
* containing a DESIRED_SIZE x DESIRED_SIZE square.
*/
private static final int MINIMUM_PREVIEW_SIZE = 320;
/**
* Conversion from screen rotation to JPEG orientation.
*/
private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
private static final String FRAGMENT_DIALOG = "dialog";
static {
ORIENTATIONS.append(Surface.ROTATION_0, 90);
ORIENTATIONS.append(Surface.ROTATION_90, 0);
ORIENTATIONS.append(Surface.ROTATION_180, 270);
ORIENTATIONS.append(Surface.ROTATION_270, 180);
}
/**
* {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a
* {@link TextureView}.
*/
private final TextureView.SurfaceTextureListener surfaceTextureListener =
new TextureView.SurfaceTextureListener() {
@Override
public void onSurfaceTextureAvailable(
final SurfaceTexture texture, final int width, final int height) {
openCamera(width, height);
}
@Override
public void onSurfaceTextureSizeChanged(
final SurfaceTexture texture, final int width, final int height) {
configureTransform(width, height);
}
@Override
public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) {
return true;
}
@Override
public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
};
/**
* Callback for Activities to use to initialize their data once the
* selected preview size is known.
*/
public interface ConnectionCallback {
void onPreviewSizeChosen(Size size, int cameraRotation);
}
/**
* ID of the current {@link CameraDevice}.
*/
private String cameraId;
/**
* An {@link AutoFitTextureView} for camera preview.
*/
private AutoFitTextureView textureView;
/**
* A {@link CameraCaptureSession } for camera preview.
*/
private CameraCaptureSession captureSession;
/**
* A reference to the opened {@link CameraDevice}.
*/
private CameraDevice cameraDevice;
/**
* The rotation in degrees of the camera sensor from the display.
*/
private Integer sensorOrientation;
/**
* The {@link android.util.Size} of camera preview.
*/
private Size previewSize;
/**
* {@link android.hardware.camera2.CameraDevice.StateCallback}
* is called when {@link CameraDevice} changes its state.
*/
private final CameraDevice.StateCallback stateCallback =
new CameraDevice.StateCallback() {
@Override
public void onOpened(final CameraDevice cd) {
// This method is called when the camera is opened. We start camera preview here.
cameraOpenCloseLock.release();
cameraDevice = cd;
createCameraPreviewSession();
}
@Override
public void onDisconnected(final CameraDevice cd) {
cameraOpenCloseLock.release();
cd.close();
cameraDevice = null;
}
@Override
public void onError(final CameraDevice cd, final int error) {
cameraOpenCloseLock.release();
cd.close();
cameraDevice = null;
final Activity activity = getActivity();
if (null != activity) {
activity.finish();
}
}
};
/**
* An additional thread for running tasks that shouldn't block the UI.
*/
private HandlerThread backgroundThread;
/**
* A {@link Handler} for running tasks in the background.
*/
private Handler backgroundHandler;
/**
* An {@link ImageReader} that handles preview frame capture.
*/
private ImageReader previewReader;
/**
* {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview
*/
private CaptureRequest.Builder previewRequestBuilder;
/**
* {@link CaptureRequest} generated by {@link #previewRequestBuilder}
*/
private CaptureRequest previewRequest;
/**
* A {@link Semaphore} to prevent the app from exiting before closing the camera.
*/
private final Semaphore cameraOpenCloseLock = new Semaphore(1);
/**
* A {@link OnImageAvailableListener} to receive frames as they are available.
*/
private final OnImageAvailableListener imageListener;
/** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */
private final Size inputSize;
/**
* The layout identifier to inflate for this Fragment.
*/
private final int layout;
private final ConnectionCallback cameraConnectionCallback;
private CameraConnectionFragment(
final ConnectionCallback connectionCallback,
final OnImageAvailableListener imageListener,
final int layout,
final Size inputSize) {
this.cameraConnectionCallback = connectionCallback;
this.imageListener = imageListener;
this.layout = layout;
this.inputSize = inputSize;
}
/**
* Shows a {@link Toast} on the UI thread.
*
* @param text The message to show
*/
private void showToast(final String text) {
final Activity activity = getActivity();
if (activity != null) {
activity.runOnUiThread(
new Runnable() {
@Override
public void run() {
Toast.makeText(activity, text, Toast.LENGTH_SHORT).show();
}
});
}
}
/**
* Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose
* width and height are at least as large as the minimum of both, or an exact match if possible.
*
* @param choices The list of sizes that the camera supports for the intended output class
* @param width The minimum desired width
* @param height The minimum desired height
* @return The optimal {@code Size}, or an arbitrary one if none were big enough
*/
protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) {
final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE);
final Size desiredSize = new Size(width, height);
// Collect the supported resolutions that are at least as big as the preview Surface
boolean exactSizeFound = false;
final List<Size> bigEnough = new ArrayList<Size>();
final List<Size> tooSmall = new ArrayList<Size>();
for (final Size option : choices) {
if (option.equals(desiredSize)) {
// Set the size but don't return yet so that remaining sizes will still be logged.
exactSizeFound = true;
}
if (option.getHeight() >= minSize && option.getWidth() >= minSize) {
bigEnough.add(option);
} else {
tooSmall.add(option);
}
}
LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize);
LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]");
LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]");
if (exactSizeFound) {
LOGGER.i("Exact size match found.");
return desiredSize;
}
// Pick the smallest of those, assuming we found any
if (bigEnough.size() > 0) {
final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea());
LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight());
return chosenSize;
} else {
LOGGER.e("Couldn't find any suitable preview size");
return choices[0];
}
}
public static CameraConnectionFragment newInstance(
final ConnectionCallback callback,
final OnImageAvailableListener imageListener,
final int layout,
final Size inputSize) {
return new CameraConnectionFragment(callback, imageListener, layout, inputSize);
}
@Override
public View onCreateView(
final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) {
return inflater.inflate(layout, container, false);
}
@Override
public void onViewCreated(final View view, final Bundle savedInstanceState) {
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
}
@Override
public void onActivityCreated(final Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
}
@Override
public void onResume() {
super.onResume();
startBackgroundThread();
// When the screen is turned off and turned back on, the SurfaceTexture is already
// available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
// a camera and start preview from here (otherwise, we wait until the surface is ready in
// the SurfaceTextureListener).
if (textureView.isAvailable()) {
openCamera(textureView.getWidth(), textureView.getHeight());
} else {
textureView.setSurfaceTextureListener(surfaceTextureListener);
}
}
@Override
public void onPause() {
closeCamera();
stopBackgroundThread();
super.onPause();
}
public void setCamera(String cameraId) {
this.cameraId = cameraId;
}
/**
* Sets up member variables related to camera.
*/
private void setUpCameraOutputs() {
final Activity activity = getActivity();
final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
try {
final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
final StreamConfigurationMap map =
characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
// For still image captures, we use the largest available size.
final Size largest =
Collections.max(
Arrays.asList(map.getOutputSizes(ImageFormat.YUV_420_888)),
new CompareSizesByArea());
sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION);
// Danger, W.R.! Attempting to use too large a preview size could exceed the camera
// bus' bandwidth limitation, resulting in gorgeous previews but the storage of
// garbage capture data.
previewSize =
chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class),
inputSize.getWidth(),
inputSize.getHeight());
// We fit the aspect ratio of TextureView to the size of preview we picked.
final int orientation = getResources().getConfiguration().orientation;
if (orientation == Configuration.ORIENTATION_LANDSCAPE) {
textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight());
} else {
textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth());
}
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
} catch (final NullPointerException e) {
// Currently an NPE is thrown when the Camera2API is used but not supported on the
// device this code runs.
// TODO(andrewharp): abstract ErrorDialog/RuntimeException handling out into new method and
// reuse throughout app.
ErrorDialog.newInstance(getString(R.string.camera_error))
.show(getChildFragmentManager(), FRAGMENT_DIALOG);
throw new RuntimeException(getString(R.string.camera_error));
}
cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation);
}
/**
* Opens the camera specified by {@link CameraConnectionFragment#cameraId}.
*/
private void openCamera(final int width, final int height) {
setUpCameraOutputs();
configureTransform(width, height);
final Activity activity = getActivity();
final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
try {
if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
throw new RuntimeException("Time out waiting to lock camera opening.");
}
manager.openCamera(cameraId, stateCallback, backgroundHandler);
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
} catch (final InterruptedException e) {
throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
}
}
/**
* Closes the current {@link CameraDevice}.
*/
private void closeCamera() {
try {
cameraOpenCloseLock.acquire();
if (null != captureSession) {
captureSession.close();
captureSession = null;
}
if (null != cameraDevice) {
cameraDevice.close();
cameraDevice = null;
}
if (null != previewReader) {
previewReader.close();
previewReader = null;
}
} catch (final InterruptedException e) {
throw new RuntimeException("Interrupted while trying to lock camera closing.", e);
} finally {
cameraOpenCloseLock.release();
}
}
/**
* Starts a background thread and its {@link Handler}.
*/
private void startBackgroundThread() {
backgroundThread = new HandlerThread("ImageListener");
backgroundThread.start();
backgroundHandler = new Handler(backgroundThread.getLooper());
}
/**
* Stops the background thread and its {@link Handler}.
*/
private void stopBackgroundThread() {
backgroundThread.quitSafely();
try {
backgroundThread.join();
backgroundThread = null;
backgroundHandler = null;
} catch (final InterruptedException e) {
LOGGER.e(e, "Exception!");
}
}
private final CameraCaptureSession.CaptureCallback captureCallback =
new CameraCaptureSession.CaptureCallback() {
@Override
public void onCaptureProgressed(
final CameraCaptureSession session,
final CaptureRequest request,
final CaptureResult partialResult) {}
@Override
public void onCaptureCompleted(
final CameraCaptureSession session,
final CaptureRequest request,
final TotalCaptureResult result) {}
};
/**
* Creates a new {@link CameraCaptureSession} for camera preview.
*/
private void createCameraPreviewSession() {
try {
final SurfaceTexture texture = textureView.getSurfaceTexture();
assert texture != null;
// We configure the size of default buffer to be the size of camera preview we want.
texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());
// This is the output Surface we need to start preview.
final Surface surface = new Surface(texture);
// We set up a CaptureRequest.Builder with the output Surface.
previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
previewRequestBuilder.addTarget(surface);
LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight());
// Create the reader for the preview frames.
previewReader =
ImageReader.newInstance(
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
previewRequestBuilder.addTarget(previewReader.getSurface());
// Here, we create a CameraCaptureSession for camera preview.
cameraDevice.createCaptureSession(
Arrays.asList(surface, previewReader.getSurface()),
new CameraCaptureSession.StateCallback() {
@Override
public void onConfigured(final CameraCaptureSession cameraCaptureSession) {
// The camera is already closed
if (null == cameraDevice) {
return;
}
// When the session is ready, we start displaying the preview.
captureSession = cameraCaptureSession;
try {
// Auto focus should be continuous for camera preview.
previewRequestBuilder.set(
CaptureRequest.CONTROL_AF_MODE,
CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
// Flash is automatically enabled when necessary.
previewRequestBuilder.set(
CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH);
// Finally, we start displaying the camera preview.
previewRequest = previewRequestBuilder.build();
captureSession.setRepeatingRequest(
previewRequest, captureCallback, backgroundHandler);
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
}
}
@Override
public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) {
showToast("Failed");
}
},
null);
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
}
}
/**
* Configures the necessary {@link android.graphics.Matrix} transformation to `mTextureView`.
* This method should be called after the camera preview size is determined in
* setUpCameraOutputs and also the size of `mTextureView` is fixed.
*
* @param viewWidth The width of `mTextureView`
* @param viewHeight The height of `mTextureView`
*/
private void configureTransform(final int viewWidth, final int viewHeight) {
final Activity activity = getActivity();
if (null == textureView || null == previewSize || null == activity) {
return;
}
final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation();
final Matrix matrix = new Matrix();
final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight);
final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth());
final float centerX = viewRect.centerX();
final float centerY = viewRect.centerY();
if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) {
bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY());
matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL);
final float scale =
Math.max(
(float) viewHeight / previewSize.getHeight(),
(float) viewWidth / previewSize.getWidth());
matrix.postScale(scale, scale, centerX, centerY);
matrix.postRotate(90 * (rotation - 2), centerX, centerY);
} else if (Surface.ROTATION_180 == rotation) {
matrix.postRotate(180, centerX, centerY);
}
textureView.setTransform(matrix);
}
/**
* Compares two {@code Size}s based on their areas.
*/
static class CompareSizesByArea implements Comparator<Size> {
@Override
public int compare(final Size lhs, final Size rhs) {
// We cast here to ensure the multiplications won't overflow
return Long.signum(
(long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight());
}
}
/**
* Shows an error message dialog.
*/
public static class ErrorDialog extends DialogFragment {
private static final String ARG_MESSAGE = "message";
public static ErrorDialog newInstance(final String message) {
final ErrorDialog dialog = new ErrorDialog();
final Bundle args = new Bundle();
args.putString(ARG_MESSAGE, message);
dialog.setArguments(args);
return dialog;
}
@Override
public Dialog onCreateDialog(final Bundle savedInstanceState) {
final Activity activity = getActivity();
return new AlertDialog.Builder(activity)
.setMessage(getArguments().getString(ARG_MESSAGE))
.setPositiveButton(
android.R.string.ok,
new DialogInterface.OnClickListener() {
@Override
public void onClick(final DialogInterface dialogInterface, final int i) {
activity.finish();
}
})
.create();
}
}
}
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.graphics.Bitmap;
import android.graphics.RectF;
import java.util.List;
/**
* Generic interface for interacting with different recognition engines.
*/
public interface Classifier {
/**
* An immutable result returned by a Classifier describing what was recognized.
*/
public class Recognition {
/**
* A unique identifier for what has been recognized. Specific to the class, not the instance of
* the object.
*/
private final String id;
/**
* Display name for the recognition.
*/
private final String title;
/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private final Float confidence;
/** Optional location within the source image for the location of the recognized object. */
private RectF location;
public Recognition(
final String id, final String title, final Float confidence, final RectF location) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
public RectF getLocation() {
return new RectF(location);
}
public void setLocation(RectF location) {
this.location = location;
}
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
if (location != null) {
resultString += location + " ";
}
return resultString.trim();
}
}
List<Recognition> recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug);
String getStatString();
void close();
}
/*
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import android.view.Surface;
import java.util.List;
import java.util.Vector;
import org.tensorflow.demo.OverlayView.DrawCallback;
import org.tensorflow.demo.env.BorderedText;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
protected static final boolean SAVE_PREVIEW_BITMAP = false;
private ResultsView resultsView;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
private Bitmap cropCopyBitmap = null;
private long lastProcessingTimeMs;
// These are the settings for the original v1 Inception model. If you want to
// use a model that's been produced from the TensorFlow for Poets codelab,
// you'll need to set IMAGE_SIZE = 299, IMAGE_MEAN = 128, IMAGE_STD = 128,
// INPUT_NAME = "Mul", and OUTPUT_NAME = "final_result".
// You'll also need to update the MODEL_FILE and LABEL_FILE paths to point to
// the ones you produced.
//
// To use v3 Inception model, strip the DecodeJpeg Op from your retrained
// model first:
//
// python strip_unused.py \
// --input_graph=<retrained-pb-file> \
// --output_graph=<your-stripped-pb-file> \
// --input_node_names="Mul" \
// --output_node_names="final_result" \
// --input_binary=true
private static final int INPUT_SIZE = 224;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;
private static final String INPUT_NAME = "input";
private static final String OUTPUT_NAME = "output";
private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
private static final String LABEL_FILE =
"file:///android_asset/imagenet_comp_graph_label_strings.txt";
private static final boolean MAINTAIN_ASPECT = true;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
private Integer sensorOrientation;
private Classifier classifier;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private BorderedText borderedText;
@Override
protected int getLayoutId() {
return R.layout.camera_connection_fragment;
}
@Override
protected Size getDesiredPreviewFrameSize() {
return DESIRED_PREVIEW_SIZE;
}
private static final float TEXT_SIZE_DIP = 10;
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx = TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
classifier =
TensorFlowImageClassifier.create(
getAssets(),
MODEL_FILE,
LABEL_FILE,
INPUT_SIZE,
IMAGE_MEAN,
IMAGE_STD,
INPUT_NAME,
OUTPUT_NAME);
previewWidth = size.getWidth();
previewHeight = size.getHeight();
sensorOrientation = rotation - getScreenOrientation();
LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
INPUT_SIZE, INPUT_SIZE,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
addCallback(
new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
renderDebug(canvas);
}
});
}
@Override
protected void processImage() {
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(croppedBitmap);
}
runInBackground(
new Runnable() {
@Override
public void run() {
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
LOGGER.i("Detect: %s", results);
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
if (resultsView == null) {
resultsView = (ResultsView) findViewById(R.id.results);
}
resultsView.setResults(results);
requestRender();
readyForNextImage();
}
});
}
@Override
public void onSetDebug(boolean debug) {
classifier.enableStatLogging(debug);
}
private void renderDebug(final Canvas canvas) {
if (!isDebug()) {
return;
}
final Bitmap copy = cropCopyBitmap;
if (copy != null) {
final Matrix matrix = new Matrix();
final float scaleFactor = 2;
matrix.postScale(scaleFactor, scaleFactor);
matrix.postTranslate(
canvas.getWidth() - copy.getWidth() * scaleFactor,
canvas.getHeight() - copy.getHeight() * scaleFactor);
canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<String>();
if (classifier != null) {
String statString = classifier.getStatString();
String[] statLines = statString.split("\n");
for (String line : statLines) {
lines.add(line);
}
}
lines.add("Frame: " + previewWidth + "x" + previewHeight);
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
lines.add("Rotation: " + sensorOrientation);
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
}
}
}
/*
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import android.view.Surface;
import android.widget.Toast;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
import org.tensorflow.demo.OverlayView.DrawCallback;
import org.tensorflow.demo.env.BorderedText;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.tracking.MultiBoxTracker;
import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
/**
* An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
* objects.
*/
public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
// Configuration values for the prepackaged multibox model.
private static final int MB_INPUT_SIZE = 224;
private static final int MB_IMAGE_MEAN = 128;
private static final float MB_IMAGE_STD = 128;
private static final String MB_INPUT_NAME = "ResizeBilinear";
private static final String MB_OUTPUT_LOCATIONS_NAME = "output_locations/Reshape";
private static final String MB_OUTPUT_SCORES_NAME = "output_scores/Reshape";
private static final String MB_MODEL_FILE = "file:///android_asset/multibox_model.pb";
private static final String MB_LOCATION_FILE =
"file:///android_asset/multibox_location_priors.txt";
private static final int TF_OD_API_INPUT_SIZE = 300;
private static final String TF_OD_API_MODEL_FILE =
"file:///android_asset/ssd_mobilenet_v1_android_export.pb";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
// Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and
// must be manually placed in the assets/ directory by the user.
// Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via
// DarkFlow (https://github.com/thtrieu/darkflow). Sample command:
// ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise
private static final String YOLO_MODEL_FILE = "file:///android_asset/yolov3.pb";
private static final int YOLO_INPUT_SIZE = 416;
private static final String YOLO_INPUT_NAME = "input";
private static final String YOLO_OUTPUT_NAMES = "output";
private static final int YOLO_BLOCK_SIZE = 32;
// Which detection model to use: by default uses Tensorflow Object Detection API frozen
// checkpoints. Optionally use legacy Multibox (trained using an older version of the API)
// or YOLO.
private enum DetectorMode {
TF_OD_API, MULTIBOX, YOLO;
}
private static final DetectorMode MODE = DetectorMode.YOLO;
// Minimum detection confidence to track a detection.
private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f;
private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f;
private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f;
private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
private static final boolean SAVE_PREVIEW_BITMAP = false;
private static final float TEXT_SIZE_DIP = 10;
private Integer sensorOrientation;
private Classifier detector;
private long lastProcessingTimeMs;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
private Bitmap cropCopyBitmap = null;
private boolean computingDetection = false;
private long timestamp = 0;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private MultiBoxTracker tracker;
private byte[] luminanceCopy;
private BorderedText borderedText;
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx =
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
tracker = new MultiBoxTracker(this);
int cropSize = TF_OD_API_INPUT_SIZE;
if (MODE == DetectorMode.YOLO) {
detector =
TensorFlowYoloDetector.create(
getAssets(),
YOLO_MODEL_FILE,
YOLO_INPUT_SIZE,
YOLO_INPUT_NAME,
YOLO_OUTPUT_NAMES,
YOLO_BLOCK_SIZE);
cropSize = YOLO_INPUT_SIZE;
} else if (MODE == DetectorMode.MULTIBOX) {
detector =
TensorFlowMultiBoxDetector.create(
getAssets(),
MB_MODEL_FILE,
MB_LOCATION_FILE,
MB_IMAGE_MEAN,
MB_IMAGE_STD,
MB_INPUT_NAME,
MB_OUTPUT_LOCATIONS_NAME,
MB_OUTPUT_SCORES_NAME);
cropSize = MB_INPUT_SIZE;
} else {
try {
detector = TensorFlowObjectDetectionAPIModel.create(
getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
LOGGER.e(e, "Exception initializing classifier!");
Toast toast =
Toast.makeText(
getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
toast.show();
finish();
}
}
previewWidth = size.getWidth();
previewHeight = size.getHeight();
sensorOrientation = rotation - getScreenOrientation();
LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
trackingOverlay.addCallback(
new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
tracker.draw(canvas);
if (isDebug()) {
tracker.drawDebug(canvas);
}
}
});
addCallback(
new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
if (!isDebug()) {
return;
}
final Bitmap copy = cropCopyBitmap;
if (copy == null) {
return;
}
final int backgroundColor = Color.argb(100, 0, 0, 0);
canvas.drawColor(backgroundColor);
final Matrix matrix = new Matrix();
final float scaleFactor = 2;
matrix.postScale(scaleFactor, scaleFactor);
matrix.postTranslate(
canvas.getWidth() - copy.getWidth() * scaleFactor,
canvas.getHeight() - copy.getHeight() * scaleFactor);
canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<String>();
if (detector != null) {
final String statString = detector.getStatString();
final String[] statLines = statString.split("\n");
for (final String line : statLines) {
lines.add(line);
}
}
lines.add("");
lines.add("Frame: " + previewWidth + "x" + previewHeight);
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
lines.add("Rotation: " + sensorOrientation);
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
}
});
}
OverlayView trackingOverlay;
@Override
protected void processImage() {
++timestamp;
final long currTimestamp = timestamp;
byte[] originalLuminance = getLuminance();
tracker.onFrame(
previewWidth,
previewHeight,
getLuminanceStride(),
sensorOrientation,
originalLuminance,
timestamp);
trackingOverlay.postInvalidate();
// No mutex needed as this method is not reentrant.
if (computingDetection) {
readyForNextImage();
return;
}
computingDetection = true;
LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
if (luminanceCopy == null) {
luminanceCopy = new byte[originalLuminance.length];
}
System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length);
readyForNextImage();
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(croppedBitmap);
}
runInBackground(
new Runnable() {
@Override
public void run() {
LOGGER.i("Running detection on image " + currTimestamp);
final long startTime = SystemClock.uptimeMillis();
final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
final Canvas canvas = new Canvas(cropCopyBitmap);
final Paint paint = new Paint();
paint.setColor(Color.RED);
paint.setStyle(Style.STROKE);
paint.setStrokeWidth(2.0f);
float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
switch (MODE) {
case TF_OD_API:
minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
break;
case MULTIBOX:
minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX;
break;
case YOLO:
minimumConfidence = MINIMUM_CONFIDENCE_YOLO;
break;
}
final List<Classifier.Recognition> mappedRecognitions =
new LinkedList<Classifier.Recognition>();
for (final Classifier.Recognition result : results) {
final RectF location = result.getLocation();
if (location != null && result.getConfidence() >= minimumConfidence) {
canvas.drawRect(location, paint);
cropToFrameTransform.mapRect(location);
result.setLocation(location);
mappedRecognitions.add(result);
}
}
tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp);
trackingOverlay.postInvalidate();
requestRender();
computingDetection = false;
}
});
}
@Override
protected int getLayoutId() {
return R.layout.camera_connection_fragment_tracking;
}
@Override
protected Size getDesiredPreviewFrameSize() {
return DESIRED_PREVIEW_SIZE;
}
@Override
public void onSetDebug(final boolean debug) {
detector.enableStatLogging(debug);
}
}
package org.tensorflow.demo;
/*
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import android.app.Fragment;
import android.graphics.SurfaceTexture;
import android.hardware.Camera;
import android.hardware.Camera.CameraInfo;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.util.Size;
import android.util.SparseIntArray;
import android.view.LayoutInflater;
import android.view.Surface;
import android.view.TextureView;
import android.view.View;
import android.view.ViewGroup;
import java.io.IOException;
import java.util.List;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
public class LegacyCameraConnectionFragment extends Fragment {
private Camera camera;
private static final Logger LOGGER = new Logger();
private Camera.PreviewCallback imageListener;
private Size desiredSize;
/**
* The layout identifier to inflate for this Fragment.
*/
private int layout;
public LegacyCameraConnectionFragment(
final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) {
this.imageListener = imageListener;
this.layout = layout;
this.desiredSize = desiredSize;
}
/**
* Conversion from screen rotation to JPEG orientation.
*/
private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
ORIENTATIONS.append(Surface.ROTATION_0, 90);
ORIENTATIONS.append(Surface.ROTATION_90, 0);
ORIENTATIONS.append(Surface.ROTATION_180, 270);
ORIENTATIONS.append(Surface.ROTATION_270, 180);
}
/**
* {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a
* {@link TextureView}.
*/
private final TextureView.SurfaceTextureListener surfaceTextureListener =
new TextureView.SurfaceTextureListener() {
@Override
public void onSurfaceTextureAvailable(
final SurfaceTexture texture, final int width, final int height) {
int index = getCameraId();
camera = Camera.open(index);
try {
Camera.Parameters parameters = camera.getParameters();
List<String> focusModes = parameters.getSupportedFocusModes();
if (focusModes != null
&& focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) {
parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE);
}
List<Camera.Size> cameraSizes = parameters.getSupportedPreviewSizes();
Size[] sizes = new Size[cameraSizes.size()];
int i = 0;
for (Camera.Size size : cameraSizes) {
sizes[i++] = new Size(size.width, size.height);
}
Size previewSize =
CameraConnectionFragment.chooseOptimalSize(
sizes, desiredSize.getWidth(), desiredSize.getHeight());
parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight());
camera.setDisplayOrientation(90);
camera.setParameters(parameters);
camera.setPreviewTexture(texture);
} catch (IOException exception) {
camera.release();
}
camera.setPreviewCallbackWithBuffer(imageListener);
Camera.Size s = camera.getParameters().getPreviewSize();
camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]);
textureView.setAspectRatio(s.height, s.width);
camera.startPreview();
}
@Override
public void onSurfaceTextureSizeChanged(
final SurfaceTexture texture, final int width, final int height) {}
@Override
public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) {
return true;
}
@Override
public void onSurfaceTextureUpdated(final SurfaceTexture texture) {}
};
/**
* An {@link AutoFitTextureView} for camera preview.
*/
private AutoFitTextureView textureView;
/**
* An additional thread for running tasks that shouldn't block the UI.
*/
private HandlerThread backgroundThread;
@Override
public View onCreateView(
final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) {
return inflater.inflate(layout, container, false);
}
@Override
public void onViewCreated(final View view, final Bundle savedInstanceState) {
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
}
@Override
public void onActivityCreated(final Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
}
@Override
public void onResume() {
super.onResume();
startBackgroundThread();
// When the screen is turned off and turned back on, the SurfaceTexture is already
// available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open
// a camera and start preview from here (otherwise, we wait until the surface is ready in
// the SurfaceTextureListener).
if (textureView.isAvailable()) {
camera.startPreview();
} else {
textureView.setSurfaceTextureListener(surfaceTextureListener);
}
}
@Override
public void onPause() {
stopCamera();
stopBackgroundThread();
super.onPause();
}
/**
* Starts a background thread and its {@link Handler}.
*/
private void startBackgroundThread() {
backgroundThread = new HandlerThread("CameraBackground");
backgroundThread.start();
}
/**
* Stops the background thread and its {@link Handler}.
*/
private void stopBackgroundThread() {
backgroundThread.quitSafely();
try {
backgroundThread.join();
backgroundThread = null;
} catch (final InterruptedException e) {
LOGGER.e(e, "Exception!");
}
}
protected void stopCamera() {
if (camera != null) {
camera.stopPreview();
camera.setPreviewCallback(null);
camera.release();
camera = null;
}
}
private int getCameraId() {
CameraInfo ci = new CameraInfo();
for (int i = 0; i < Camera.getNumberOfCameras(); i++) {
Camera.getCameraInfo(i, ci);
if (ci.facing == CameraInfo.CAMERA_FACING_BACK)
return i;
}
return -1; // No camera found
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.Context;
import android.graphics.Canvas;
import android.util.AttributeSet;
import android.view.View;
import java.util.LinkedList;
import java.util.List;
/**
* A simple View providing a render callback to other classes.
*/
public class OverlayView extends View {
private final List<DrawCallback> callbacks = new LinkedList<DrawCallback>();
public OverlayView(final Context context, final AttributeSet attrs) {
super(context, attrs);
}
/**
* Interface defining the callback for client classes.
*/
public interface DrawCallback {
public void drawCallback(final Canvas canvas);
}
public void addCallback(final DrawCallback callback) {
callbacks.add(callback);
}
@Override
public synchronized void draw(final Canvas canvas) {
for (final DrawCallback callback : callbacks) {
callback.drawCallback(canvas);
}
}
}
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.Context;
import android.graphics.Canvas;
import android.graphics.Paint;
import android.util.AttributeSet;
import android.util.TypedValue;
import android.view.View;
import org.tensorflow.demo.Classifier.Recognition;
import java.util.List;
public class RecognitionScoreView extends View implements ResultsView {
private static final float TEXT_SIZE_DIP = 24;
private List<Recognition> results;
private final float textSizePx;
private final Paint fgPaint;
private final Paint bgPaint;
public RecognitionScoreView(final Context context, final AttributeSet set) {
super(context, set);
textSizePx =
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
fgPaint = new Paint();
fgPaint.setTextSize(textSizePx);
bgPaint = new Paint();
bgPaint.setColor(0xcc4285f4);
}
@Override
public void setResults(final List<Recognition> results) {
this.results = results;
postInvalidate();
}
@Override
public void onDraw(final Canvas canvas) {
final int x = 10;
int y = (int) (fgPaint.getTextSize() * 1.5f);
canvas.drawPaint(bgPaint);
if (results != null) {
for (final Recognition recog : results) {
canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint);
y += fgPaint.getTextSize() * 1.5f;
}
}
}
}
/*
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.util.Log;
import android.util.Pair;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
/** Reads in results from an instantaneous audio recognition model and smoothes them over time. */
public class RecognizeCommands {
// Configuration settings.
private List<String> labels = new ArrayList<String>();
private long averageWindowDurationMs;
private float detectionThreshold;
private int suppressionMs;
private int minimumCount;
private long minimumTimeBetweenSamplesMs;
// Working variables.
private Deque<Pair<Long, float[]>> previousResults = new ArrayDeque<Pair<Long, float[]>>();
private String previousTopLabel;
private int labelsCount;
private long previousTopLabelTime;
private float previousTopLabelScore;
private static final String SILENCE_LABEL = "_silence_";
private static final long MINIMUM_TIME_FRACTION = 4;
public RecognizeCommands(
List<String> inLabels,
long inAverageWindowDurationMs,
float inDetectionThreshold,
int inSuppressionMS,
int inMinimumCount,
long inMinimumTimeBetweenSamplesMS) {
labels = inLabels;
averageWindowDurationMs = inAverageWindowDurationMs;
detectionThreshold = inDetectionThreshold;
suppressionMs = inSuppressionMS;
minimumCount = inMinimumCount;
labelsCount = inLabels.size();
previousTopLabel = SILENCE_LABEL;
previousTopLabelTime = Long.MIN_VALUE;
previousTopLabelScore = 0.0f;
minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS;
}
/** Holds information about what's been recognized. */
public static class RecognitionResult {
public final String foundCommand;
public final float score;
public final boolean isNewCommand;
public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) {
foundCommand = inFoundCommand;
score = inScore;
isNewCommand = inIsNewCommand;
}
}
private static class ScoreForSorting implements Comparable<ScoreForSorting> {
public final float score;
public final int index;
public ScoreForSorting(float inScore, int inIndex) {
score = inScore;
index = inIndex;
}
@Override
public int compareTo(ScoreForSorting other) {
if (this.score > other.score) {
return -1;
} else if (this.score < other.score) {
return 1;
} else {
return 0;
}
}
}
public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) {
if (currentResults.length != labelsCount) {
throw new RuntimeException(
"The results for recognition should contain "
+ labelsCount
+ " elements, but there are "
+ currentResults.length);
}
if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) {
throw new RuntimeException(
"You must feed results in increasing time order, but received a timestamp of "
+ currentTimeMS
+ " that was earlier than the previous one of "
+ previousResults.getFirst().first);
}
final int howManyResults = previousResults.size();
// Ignore any results that are coming in too frequently.
if (howManyResults > 1) {
final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first;
if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) {
return new RecognitionResult(previousTopLabel, previousTopLabelScore, false);
}
}
// Add the latest results to the head of the queue.
previousResults.addLast(new Pair<Long, float[]>(currentTimeMS, currentResults));
// Prune any earlier results that are too old for the averaging window.
final long timeLimit = currentTimeMS - averageWindowDurationMs;
while (previousResults.getFirst().first < timeLimit) {
previousResults.removeFirst();
}
// If there are too few results, assume the result will be unreliable and
// bail.
final long earliestTime = previousResults.getFirst().first;
final long samplesDuration = currentTimeMS - earliestTime;
if ((howManyResults < minimumCount)
|| (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) {
Log.v("RecognizeResult", "Too few results");
return new RecognitionResult(previousTopLabel, 0.0f, false);
}
// Calculate the average score across all the results in the window.
float[] averageScores = new float[labelsCount];
for (Pair<Long, float[]> previousResult : previousResults) {
final float[] scoresTensor = previousResult.second;
int i = 0;
while (i < scoresTensor.length) {
averageScores[i] += scoresTensor[i] / howManyResults;
++i;
}
}
// Sort the averaged results in descending score order.
ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount];
for (int i = 0; i < labelsCount; ++i) {
sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i);
}
Arrays.sort(sortedAverageScores);
// See if the latest top score is enough to trigger a detection.
final int currentTopIndex = sortedAverageScores[0].index;
final String currentTopLabel = labels.get(currentTopIndex);
final float currentTopScore = sortedAverageScores[0].score;
// If we've recently had another label trigger, assume one that occurs too
// soon afterwards is a bad result.
long timeSinceLastTop;
if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) {
timeSinceLastTop = Long.MAX_VALUE;
} else {
timeSinceLastTop = currentTimeMS - previousTopLabelTime;
}
boolean isNewCommand;
if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) {
previousTopLabel = currentTopLabel;
previousTopLabelTime = currentTimeMS;
previousTopLabelScore = currentTopScore;
isNewCommand = true;
} else {
isNewCommand = false;
}
return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand);
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import org.tensorflow.demo.Classifier.Recognition;
import java.util.List;
public interface ResultsView {
public void setResults(final List<Recognition> results);
}
/*
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Demonstrates how to run an audio recognition model in Android.
This example loads a simple speech recognition model trained by the tutorial at
https://www.tensorflow.org/tutorials/audio_training
The model files should be downloaded automatically from the TensorFlow website,
but if you have a custom model you can update the LABEL_FILENAME and
MODEL_FILENAME constants to point to your own files.
The example application displays a list view with all of the known audio labels,
and highlights each one when it thinks it has detected one through the
microphone. The averaging of results to give a more reliable signal happens in
the RecognizeCommands helper class.
*/
package org.tensorflow.demo;
import android.animation.AnimatorInflater;
import android.animation.AnimatorSet;
import android.app.Activity;
import android.content.pm.PackageManager;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ArrayAdapter;
import android.widget.Button;
import android.widget.ListView;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import org.tensorflow.demo.R;
/**
* An activity that listens for audio and then uses a TensorFlow model to detect particular classes,
* by default a small set of action words.
*/
public class SpeechActivity extends Activity {
// Constants that control the behavior of the recognition code and model
// settings. See the audio recognition tutorial for a detailed explanation of
// all these, but you should customize them to match your training settings if
// you are running your own model.
private static final int SAMPLE_RATE = 16000;
private static final int SAMPLE_DURATION_MS = 1000;
private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000);
private static final long AVERAGE_WINDOW_DURATION_MS = 500;
private static final float DETECTION_THRESHOLD = 0.70f;
private static final int SUPPRESSION_MS = 1500;
private static final int MINIMUM_COUNT = 3;
private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30;
private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt";
private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.pb";
private static final String INPUT_DATA_NAME = "decoded_sample_data:0";
private static final String SAMPLE_RATE_NAME = "decoded_sample_data:1";
private static final String OUTPUT_SCORES_NAME = "labels_softmax";
// UI elements.
private static final int REQUEST_RECORD_AUDIO = 13;
private Button quitButton;
private ListView labelsListView;
private static final String LOG_TAG = SpeechActivity.class.getSimpleName();
// Working variables.
short[] recordingBuffer = new short[RECORDING_LENGTH];
int recordingOffset = 0;
boolean shouldContinue = true;
private Thread recordingThread;
boolean shouldContinueRecognition = true;
private Thread recognitionThread;
private final ReentrantLock recordingBufferLock = new ReentrantLock();
private TensorFlowInferenceInterface inferenceInterface;
private List<String> labels = new ArrayList<String>();
private List<String> displayedLabels = new ArrayList<>();
private RecognizeCommands recognizeCommands = null;
@Override
protected void onCreate(Bundle savedInstanceState) {
// Set up the UI.
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_speech);
quitButton = (Button) findViewById(R.id.quit);
quitButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
moveTaskToBack(true);
android.os.Process.killProcess(android.os.Process.myPid());
System.exit(1);
}
});
labelsListView = (ListView) findViewById(R.id.list_view);
// Load the labels for the model, but only display those that don't start
// with an underscore.
String actualFilename = LABEL_FILENAME.split("file:///android_asset/")[1];
Log.i(LOG_TAG, "Reading labels from: " + actualFilename);
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(getAssets().open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
labels.add(line);
if (line.charAt(0) != '_') {
displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1));
}
}
br.close();
} catch (IOException e) {
throw new RuntimeException("Problem reading label file!", e);
}
// Build a list view based on these labels.
ArrayAdapter<String> arrayAdapter =
new ArrayAdapter<String>(this, R.layout.list_text_item, displayedLabels);
labelsListView.setAdapter(arrayAdapter);
// Set up an object to smooth recognition results to increase accuracy.
recognizeCommands =
new RecognizeCommands(
labels,
AVERAGE_WINDOW_DURATION_MS,
DETECTION_THRESHOLD,
SUPPRESSION_MS,
MINIMUM_COUNT,
MINIMUM_TIME_BETWEEN_SAMPLES_MS);
// Load the TensorFlow model.
inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME);
// Start the recording and recognition threads.
requestMicrophonePermission();
startRecording();
startRecognition();
}
private void requestMicrophonePermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(
new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
}
}
@Override
public void onRequestPermissionsResult(
int requestCode, String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_RECORD_AUDIO
&& grantResults.length > 0
&& grantResults[0] == PackageManager.PERMISSION_GRANTED) {
startRecording();
startRecognition();
}
}
public synchronized void startRecording() {
if (recordingThread != null) {
return;
}
shouldContinue = true;
recordingThread =
new Thread(
new Runnable() {
@Override
public void run() {
record();
}
});
recordingThread.start();
}
public synchronized void stopRecording() {
if (recordingThread == null) {
return;
}
shouldContinue = false;
recordingThread = null;
}
private void record() {
android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);
// Estimate the buffer size we'll need for this device.
int bufferSize =
AudioRecord.getMinBufferSize(
SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT);
if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) {
bufferSize = SAMPLE_RATE * 2;
}
short[] audioBuffer = new short[bufferSize / 2];
AudioRecord record =
new AudioRecord(
MediaRecorder.AudioSource.DEFAULT,
SAMPLE_RATE,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_16BIT,
bufferSize);
if (record.getState() != AudioRecord.STATE_INITIALIZED) {
Log.e(LOG_TAG, "Audio Record can't initialize!");
return;
}
record.startRecording();
Log.v(LOG_TAG, "Start recording");
// Loop, gathering audio data and copying it to a round-robin buffer.
while (shouldContinue) {
int numberRead = record.read(audioBuffer, 0, audioBuffer.length);
int maxLength = recordingBuffer.length;
int newRecordingOffset = recordingOffset + numberRead;
int secondCopyLength = Math.max(0, newRecordingOffset - maxLength);
int firstCopyLength = numberRead - secondCopyLength;
// We store off all the data for the recognition thread to access. The ML
// thread will copy out of this buffer into its own, while holding the
// lock, so this should be thread safe.
recordingBufferLock.lock();
try {
System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength);
System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength);
recordingOffset = newRecordingOffset % maxLength;
} finally {
recordingBufferLock.unlock();
}
}
record.stop();
record.release();
}
public synchronized void startRecognition() {
if (recognitionThread != null) {
return;
}
shouldContinueRecognition = true;
recognitionThread =
new Thread(
new Runnable() {
@Override
public void run() {
recognize();
}
});
recognitionThread.start();
}
public synchronized void stopRecognition() {
if (recognitionThread == null) {
return;
}
shouldContinueRecognition = false;
recognitionThread = null;
}
private void recognize() {
Log.v(LOG_TAG, "Start recognition");
short[] inputBuffer = new short[RECORDING_LENGTH];
float[] floatInputBuffer = new float[RECORDING_LENGTH];
float[] outputScores = new float[labels.size()];
String[] outputScoresNames = new String[] {OUTPUT_SCORES_NAME};
int[] sampleRateList = new int[] {SAMPLE_RATE};
// Loop, grabbing recorded data and running the recognition model on it.
while (shouldContinueRecognition) {
// The recording thread places data in this round-robin buffer, so lock to
// make sure there's no writing happening and then copy it to our own
// local version.
recordingBufferLock.lock();
try {
int maxLength = recordingBuffer.length;
int firstCopyLength = maxLength - recordingOffset;
int secondCopyLength = recordingOffset;
System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength);
System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength);
} finally {
recordingBufferLock.unlock();
}
// We need to feed in float values between -1.0f and 1.0f, so divide the
// signed 16-bit inputs.
for (int i = 0; i < RECORDING_LENGTH; ++i) {
floatInputBuffer[i] = inputBuffer[i] / 32767.0f;
}
// Run the model.
inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList);
inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1);
inferenceInterface.run(outputScoresNames);
inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores);
// Use the smoother to figure out if we've had a real recognition event.
long currentTime = System.currentTimeMillis();
final RecognizeCommands.RecognitionResult result =
recognizeCommands.processLatestResults(outputScores, currentTime);
runOnUiThread(
new Runnable() {
@Override
public void run() {
// If we do have a new command, highlight the right list entry.
if (!result.foundCommand.startsWith("_") && result.isNewCommand) {
int labelIndex = -1;
for (int i = 0; i < labels.size(); ++i) {
if (labels.get(i).equals(result.foundCommand)) {
labelIndex = i;
}
}
final View labelView = labelsListView.getChildAt(labelIndex - 2);
AnimatorSet colorAnimation =
(AnimatorSet)
AnimatorInflater.loadAnimator(
SpeechActivity.this, R.animator.color_animation);
colorAnimation.setTarget(labelView);
colorAnimation.start();
}
}
});
try {
// We don't need to run too frequently, so snooze for a bit.
Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS);
} catch (InterruptedException e) {
// Ignore
}
}
Log.v(LOG_TAG, "End recognition");
}
}
/*
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.demo;
import android.app.UiModeManager;
import android.content.Context;
import android.content.res.AssetManager;
import android.content.res.Configuration;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.BitmapFactory;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.Rect;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.Bundle;
import android.os.SystemClock;
import android.util.DisplayMetrics;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import android.view.KeyEvent;
import android.view.MotionEvent;
import android.view.View;
import android.view.View.OnClickListener;
import android.view.View.OnTouchListener;
import android.view.ViewGroup;
import android.widget.BaseAdapter;
import android.widget.Button;
import android.widget.GridView;
import android.widget.ImageView;
import android.widget.RelativeLayout;
import android.widget.Toast;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Vector;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import org.tensorflow.demo.OverlayView.DrawCallback;
import org.tensorflow.demo.env.BorderedText;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
/**
* Sample activity that stylizes the camera preview according to "A Learned Representation For
* Artistic Style" (https://arxiv.org/abs/1610.07629)
*/
public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
private static final String INPUT_NODE = "input";
private static final String STYLE_NODE = "style_num";
private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";
private static final int NUM_STYLES = 26;
private static final boolean SAVE_PREVIEW_BITMAP = false;
// Whether to actively manipulate non-selected sliders so that sum of activations always appears
// to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless.
private static final boolean NORMALIZE_SLIDERS = true;
private static final float TEXT_SIZE_DIP = 12;
private static final boolean DEBUG_MODEL = false;
private static final int[] SIZES = {128, 192, 256, 384, 512, 720};
private static final Size DESIRED_PREVIEW_SIZE = new Size(1280, 720);
// Start at a medium size, but let the user step up through smaller sizes so they don't get
// immediately stuck processing a large image.
private int desiredSizeIndex = -1;
private int desiredSize = 256;
private int initializedSize = 0;
private Integer sensorOrientation;
private long lastProcessingTimeMs;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
private Bitmap cropCopyBitmap = null;
private final float[] styleVals = new float[NUM_STYLES];
private int[] intValues;
private float[] floatValues;
private int frameNum = 0;
private Bitmap textureCopyBitmap;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private BorderedText borderedText;
private TensorFlowInferenceInterface inferenceInterface;
private int lastOtherStyle = 1;
private boolean allZero = false;
private ImageGridAdapter adapter;
private GridView grid;
private final OnTouchListener gridTouchAdapter =
new OnTouchListener() {
ImageSlider slider = null;
@Override
public boolean onTouch(final View v, final MotionEvent event) {
switch (event.getActionMasked()) {
case MotionEvent.ACTION_DOWN:
for (int i = 0; i < NUM_STYLES; ++i) {
final ImageSlider child = adapter.items[i];
final Rect rect = new Rect();
child.getHitRect(rect);
if (rect.contains((int) event.getX(), (int) event.getY())) {
slider = child;
slider.setHilighted(true);
}
}
break;
case MotionEvent.ACTION_MOVE:
if (slider != null) {
final Rect rect = new Rect();
slider.getHitRect(rect);
final float newSliderVal =
(float)
Math.min(
1.0,
Math.max(
0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight()));
setStyle(slider, newSliderVal);
}
break;
case MotionEvent.ACTION_UP:
if (slider != null) {
slider.setHilighted(false);
slider = null;
}
break;
default: // fall out
}
return true;
}
};
@Override
public void onCreate(final Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
}
@Override
protected int getLayoutId() {
return R.layout.camera_connection_fragment_stylize;
}
@Override
protected Size getDesiredPreviewFrameSize() {
return DESIRED_PREVIEW_SIZE;
}
public static Bitmap getBitmapFromAsset(final Context context, final String filePath) {
final AssetManager assetManager = context.getAssets();
Bitmap bitmap = null;
try {
final InputStream inputStream = assetManager.open(filePath);
bitmap = BitmapFactory.decodeStream(inputStream);
} catch (final IOException e) {
LOGGER.e("Error opening bitmap!", e);
}
return bitmap;
}
private class ImageSlider extends ImageView {
private float value = 0.0f;
private boolean hilighted = false;
private final Paint boxPaint;
private final Paint linePaint;
public ImageSlider(final Context context) {
super(context);
value = 0.0f;
boxPaint = new Paint();
boxPaint.setColor(Color.BLACK);
boxPaint.setAlpha(128);
linePaint = new Paint();
linePaint.setColor(Color.WHITE);
linePaint.setStrokeWidth(10.0f);
linePaint.setStyle(Style.STROKE);
}
@Override
public void onDraw(final Canvas canvas) {
super.onDraw(canvas);
final float y = (1.0f - value) * canvas.getHeight();
// If all sliders are zero, don't bother shading anything.
if (!allZero) {
canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint);
}
if (value > 0.0f) {
canvas.drawLine(0, y, canvas.getWidth(), y, linePaint);
}
if (hilighted) {
canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint);
}
}
@Override
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
}
public void setValue(final float value) {
this.value = value;
postInvalidate();
}
public void setHilighted(final boolean highlighted) {
this.hilighted = highlighted;
this.postInvalidate();
}
}
private class ImageGridAdapter extends BaseAdapter {
final ImageSlider[] items = new ImageSlider[NUM_STYLES];
final ArrayList<Button> buttons = new ArrayList<>();
{
final Button sizeButton =
new Button(StylizeActivity.this) {
@Override
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
}
};
sizeButton.setText("" + desiredSize);
sizeButton.setOnClickListener(
new OnClickListener() {
@Override
public void onClick(final View v) {
desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length;
desiredSize = SIZES[desiredSizeIndex];
sizeButton.setText("" + desiredSize);
sizeButton.postInvalidate();
}
});
final Button saveButton =
new Button(StylizeActivity.this) {
@Override
protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
super.onMeasure(widthMeasureSpec, heightMeasureSpec);
setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
}
};
saveButton.setText("save");
saveButton.setTextSize(12);
saveButton.setOnClickListener(
new OnClickListener() {
@Override
public void onClick(final View v) {
if (textureCopyBitmap != null) {
// TODO(andrewharp): Save as jpeg with guaranteed unique filename.
ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png");
Toast.makeText(
StylizeActivity.this,
"Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png",
Toast.LENGTH_LONG)
.show();
}
}
});
buttons.add(sizeButton);
buttons.add(saveButton);
for (int i = 0; i < NUM_STYLES; ++i) {
LOGGER.v("Creating item %d", i);
if (items[i] == null) {
final ImageSlider slider = new ImageSlider(StylizeActivity.this);
final Bitmap bm =
getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg");
slider.setImageBitmap(bm);
items[i] = slider;
}
}
}
@Override
public int getCount() {
return buttons.size() + NUM_STYLES;
}
@Override
public Object getItem(final int position) {
if (position < buttons.size()) {
return buttons.get(position);
} else {
return items[position - buttons.size()];
}
}
@Override
public long getItemId(final int position) {
return getItem(position).hashCode();
}
@Override
public View getView(final int position, final View convertView, final ViewGroup parent) {
if (convertView != null) {
return convertView;
}
return (View) getItem(position);
}
}
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx = TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
previewWidth = size.getWidth();
previewHeight = size.getHeight();
final Display display = getWindowManager().getDefaultDisplay();
final int screenOrientation = display.getRotation();
LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
sensorOrientation = rotation + screenOrientation;
addCallback(
new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
renderDebug(canvas);
}
});
adapter = new ImageGridAdapter();
grid = (GridView) findViewById(R.id.grid_layout);
grid.setAdapter(adapter);
grid.setOnTouchListener(gridTouchAdapter);
// Change UI on Android TV
UiModeManager uiModeManager = (UiModeManager) getSystemService(UI_MODE_SERVICE);
if (uiModeManager.getCurrentModeType() == Configuration.UI_MODE_TYPE_TELEVISION) {
DisplayMetrics displayMetrics = new DisplayMetrics();
getWindowManager().getDefaultDisplay().getMetrics(displayMetrics);
int styleSelectorHeight = displayMetrics.heightPixels;
int styleSelectorWidth = displayMetrics.widthPixels - styleSelectorHeight;
RelativeLayout.LayoutParams layoutParams = new RelativeLayout.LayoutParams(styleSelectorWidth, ViewGroup.LayoutParams.MATCH_PARENT);
// Calculate number of style in a row, so all the style can show up without scrolling
int numOfStylePerRow = 3;
while (styleSelectorWidth / numOfStylePerRow * Math.ceil((float) (adapter.getCount() - 2) / numOfStylePerRow) > styleSelectorHeight) {
numOfStylePerRow++;
}
grid.setNumColumns(numOfStylePerRow);
layoutParams.addRule(RelativeLayout.ALIGN_PARENT_RIGHT);
grid.setLayoutParams(layoutParams);
adapter.buttons.clear();
}
setStyle(adapter.items[0], 1.0f);
}
private void setStyle(final ImageSlider slider, final float value) {
slider.setValue(value);
if (NORMALIZE_SLIDERS) {
// Slider vals correspond directly to the input tensor vals, and normalization is visually
// maintained by remanipulating non-selected sliders.
float otherSum = 0.0f;
for (int i = 0; i < NUM_STYLES; ++i) {
if (adapter.items[i] != slider) {
otherSum += adapter.items[i].value;
}
}
if (otherSum > 0.0) {
float highestOtherVal = 0;
final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f;
for (int i = 0; i < NUM_STYLES; ++i) {
final ImageSlider child = adapter.items[i];
if (child == slider) {
continue;
}
final float newVal = child.value * factor;
child.setValue(newVal > 0.01f ? newVal : 0.0f);
if (child.value > highestOtherVal) {
lastOtherStyle = i;
highestOtherVal = child.value;
}
}
} else {
// Everything else is 0, so just pick a suitable slider to push up when the
// selected one goes down.
if (adapter.items[lastOtherStyle] == slider) {
lastOtherStyle = (lastOtherStyle + 1) % NUM_STYLES;
}
adapter.items[lastOtherStyle].setValue(1.0f - value);
}
}
final boolean lastAllZero = allZero;
float sum = 0.0f;
for (int i = 0; i < NUM_STYLES; ++i) {
sum += adapter.items[i].value;
}
allZero = sum == 0.0f;
// Now update the values used for the input tensor. If nothing is set, mix in everything
// equally. Otherwise everything is normalized to sum to 1.0.
for (int i = 0; i < NUM_STYLES; ++i) {
styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum;
if (lastAllZero != allZero) {
adapter.items[i].postInvalidate();
}
}
}
private void resetPreviewBuffers() {
croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
desiredSize, desiredSize,
sensorOrientation, true);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
intValues = new int[desiredSize * desiredSize];
floatValues = new float[desiredSize * desiredSize * 3];
initializedSize = desiredSize;
}
@Override
protected void processImage() {
if (desiredSize != initializedSize) {
LOGGER.i(
"Initializing at size preview size %dx%d, stylize size %d",
previewWidth, previewHeight, desiredSize);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
desiredSize, desiredSize,
sensorOrientation, true);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
intValues = new int[desiredSize * desiredSize];
floatValues = new float[desiredSize * desiredSize * 3];
initializedSize = desiredSize;
}
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(croppedBitmap);
}
runInBackground(
new Runnable() {
@Override
public void run() {
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
final long startTime = SystemClock.uptimeMillis();
stylizeImage(croppedBitmap);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
textureCopyBitmap = Bitmap.createBitmap(croppedBitmap);
requestRender();
readyForNextImage();
}
});
if (desiredSize != initializedSize) {
resetPreviewBuffers();
}
}
private void stylizeImage(final Bitmap bitmap) {
++frameNum;
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
if (DEBUG_MODEL) {
// Create a white square that steps through a black background 1 pixel per frame.
final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth();
final int centerY = bitmap.getHeight() / 2;
final int squareSize = 10;
for (int i = 0; i < intValues.length; ++i) {
final int x = i % bitmap.getWidth();
final int y = i / bitmap.getHeight();
final float val =
Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f;
floatValues[i * 3] = val;
floatValues[i * 3 + 1] = val;
floatValues[i * 3 + 2] = val;
}
} else {
for (int i = 0; i < intValues.length; ++i) {
final int val = intValues[i];
floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
}
}
// Copy the input data into TensorFlow.
LOGGER.i("Width: %s , Height: %s", bitmap.getWidth(), bitmap.getHeight());
inferenceInterface.feed(
INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3);
inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);
inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug());
inferenceInterface.fetch(OUTPUT_NODE, floatValues);
for (int i = 0; i < intValues.length; ++i) {
intValues[i] =
0xFF000000
| (((int) (floatValues[i * 3] * 255)) << 16)
| (((int) (floatValues[i * 3 + 1] * 255)) << 8)
| ((int) (floatValues[i * 3 + 2] * 255));
}
bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
}
private void renderDebug(final Canvas canvas) {
// TODO(andrewharp): move result display to its own View instead of using debug overlay.
final Bitmap texture = textureCopyBitmap;
if (texture != null) {
final Matrix matrix = new Matrix();
final float scaleFactor =
DEBUG_MODEL
? 4.0f
: Math.min(
(float) canvas.getWidth() / texture.getWidth(),
(float) canvas.getHeight() / texture.getHeight());
matrix.postScale(scaleFactor, scaleFactor);
canvas.drawBitmap(texture, matrix, new Paint());
}
if (!isDebug()) {
return;
}
final Bitmap copy = cropCopyBitmap;
if (copy == null) {
return;
}
canvas.drawColor(0x55000000);
final Matrix matrix = new Matrix();
final float scaleFactor = 2;
matrix.postScale(scaleFactor, scaleFactor);
matrix.postTranslate(
canvas.getWidth() - copy.getWidth() * scaleFactor,
canvas.getHeight() - copy.getHeight() * scaleFactor);
canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<>();
final String[] statLines = inferenceInterface.getStatString().split("\n");
Collections.addAll(lines, statLines);
lines.add("");
lines.add("Frame: " + previewWidth + "x" + previewHeight);
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
lines.add("Rotation: " + sensorOrientation);
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
lines.add("Desired size: " + desiredSize);
lines.add("Initialized size: " + initializedSize);
borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
}
@Override
public boolean onKeyDown(int keyCode, KeyEvent event) {
int moveOffset = 0;
switch (keyCode) {
case KeyEvent.KEYCODE_DPAD_LEFT:
moveOffset = -1;
break;
case KeyEvent.KEYCODE_DPAD_RIGHT:
moveOffset = 1;
break;
case KeyEvent.KEYCODE_DPAD_UP:
moveOffset = -1 * grid.getNumColumns();
break;
case KeyEvent.KEYCODE_DPAD_DOWN:
moveOffset = grid.getNumColumns();
break;
default:
return super.onKeyDown(keyCode, event);
}
// get the highest selected style
int currentSelect = 0;
float highestValue = 0;
for (int i = 0; i < adapter.getCount(); i++) {
if (adapter.items[i].value > highestValue) {
currentSelect = i;
highestValue = adapter.items[i].value;
}
}
setStyle(adapter.items[(currentSelect + moveOffset + adapter.getCount()) % adapter.getCount()], 1);
return true;
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.os.Trace;
import android.util.Log;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;
import org.tensorflow.Operation;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
/** A classifier specialized to label images using TensorFlow. */
public class TensorFlowImageClassifier implements Classifier {
private static final String TAG = "TensorFlowImageClassifier";
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = 3;
private static final float THRESHOLD = 0.1f;
// Config values.
private String inputName;
private String outputName;
private int inputSize;
private int imageMean;
private float imageStd;
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
private float[] floatValues;
private float[] outputs;
private String[] outputNames;
private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
private TensorFlowImageClassifier() {}
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
* @param imageMean The assumed mean of the image values.
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @throws IOException
*/
public static Classifier create(
AssetManager assetManager,
String modelFilename,
String labelFilename,
int inputSize,
int imageMean,
float imageStd,
String inputName,
String outputName) {
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
c.inputName = inputName;
c.outputName = outputName;
// Read the label names into memory.
// TODO(andrewharp): make this handle non-assets.
String actualFilename = labelFilename.split("file:///android_asset/")[1];
Log.i(TAG, "Reading labels from: " + actualFilename);
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.labels.add(line);
}
br.close();
} catch (IOException e) {
throw new RuntimeException("Problem reading label file!" , e);
}
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
final Operation operation = c.inferenceInterface.graphOperation(outputName);
final int numClasses = (int) operation.output(0).shape().size(1);
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
c.inputSize = inputSize;
c.imageMean = imageMean;
c.imageStd = imageStd;
// Pre-allocate buffers.
c.outputNames = new String[] {outputName};
c.intValues = new int[inputSize * inputSize];
c.floatValues = new float[inputSize * inputSize * 3];
c.outputs = new float[numClasses];
return c;
}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
final int val = intValues[i];
floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
}
Trace.endSection();
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
inferenceInterface.run(outputNames, logStats);
Trace.endSection();
// Copy the output Tensor back into the output array.
Trace.beginSection("fetch");
inferenceInterface.fetch(outputName, outputs);
Trace.endSection();
// Find the best classifications.
PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
3,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < outputs.length; ++i) {
if (outputs[i] > THRESHOLD) {
pq.add(
new Recognition(
"" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
}
}
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); // "recognizeImage"
return recognitions;
}
@Override
public void enableStatLogging(boolean logStats) {
this.logStats = logStats;
}
@Override
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Trace;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.StringTokenizer;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import org.tensorflow.demo.env.Logger;
/**
* A detector for general purpose object detection as described in Scalable Object Detection using
* Deep Neural Networks (https://arxiv.org/abs/1312.2249).
*/
public class TensorFlowMultiBoxDetector implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
private static final int MAX_RESULTS = Integer.MAX_VALUE;
// Config values.
private String inputName;
private int inputSize;
private int imageMean;
private float imageStd;
// Pre-allocated buffers.
private int[] intValues;
private float[] floatValues;
private float[] outputLocations;
private float[] outputScores;
private String[] outputNames;
private int numLocations;
private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
private float[] boxPriors;
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param locationFilename The filepath of label file for classes.
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
* @param imageMean The assumed mean of the image values.
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String locationFilename,
final int imageMean,
final float imageStd,
final String inputName,
final String outputLocationsName,
final String outputScoresName) {
final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = d.inferenceInterface.graph();
d.inputName = inputName;
// The inputName node has a shape of [N, H, W, C], where
// N is the batch size
// H = W are the height and width
// C is the number of channels (3 for our purposes - RGB)
final Operation inputOp = g.operation(inputName);
if (inputOp == null) {
throw new RuntimeException("Failed to find input Node '" + inputName + "'");
}
d.inputSize = (int) inputOp.output(0).shape().size(1);
d.imageMean = imageMean;
d.imageStd = imageStd;
// The outputScoresName node has a shape of [N, NumLocations], where N
// is the batch size.
final Operation outputOp = g.operation(outputScoresName);
if (outputOp == null) {
throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'");
}
d.numLocations = (int) outputOp.output(0).shape().size(1);
d.boxPriors = new float[d.numLocations * 8];
try {
d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
} catch (final IOException e) {
throw new RuntimeException("Error initializing box priors from " + locationFilename);
}
// Pre-allocate buffers.
d.outputNames = new String[] {outputLocationsName, outputScoresName};
d.intValues = new int[d.inputSize * d.inputSize];
d.floatValues = new float[d.inputSize * d.inputSize * 3];
d.outputScores = new float[d.numLocations];
d.outputLocations = new float[d.numLocations * 4];
return d;
}
private TensorFlowMultiBoxDetector() {}
private void loadCoderOptions(
final AssetManager assetManager, final String locationFilename, final float[] boxPriors)
throws IOException {
// Try to be intelligent about opening from assets or sdcard depending on prefix.
final String assetPrefix = "file:///android_asset/";
InputStream is;
if (locationFilename.startsWith(assetPrefix)) {
is = assetManager.open(locationFilename.split(assetPrefix)[1]);
} else {
is = new FileInputStream(locationFilename);
}
// Read values. Number of values per line doesn't matter, as long as they are separated
// by commas and/or whitespace, and there are exactly numLocations * 8 values total.
// Values are in the order mean, std for each consecutive corner of each box, for a total of 8
// per location.
final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
int priorIndex = 0;
String line;
while ((line = reader.readLine()) != null) {
final StringTokenizer st = new StringTokenizer(line, ", ");
while (st.hasMoreTokens()) {
final String token = st.nextToken();
try {
final float number = Float.parseFloat(token);
boxPriors[priorIndex++] = number;
} catch (final NumberFormatException e) {
// Silently ignore.
}
}
}
if (priorIndex != boxPriors.length) {
throw new RuntimeException(
"BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length);
}
}
private float[] decodeLocationsEncoding(final float[] locationEncoding) {
final float[] locations = new float[locationEncoding.length];
boolean nonZero = false;
for (int i = 0; i < numLocations; ++i) {
for (int j = 0; j < 4; ++j) {
final float currEncoding = locationEncoding[4 * i + j];
nonZero = nonZero || currEncoding != 0.0f;
final float mean = boxPriors[i * 8 + j * 2];
final float stdDev = boxPriors[i * 8 + j * 2 + 1];
float currentLocation = currEncoding * stdDev + mean;
currentLocation = Math.max(currentLocation, 0.0f);
currentLocation = Math.min(currentLocation, 1.0f);
locations[4 * i + j] = currentLocation;
}
}
if (!nonZero) {
LOGGER.w("No non-zero encodings; check log for inference errors.");
}
return locations;
}
private float[] decodeScoresEncoding(final float[] scoresEncoding) {
final float[] scores = new float[scoresEncoding.length];
for (int i = 0; i < scoresEncoding.length; ++i) {
scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i])));
}
return scores;
}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
floatValues[i * 3 + 0] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 2] = ((intValues[i] & 0xFF) - imageMean) / imageStd;
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
inferenceInterface.run(outputNames, logStats);
Trace.endSection();
// Copy the output Tensor back into the output array.
Trace.beginSection("fetch");
final float[] outputScoresEncoding = new float[numLocations];
final float[] outputLocationsEncoding = new float[numLocations * 4];
inferenceInterface.fetch(outputNames[0], outputLocationsEncoding);
inferenceInterface.fetch(outputNames[1], outputScoresEncoding);
Trace.endSection();
outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
outputScores = decodeScoresEncoding(outputScoresEncoding);
// Find the best detections.
final PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
1,
new Comparator<Recognition>() {
@Override
public int compare(final Recognition lhs, final Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
// Scale them back to the input size.
for (int i = 0; i < outputScores.length; ++i) {
final RectF detection =
new RectF(
outputLocations[4 * i] * inputSize,
outputLocations[4 * i + 1] * inputSize,
outputLocations[4 * i + 2] * inputSize,
outputLocations[4 * i + 3] * inputSize);
pq.add(new Recognition("" + i, null, outputScores[i], detection));
}
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); // "recognizeImage"
return recognitions;
}
@Override
public void enableStatLogging(final boolean logStats) {
this.logStats = logStats;
}
@Override
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Trace;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import org.tensorflow.demo.env.Logger;
/**
* Wrapper for frozen detection models trained using the Tensorflow Object Detection API:
* github.com/tensorflow/models/tree/master/research/object_detection
*/
public class TensorFlowObjectDetectionAPIModel implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
private static final int MAX_RESULTS = 100;
// Config values.
private String inputName;
private int inputSize;
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
private byte[] byteValues;
private float[] outputLocations;
private float[] outputScores;
private float[] outputClasses;
private float[] outputNumDetections;
private String[] outputNames;
private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
final int inputSize) throws IOException {
final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
LOGGER.w(line);
d.labels.add(line);
}
br.close();
d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = d.inferenceInterface.graph();
d.inputName = "image_tensor";
// The inputName node has a shape of [N, H, W, C], where
// N is the batch size
// H = W are the height and width
// C is the number of channels (3 for our purposes - RGB)
final Operation inputOp = g.operation(d.inputName);
if (inputOp == null) {
throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
}
d.inputSize = inputSize;
// The outputScoresName node has a shape of [N, NumLocations], where N
// is the batch size.
final Operation outputOp1 = g.operation("detection_scores");
if (outputOp1 == null) {
throw new RuntimeException("Failed to find output Node 'detection_scores'");
}
final Operation outputOp2 = g.operation("detection_boxes");
if (outputOp2 == null) {
throw new RuntimeException("Failed to find output Node 'detection_boxes'");
}
final Operation outputOp3 = g.operation("detection_classes");
if (outputOp3 == null) {
throw new RuntimeException("Failed to find output Node 'detection_classes'");
}
// Pre-allocate buffers.
d.outputNames = new String[] {"detection_boxes", "detection_scores",
"detection_classes", "num_detections"};
d.intValues = new int[d.inputSize * d.inputSize];
d.byteValues = new byte[d.inputSize * d.inputSize * 3];
d.outputScores = new float[MAX_RESULTS];
d.outputLocations = new float[MAX_RESULTS * 4];
d.outputClasses = new float[MAX_RESULTS];
d.outputNumDetections = new float[1];
return d;
}
private TensorFlowObjectDetectionAPIModel() {}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data to extract R, G and B bytes from int of form 0x00RRGGBB
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
inferenceInterface.run(outputNames, logStats);
Trace.endSection();
// Copy the output Tensor back into the output array.
Trace.beginSection("fetch");
outputLocations = new float[MAX_RESULTS * 4];
outputScores = new float[MAX_RESULTS];
outputClasses = new float[MAX_RESULTS];
outputNumDetections = new float[1];
inferenceInterface.fetch(outputNames[0], outputLocations);
inferenceInterface.fetch(outputNames[1], outputScores);
inferenceInterface.fetch(outputNames[2], outputClasses);
inferenceInterface.fetch(outputNames[3], outputNumDetections);
Trace.endSection();
// Find the best detections.
final PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
1,
new Comparator<Recognition>() {
@Override
public int compare(final Recognition lhs, final Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
// Scale them back to the input size.
for (int i = 0; i < outputScores.length; ++i) {
final RectF detection =
new RectF(
outputLocations[4 * i + 1] * inputSize,
outputLocations[4 * i] * inputSize,
outputLocations[4 * i + 3] * inputSize,
outputLocations[4 * i + 2] * inputSize);
pq.add(
new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection));
}
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); // "recognizeImage"
return recognitions;
}
@Override
public void enableStatLogging(final boolean logStats) {
this.logStats = logStats;
}
@Override
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Trace;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.env.SplitTimer;
/** An object detector that uses TF and a YOLO model to detect objects. */
public class TensorFlowYoloDetector implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = 5;
private static final int NUM_CLASSES = 1;
private static final int NUM_BOXES_PER_BLOCK = 5;
// TODO(andrewharp): allow loading anchors and classes
// from files.
private static final double[] ANCHORS = {
1.08, 1.19,
3.42, 4.41,
6.63, 11.38,
9.42, 5.11,
16.62, 10.52
};
private static final String[] LABELS = {
"dog"
};
// Config values.
private String inputName;
private int inputSize;
// Pre-allocated buffers.
private int[] intValues;
private float[] floatValues;
private String[] outputNames;
private int blockSize;
private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
/** Initializes a native TensorFlow session for classifying images. */
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final int inputSize,
final String inputName,
final String outputName,
final int blockSize) {
TensorFlowYoloDetector d = new TensorFlowYoloDetector();
d.inputName = inputName;
d.inputSize = inputSize;
// Pre-allocate buffers.
d.outputNames = outputName.split(",");
d.intValues = new int[inputSize * inputSize];
d.floatValues = new float[inputSize * inputSize * 3];
d.blockSize = blockSize;
d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
return d;
}
private TensorFlowYoloDetector() {}
private float expit(final float x) {
return (float) (1. / (1. + Math.exp(-x)));
}
private void softmax(final float[] vals) {
float max = Float.NEGATIVE_INFINITY;
for (final float val : vals) {
max = Math.max(max, val);
}
float sum = 0.0f;
for (int i = 0; i < vals.length; ++i) {
vals[i] = (float) Math.exp(vals[i] - max);
sum += vals[i];
}
for (int i = 0; i < vals.length; ++i) {
vals[i] = vals[i] / sum;
}
}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
final SplitTimer timer = new SplitTimer("recognizeImage");
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;
floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
Trace.endSection();
timer.endSplit("ready for inference");
// Run the inference call.
Trace.beginSection("run");
inferenceInterface.run(outputNames, logStats);
Trace.endSection();
timer.endSplit("ran inference");
// Copy the output Tensor back into the output array.
Trace.beginSection("fetch");
final int gridWidth = bitmap.getWidth() / blockSize;
final int gridHeight = bitmap.getHeight() / blockSize;
final float[] output =
new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
inferenceInterface.fetch(outputNames[0], output);
Trace.endSection();
// Find the best detections.
final PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
1,
new Comparator<Recognition>() {
@Override
public int compare(final Recognition lhs, final Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int y = 0; y < gridHeight; ++y) {
for (int x = 0; x < gridWidth; ++x) {
for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
final int offset =
(gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
+ (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
+ (NUM_CLASSES + 5) * b;
final float xPos = (x + expit(output[offset + 0])) * blockSize;
final float yPos = (y + expit(output[offset + 1])) * blockSize;
final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * blockSize;
final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * blockSize;
final RectF rect =
new RectF(
Math.max(0, xPos - w / 2),
Math.max(0, yPos - h / 2),
Math.min(bitmap.getWidth() - 1, xPos + w / 2),
Math.min(bitmap.getHeight() - 1, yPos + h / 2));
final float confidence = expit(output[offset + 4]);
int detectedClass = -1;
float maxClass = 0;
final float[] classes = new float[NUM_CLASSES];
for (int c = 0; c < NUM_CLASSES; ++c) {
classes[c] = output[offset + 5 + c];
}
softmax(classes);
for (int c = 0; c < NUM_CLASSES; ++c) {
if (classes[c] > maxClass) {
detectedClass = c;
maxClass = classes[c];
}
}
final float confidenceInClass = maxClass * confidence;
if (confidenceInClass > 0.01) {
LOGGER.i(
"%s (%d) %f %s", LABELS[detectedClass], detectedClass, confidenceInClass, rect);
pq.add(new Recognition("" + offset, LABELS[detectedClass], confidenceInClass, rect));
}
}
}
}
timer.endSplit("decoded results");
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); // "recognizeImage"
timer.endSplit("processed results");
return recognitions;
}
@Override
public void enableStatLogging(final boolean logStats) {
this.logStats = logStats;
}
@Override
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.env;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Paint.Align;
import android.graphics.Paint.Style;
import android.graphics.Rect;
import android.graphics.Typeface;
import java.util.Vector;
/**
* A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
*/
public class BorderedText {
private final Paint interiorPaint;
private final Paint exteriorPaint;
private final float textSize;
/**
* Creates a left-aligned bordered text object with a white interior, and a black exterior with
* the specified text size.
*
* @param textSize text size in pixels
*/
public BorderedText(final float textSize) {
this(Color.WHITE, Color.BLACK, textSize);
}
/**
* Create a bordered text object with the specified interior and exterior colors, text size and
* alignment.
*
* @param interiorColor the interior text color
* @param exteriorColor the exterior text color
* @param textSize text size in pixels
*/
public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
interiorPaint = new Paint();
interiorPaint.setTextSize(textSize);
interiorPaint.setColor(interiorColor);
interiorPaint.setStyle(Style.FILL);
interiorPaint.setAntiAlias(false);
interiorPaint.setAlpha(255);
exteriorPaint = new Paint();
exteriorPaint.setTextSize(textSize);
exteriorPaint.setColor(exteriorColor);
exteriorPaint.setStyle(Style.FILL_AND_STROKE);
exteriorPaint.setStrokeWidth(textSize / 8);
exteriorPaint.setAntiAlias(false);
exteriorPaint.setAlpha(255);
this.textSize = textSize;
}
public void setTypeface(Typeface typeface) {
interiorPaint.setTypeface(typeface);
exteriorPaint.setTypeface(typeface);
}
public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
canvas.drawText(text, posX, posY, exteriorPaint);
canvas.drawText(text, posX, posY, interiorPaint);
}
public void drawLines(Canvas canvas, final float posX, final float posY, Vector<String> lines) {
int lineNum = 0;
for (final String line : lines) {
drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line);
++lineNum;
}
}
public void setInteriorColor(final int color) {
interiorPaint.setColor(color);
}
public void setExteriorColor(final int color) {
exteriorPaint.setColor(color);
}
public float getTextSize() {
return textSize;
}
public void setAlpha(final int alpha) {
interiorPaint.setAlpha(alpha);
exteriorPaint.setAlpha(alpha);
}
public void getTextBounds(
final String line, final int index, final int count, final Rect lineBounds) {
interiorPaint.getTextBounds(line, index, count, lineBounds);
}
public void setTextAlign(final Align align) {
interiorPaint.setTextAlign(align);
exteriorPaint.setTextAlign(align);
}
}
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.env;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.os.Environment;
import java.io.File;
import java.io.FileOutputStream;
/**
* Utility class for manipulating images.
**/
public class ImageUtils {
@SuppressWarnings("unused")
private static final Logger LOGGER = new Logger();
static {
try {
System.loadLibrary("tensorflow_demo");
} catch (UnsatisfiedLinkError e) {
LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable.");
}
}
/**
* Utility method to compute the allocated size in bytes of a YUV420SP image
* of the given dimensions.
*/
public static int getYUVByteSize(final int width, final int height) {
// The luminance plane requires 1 byte per pixel.
final int ySize = width * height;
// The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up.
// Each 2x2 block takes 2 bytes to encode, one each for U and V.
final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2;
return ySize + uvSize;
}
/**
* Saves a Bitmap object to disk for analysis.
*
* @param bitmap The bitmap to save.
*/
public static void saveBitmap(final Bitmap bitmap) {
saveBitmap(bitmap, "preview.png");
}
/**
* Saves a Bitmap object to disk for analysis.
*
* @param bitmap The bitmap to save.
* @param filename The location to save the bitmap to.
*/
public static void saveBitmap(final Bitmap bitmap, final String filename) {
final String root =
Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow";
LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root);
final File myDir = new File(root);
if (!myDir.mkdirs()) {
LOGGER.i("Make dir failed");
}
final String fname = filename;
final File file = new File(myDir, fname);
if (file.exists()) {
file.delete();
}
try {
final FileOutputStream out = new FileOutputStream(file);
bitmap.compress(Bitmap.CompressFormat.PNG, 99, out);
out.flush();
out.close();
} catch (final Exception e) {
LOGGER.e(e, "Exception!");
}
}
// This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
// are normalized to eight bits.
static final int kMaxChannelValue = 262143;
// Always prefer the native implementation if available.
private static boolean useNativeConversion = true;
public static void convertYUV420SPToARGB8888(
byte[] input,
int width,
int height,
int[] output) {
if (useNativeConversion) {
try {
ImageUtils.convertYUV420SPToARGB8888(input, output, width, height, false);
return;
} catch (UnsatisfiedLinkError e) {
LOGGER.w(
"Native YUV420SP -> RGB implementation not found, falling back to Java implementation");
useNativeConversion = false;
}
}
// Java implementation of YUV420SP to ARGB8888 converting
final int frameSize = width * height;
for (int j = 0, yp = 0; j < height; j++) {
int uvp = frameSize + (j >> 1) * width;
int u = 0;
int v = 0;
for (int i = 0; i < width; i++, yp++) {
int y = 0xff & input[yp];
if ((i & 1) == 0) {
v = 0xff & input[uvp++];
u = 0xff & input[uvp++];
}
output[yp] = YUV2RGB(y, u, v);
}
}
}
private static int YUV2RGB(int y, int u, int v) {
// Adjust and check YUV values
y = (y - 16) < 0 ? 0 : (y - 16);
u -= 128;
v -= 128;
// This is the floating point equivalent. We do the conversion in integer
// because some Android devices do not have floating point in hardware.
// nR = (int)(1.164 * nY + 2.018 * nU);
// nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
// nB = (int)(1.164 * nY + 1.596 * nV);
int y1192 = 1192 * y;
int r = (y1192 + 1634 * v);
int g = (y1192 - 833 * v - 400 * u);
int b = (y1192 + 2066 * u);
// Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ]
r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r);
g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g);
b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b);
return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff);
}
public static void convertYUV420ToARGB8888(
byte[] yData,
byte[] uData,
byte[] vData,
int width,
int height,
int yRowStride,
int uvRowStride,
int uvPixelStride,
int[] out) {
if (useNativeConversion) {
try {
convertYUV420ToARGB8888(
yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false);
return;
} catch (UnsatisfiedLinkError e) {
LOGGER.w(
"Native YUV420 -> RGB implementation not found, falling back to Java implementation");
useNativeConversion = false;
}
}
int yp = 0;
for (int j = 0; j < height; j++) {
int pY = yRowStride * j;
int pUV = uvRowStride * (j >> 1);
for (int i = 0; i < width; i++) {
int uv_offset = pUV + (i >> 1) * uvPixelStride;
out[yp++] = YUV2RGB(
0xff & yData[pY + i],
0xff & uData[uv_offset],
0xff & vData[uv_offset]);
}
}
}
/**
* Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The
* input and output must already be allocated and non-null. For efficiency, no error checking is
* performed.
*
* @param input The array of YUV 4:2:0 input data.
* @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
* @param width The width of the input image.
* @param height The height of the input image.
* @param halfSize If true, downsample to 50% in each dimension, otherwise not.
*/
private static native void convertYUV420SPToARGB8888(
byte[] input, int[] output, int width, int height, boolean halfSize);
/**
* Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width
* and height. The input and output must already be allocated and non-null.
* For efficiency, no error checking is performed.
*
* @param y
* @param u
* @param v
* @param uvPixelStride
* @param width The width of the input image.
* @param height The height of the input image.
* @param halfSize If true, downsample to 50% in each dimension, otherwise not.
* @param output A pre-allocated array for the ARGB 8:8:8:8 output data.
*/
private static native void convertYUV420ToARGB8888(
byte[] y,
byte[] u,
byte[] v,
int[] output,
int width,
int height,
int yRowStride,
int uvRowStride,
int uvPixelStride,
boolean halfSize);
/**
* Converts YUV420 semi-planar data to RGB 565 data using the supplied width
* and height. The input and output must already be allocated and non-null.
* For efficiency, no error checking is performed.
*
* @param input The array of YUV 4:2:0 input data.
* @param output A pre-allocated array for the RGB 5:6:5 output data.
* @param width The width of the input image.
* @param height The height of the input image.
*/
private static native void convertYUV420SPToRGB565(
byte[] input, byte[] output, int width, int height);
/**
* Converts 32-bit ARGB8888 image data to YUV420SP data. This is useful, for
* instance, in creating data to feed the classes that rely on raw camera
* preview frames.
*
* @param input An array of input pixels in ARGB8888 format.
* @param output A pre-allocated array for the YUV420SP output data.
* @param width The width of the input image.
* @param height The height of the input image.
*/
private static native void convertARGB8888ToYUV420SP(
int[] input, byte[] output, int width, int height);
/**
* Converts 16-bit RGB565 image data to YUV420SP data. This is useful, for
* instance, in creating data to feed the classes that rely on raw camera
* preview frames.
*
* @param input An array of input pixels in RGB565 format.
* @param output A pre-allocated array for the YUV420SP output data.
* @param width The width of the input image.
* @param height The height of the input image.
*/
private static native void convertRGB565ToYUV420SP(
byte[] input, byte[] output, int width, int height);
/**
* Returns a transformation matrix from one reference frame into another.
* Handles cropping (if maintaining aspect ratio is desired) and rotation.
*
* @param srcWidth Width of source frame.
* @param srcHeight Height of source frame.
* @param dstWidth Width of destination frame.
* @param dstHeight Height of destination frame.
* @param applyRotation Amount of rotation to apply from one frame to another.
* Must be a multiple of 90.
* @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
* cropping the image if necessary.
* @return The transformation fulfilling the desired requirements.
*/
public static Matrix getTransformationMatrix(
final int srcWidth,
final int srcHeight,
final int dstWidth,
final int dstHeight,
final int applyRotation,
final boolean maintainAspectRatio) {
final Matrix matrix = new Matrix();
if (applyRotation != 0) {
if (applyRotation % 90 != 0) {
LOGGER.w("Rotation of %d % 90 != 0", applyRotation);
}
// Translate so center of image is at origin.
matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);
// Rotate around origin.
matrix.postRotate(applyRotation);
}
// Account for the already applied rotation, if any, and then determine how
// much scaling is needed for each axis.
final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;
final int inWidth = transpose ? srcHeight : srcWidth;
final int inHeight = transpose ? srcWidth : srcHeight;
// Apply scaling if necessary.
if (inWidth != dstWidth || inHeight != dstHeight) {
final float scaleFactorX = dstWidth / (float) inWidth;
final float scaleFactorY = dstHeight / (float) inHeight;
if (maintainAspectRatio) {
// Scale by minimum factor so that dst is filled completely while
// maintaining the aspect ratio. Some image may fall off the edge.
final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
matrix.postScale(scaleFactor, scaleFactor);
} else {
// Scale exactly to fill dst from src.
matrix.postScale(scaleFactorX, scaleFactorY);
}
}
if (applyRotation != 0) {
// Translate back from origin centered reference to destination frame.
matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
}
return matrix;
}
}
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.env;
import android.util.Log;
import java.util.HashSet;
import java.util.Set;
/**
* Wrapper for the platform log function, allows convenient message prefixing and log disabling.
*/
public final class Logger {
private static final String DEFAULT_TAG = "tensorflow";
private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG;
// Classes to be ignored when examining the stack trace
private static final Set<String> IGNORED_CLASS_NAMES;
static {
IGNORED_CLASS_NAMES = new HashSet<String>(3);
IGNORED_CLASS_NAMES.add("dalvik.system.VMStack");
IGNORED_CLASS_NAMES.add("java.lang.Thread");
IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName());
}
private final String tag;
private final String messagePrefix;
private int minLogLevel = DEFAULT_MIN_LOG_LEVEL;
/**
* Creates a Logger using the class name as the message prefix.
*
* @param clazz the simple name of this class is used as the message prefix.
*/
public Logger(final Class<?> clazz) {
this(clazz.getSimpleName());
}
/**
* Creates a Logger using the specified message prefix.
*
* @param messagePrefix is prepended to the text of every message.
*/
public Logger(final String messagePrefix) {
this(DEFAULT_TAG, messagePrefix);
}
/**
* Creates a Logger with a custom tag and a custom message prefix. If the message prefix
* is set to <pre>null</pre>, the caller's class name is used as the prefix.
*
* @param tag identifies the source of a log message.
* @param messagePrefix prepended to every message if non-null. If null, the name of the caller is
* being used
*/
public Logger(final String tag, final String messagePrefix) {
this.tag = tag;
final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix;
this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix;
}
/**
* Creates a Logger using the caller's class name as the message prefix.
*/
public Logger() {
this(DEFAULT_TAG, null);
}
/**
* Creates a Logger using the caller's class name as the message prefix.
*/
public Logger(final int minLogLevel) {
this(DEFAULT_TAG, null);
this.minLogLevel = minLogLevel;
}
public void setMinLogLevel(final int minLogLevel) {
this.minLogLevel = minLogLevel;
}
public boolean isLoggable(final int logLevel) {
return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel);
}
/**
* Return caller's simple name.
*
* Android getStackTrace() returns an array that looks like this:
* stackTrace[0]: dalvik.system.VMStack
* stackTrace[1]: java.lang.Thread
* stackTrace[2]: com.google.android.apps.unveil.env.UnveilLogger
* stackTrace[3]: com.google.android.apps.unveil.BaseApplication
*
* This function returns the simple version of the first non-filtered name.
*
* @return caller's simple name
*/
private static String getCallerSimpleName() {
// Get the current callstack so we can pull the class of the caller off of it.
final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
for (final StackTraceElement elem : stackTrace) {
final String className = elem.getClassName();
if (!IGNORED_CLASS_NAMES.contains(className)) {
// We're only interested in the simple name of the class, not the complete package.
final String[] classParts = className.split("\\.");
return classParts[classParts.length - 1];
}
}
return Logger.class.getSimpleName();
}
private String toMessage(final String format, final Object... args) {
return messagePrefix + (args.length > 0 ? String.format(format, args) : format);
}
public void v(final String format, final Object... args) {
if (isLoggable(Log.VERBOSE)) {
Log.v(tag, toMessage(format, args));
}
}
public void v(final Throwable t, final String format, final Object... args) {
if (isLoggable(Log.VERBOSE)) {
Log.v(tag, toMessage(format, args), t);
}
}
public void d(final String format, final Object... args) {
if (isLoggable(Log.DEBUG)) {
Log.d(tag, toMessage(format, args));
}
}
public void d(final Throwable t, final String format, final Object... args) {
if (isLoggable(Log.DEBUG)) {
Log.d(tag, toMessage(format, args), t);
}
}
public void i(final String format, final Object... args) {
if (isLoggable(Log.INFO)) {
Log.i(tag, toMessage(format, args));
}
}
public void i(final Throwable t, final String format, final Object... args) {
if (isLoggable(Log.INFO)) {
Log.i(tag, toMessage(format, args), t);
}
}
public void w(final String format, final Object... args) {
if (isLoggable(Log.WARN)) {
Log.w(tag, toMessage(format, args));
}
}
public void w(final Throwable t, final String format, final Object... args) {
if (isLoggable(Log.WARN)) {
Log.w(tag, toMessage(format, args), t);
}
}
public void e(final String format, final Object... args) {
if (isLoggable(Log.ERROR)) {
Log.e(tag, toMessage(format, args));
}
}
public void e(final Throwable t, final String format, final Object... args) {
if (isLoggable(Log.ERROR)) {
Log.e(tag, toMessage(format, args), t);
}
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.env;
import android.graphics.Bitmap;
import android.text.TextUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* Size class independent of a Camera object.
*/
public class Size implements Comparable<Size>, Serializable {
// 1.4 went out with this UID so we'll need to maintain it to preserve pending queries when
// upgrading.
public static final long serialVersionUID = 7689808733290872361L;
public final int width;
public final int height;
public Size(final int width, final int height) {
this.width = width;
this.height = height;
}
public Size(final Bitmap bmp) {
this.width = bmp.getWidth();
this.height = bmp.getHeight();
}
/**
* Rotate a size by the given number of degrees.
* @param size Size to rotate.
* @param rotation Degrees {0, 90, 180, 270} to rotate the size.
* @return Rotated size.
*/
public static Size getRotatedSize(final Size size, final int rotation) {
if (rotation % 180 != 0) {
// The phone is portrait, therefore the camera is sideways and frame should be rotated.
return new Size(size.height, size.width);
}
return size;
}
public static Size parseFromString(String sizeString) {
if (TextUtils.isEmpty(sizeString)) {
return null;
}
sizeString = sizeString.trim();
// The expected format is "<width>x<height>".
final String[] components = sizeString.split("x");
if (components.length == 2) {
try {
final int width = Integer.parseInt(components[0]);
final int height = Integer.parseInt(components[1]);
return new Size(width, height);
} catch (final NumberFormatException e) {
return null;
}
} else {
return null;
}
}
public static List<Size> sizeStringToList(final String sizes) {
final List<Size> sizeList = new ArrayList<Size>();
if (sizes != null) {
final String[] pairs = sizes.split(",");
for (final String pair : pairs) {
final Size size = Size.parseFromString(pair);
if (size != null) {
sizeList.add(size);
}
}
}
return sizeList;
}
public static String sizeListToString(final List<Size> sizes) {
String sizesString = "";
if (sizes != null && sizes.size() > 0) {
sizesString = sizes.get(0).toString();
for (int i = 1; i < sizes.size(); i++) {
sizesString += "," + sizes.get(i).toString();
}
}
return sizesString;
}
public final float aspectRatio() {
return (float) width / (float) height;
}
@Override
public int compareTo(final Size other) {
return width * height - other.width * other.height;
}
@Override
public boolean equals(final Object other) {
if (other == null) {
return false;
}
if (!(other instanceof Size)) {
return false;
}
final Size otherSize = (Size) other;
return (width == otherSize.width && height == otherSize.height);
}
@Override
public int hashCode() {
return width * 32713 + height;
}
@Override
public String toString() {
return dimensionsAsString(width, height);
}
public static final String dimensionsAsString(final int width, final int height) {
return width + "x" + height;
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.env;
import android.os.SystemClock;
/**
* A simple utility timer for measuring CPU time and wall-clock splits.
*/
public class SplitTimer {
private final Logger logger;
private long lastWallTime;
private long lastCpuTime;
public SplitTimer(final String name) {
logger = new Logger(name);
newSplit();
}
public void newSplit() {
lastWallTime = SystemClock.uptimeMillis();
lastCpuTime = SystemClock.currentThreadTimeMillis();
}
public void endSplit(final String splitName) {
final long currWallTime = SystemClock.uptimeMillis();
final long currCpuTime = SystemClock.currentThreadTimeMillis();
logger.i(
"%s: cpu=%dms wall=%dms",
splitName, currCpuTime - lastCpuTime, currWallTime - lastWallTime);
lastWallTime = currWallTime;
lastCpuTime = currCpuTime;
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.tracking;
import android.content.Context;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Cap;
import android.graphics.Paint.Join;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.text.TextUtils;
import android.util.Pair;
import android.util.TypedValue;
import android.widget.Toast;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import org.tensorflow.demo.Classifier.Recognition;
import org.tensorflow.demo.env.BorderedText;
import org.tensorflow.demo.env.ImageUtils;
import org.tensorflow.demo.env.Logger;
/**
* A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
* objects to new detections.
*/
public class MultiBoxTracker {
private final Logger logger = new Logger();
private static final float TEXT_SIZE_DIP = 18;
// Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
// the lower scored box (new or old) will be removed.
private static final float MAX_OVERLAP = 0.2f;
private static final float MIN_SIZE = 16.0f;
// Allow replacement of the tracked box with new results if
// correlation has dropped below this level.
private static final float MARGINAL_CORRELATION = 0.75f;
// Consider object to be lost if correlation falls below this threshold.
private static final float MIN_CORRELATION = 0.3f;
private static final int[] COLORS = {
Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE,
Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"),
Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"),
Color.parseColor("#AA33AA"), Color.parseColor("#0D0068")
};
private final Queue<Integer> availableColors = new LinkedList<Integer>();
public ObjectTracker objectTracker;
final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
private static class TrackedRecognition {
ObjectTracker.TrackedObject trackedObject;
RectF location;
float detectionConfidence;
int color;
String title;
}
private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
private final Paint boxPaint = new Paint();
private final float textSizePx;
private final BorderedText borderedText;
private Matrix frameToCanvasMatrix;
private int frameWidth;
private int frameHeight;
private int sensorOrientation;
private Context context;
public MultiBoxTracker(final Context context) {
this.context = context;
for (final int color : COLORS) {
availableColors.add(color);
}
boxPaint.setColor(Color.RED);
boxPaint.setStyle(Style.STROKE);
boxPaint.setStrokeWidth(12.0f);
boxPaint.setStrokeCap(Cap.ROUND);
boxPaint.setStrokeJoin(Join.ROUND);
boxPaint.setStrokeMiter(100);
textSizePx =
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
}
private Matrix getFrameToCanvasMatrix() {
return frameToCanvasMatrix;
}
public synchronized void drawDebug(final Canvas canvas) {
final Paint textPaint = new Paint();
textPaint.setColor(Color.WHITE);
textPaint.setTextSize(60.0f);
final Paint boxPaint = new Paint();
boxPaint.setColor(Color.RED);
boxPaint.setAlpha(200);
boxPaint.setStyle(Style.STROKE);
for (final Pair<Float, RectF> detection : screenRects) {
final RectF rect = detection.second;
canvas.drawRect(rect, boxPaint);
canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
}
if (objectTracker == null) {
return;
}
// Draw correlations.
for (final TrackedRecognition recognition : trackedObjects) {
final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
}
}
final Matrix matrix = getFrameToCanvasMatrix();
objectTracker.drawDebug(canvas, matrix);
}
public synchronized void trackResults(
final List<Recognition> results, final byte[] frame, final long timestamp) {
logger.i("Processing %d results from %d", results.size(), timestamp);
processResults(timestamp, results, frame);
}
public synchronized void draw(final Canvas canvas) {
final boolean rotated = sensorOrientation % 180 == 90;
final float multiplier =
Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight),
canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth));
frameToCanvasMatrix =
ImageUtils.getTransformationMatrix(
frameWidth,
frameHeight,
(int) (multiplier * (rotated ? frameHeight : frameWidth)),
(int) (multiplier * (rotated ? frameWidth : frameHeight)),
sensorOrientation,
false);
for (final TrackedRecognition recognition : trackedObjects) {
final RectF trackedPos =
(objectTracker != null)
? recognition.trackedObject.getTrackedPositionInPreviewFrame()
: new RectF(recognition.location);
getFrameToCanvasMatrix().mapRect(trackedPos);
boxPaint.setColor(recognition.color);
final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
final String labelString =
!TextUtils.isEmpty(recognition.title)
? String.format("%s %.2f", recognition.title, recognition.detectionConfidence)
: String.format("%.2f", recognition.detectionConfidence);
borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
}
}
private boolean initialized = false;
public synchronized void onFrame(
final int w,
final int h,
final int rowStride,
final int sensorOrientation,
final byte[] frame,
final long timestamp) {
if (objectTracker == null && !initialized) {
ObjectTracker.clearInstance();
logger.i("Initializing ObjectTracker: %dx%d", w, h);
objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
frameWidth = w;
frameHeight = h;
this.sensorOrientation = sensorOrientation;
initialized = true;
if (objectTracker == null) {
String message =
"Object tracking support not found. "
+ "See tensorflow/examples/android/README.md for details.";
Toast.makeText(context, message, Toast.LENGTH_LONG).show();
logger.e(message);
}
}
if (objectTracker == null) {
return;
}
objectTracker.nextFrame(frame, null, timestamp, null, true);
// Clean up any objects not worth tracking any more.
final LinkedList<TrackedRecognition> copyList =
new LinkedList<TrackedRecognition>(trackedObjects);
for (final TrackedRecognition recognition : copyList) {
final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
final float correlation = trackedObject.getCurrentCorrelation();
if (correlation < MIN_CORRELATION) {
logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
trackedObject.stopTracking();
trackedObjects.remove(recognition);
availableColors.add(recognition.color);
}
}
}
private void processResults(
final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
final List<Pair<Float, Recognition>> rectsToTrack = new LinkedList<Pair<Float, Recognition>>();
screenRects.clear();
final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
for (final Recognition result : results) {
if (result.getLocation() == null) {
continue;
}
final RectF detectionFrameRect = new RectF(result.getLocation());
final RectF detectionScreenRect = new RectF();
rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
logger.v(
"Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
logger.w("Degenerate rectangle! " + detectionFrameRect);
continue;
}
rectsToTrack.add(new Pair<Float, Recognition>(result.getConfidence(), result));
}
if (rectsToTrack.isEmpty()) {
logger.v("Nothing to track, aborting.");
return;
}
if (objectTracker == null) {
trackedObjects.clear();
for (final Pair<Float, Recognition> potential : rectsToTrack) {
final TrackedRecognition trackedRecognition = new TrackedRecognition();
trackedRecognition.detectionConfidence = potential.first;
trackedRecognition.location = new RectF(potential.second.getLocation());
trackedRecognition.trackedObject = null;
trackedRecognition.title = potential.second.getTitle();
trackedRecognition.color = COLORS[trackedObjects.size()];
trackedObjects.add(trackedRecognition);
if (trackedObjects.size() >= COLORS.length) {
break;
}
}
return;
}
logger.i("%d rects to track", rectsToTrack.size());
for (final Pair<Float, Recognition> potential : rectsToTrack) {
handleDetection(originalFrame, timestamp, potential);
}
}
private void handleDetection(
final byte[] frameCopy, final long timestamp, final Pair<Float, Recognition> potential) {
final ObjectTracker.TrackedObject potentialObject =
objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy);
final float potentialCorrelation = potentialObject.getCurrentCorrelation();
logger.v(
"Tracked object went from %s to %s with correlation %.2f",
potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
if (potentialCorrelation < MARGINAL_CORRELATION) {
logger.v("Correlation too low to begin tracking %s.", potentialObject);
potentialObject.stopTracking();
return;
}
final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
float maxIntersect = 0.0f;
// This is the current tracked object whose color we will take. If left null we'll take the
// first one from the color queue.
TrackedRecognition recogToReplace = null;
// Look for intersections that will be overridden by this object or an intersection that would
// prevent this one from being placed.
for (final TrackedRecognition trackedRecognition : trackedObjects) {
final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
final RectF intersection = new RectF();
final boolean intersects = intersection.setIntersect(a, b);
final float intersectArea = intersection.width() * intersection.height();
final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea;
final float intersectOverUnion = intersectArea / totalArea;
// If there is an intersection with this currently tracked box above the maximum overlap
// percentage allowed, either the new recognition needs to be dismissed or the old
// recognition needs to be removed and possibly replaced with the new one.
if (intersects && intersectOverUnion > MAX_OVERLAP) {
if (potential.first < trackedRecognition.detectionConfidence
&& trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
// If track for the existing object is still going strong and the detection score was
// good, reject this new object.
potentialObject.stopTracking();
return;
} else {
removeList.add(trackedRecognition);
// Let the previously tracked object with max intersection amount donate its color to
// the new object.
if (intersectOverUnion > maxIntersect) {
maxIntersect = intersectOverUnion;
recogToReplace = trackedRecognition;
}
}
}
}
// If we're already tracking the max object and no intersections were found to bump off,
// pick the worst current tracked object to remove, if it's also worse than this candidate
// object.
if (availableColors.isEmpty() && removeList.isEmpty()) {
for (final TrackedRecognition candidate : trackedObjects) {
if (candidate.detectionConfidence < potential.first) {
if (recogToReplace == null
|| candidate.detectionConfidence < recogToReplace.detectionConfidence) {
// Save it so that we use this color for the new object.
recogToReplace = candidate;
}
}
}
if (recogToReplace != null) {
logger.v("Found non-intersecting object to remove.");
removeList.add(recogToReplace);
} else {
logger.v("No non-intersecting object found to remove");
}
}
// Remove everything that got intersected.
for (final TrackedRecognition trackedRecognition : removeList) {
logger.v(
"Removing tracked object %s with detection confidence %.2f, correlation %.2f",
trackedRecognition.trackedObject,
trackedRecognition.detectionConfidence,
trackedRecognition.trackedObject.getCurrentCorrelation());
trackedRecognition.trackedObject.stopTracking();
trackedObjects.remove(trackedRecognition);
if (trackedRecognition != recogToReplace) {
availableColors.add(trackedRecognition.color);
}
}
if (recogToReplace == null && availableColors.isEmpty()) {
logger.e("No room to track this object, aborting.");
potentialObject.stopTracking();
return;
}
// Finally safe to say we can track this object.
logger.v(
"Tracking object %s (%s) with detection confidence %.2f at position %s",
potentialObject,
potential.second.getTitle(),
potential.first,
potential.second.getLocation());
final TrackedRecognition trackedRecognition = new TrackedRecognition();
trackedRecognition.detectionConfidence = potential.first;
trackedRecognition.trackedObject = potentialObject;
trackedRecognition.title = potential.second.getTitle();
// Use the color from a replaced object before taking one from the color queue.
trackedRecognition.color =
recogToReplace != null ? recogToReplace.color : availableColors.poll();
trackedObjects.add(trackedRecognition);
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.demo.tracking;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.PointF;
import android.graphics.RectF;
import android.graphics.Typeface;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Vector;
import javax.microedition.khronos.opengles.GL10;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.demo.env.Size;
/**
* True object detector/tracker class that tracks objects across consecutive preview frames.
* It provides a simplified Java interface to the analogous native object defined by
* jni/client_vision/tracking/object_tracker.*.
*
* Currently, the ObjectTracker is a singleton due to native code restrictions, and so must
* be allocated by ObjectTracker.getInstance(). In addition, release() should be called
* as soon as the ObjectTracker is no longer needed, and before a new one is created.
*
* nextFrame() should be called as new frames become available, preferably as often as possible.
*
* After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects
* are associated with the ObjectTracker that created them, and are only valid while that
* ObjectTracker still exists.
*/
public class ObjectTracker {
private static final Logger LOGGER = new Logger();
private static boolean libraryFound = false;
static {
try {
System.loadLibrary("tensorflow_demo");
libraryFound = true;
} catch (UnsatisfiedLinkError e) {
LOGGER.e("libtensorflow_demo.so not found, tracking unavailable");
}
}
private static final boolean DRAW_TEXT = false;
/**
* How many history points to keep track of and draw in the red history line.
*/
private static final int MAX_DEBUG_HISTORY_SIZE = 30;
/**
* How many frames of optical flow deltas to record.
* TODO(andrewharp): Push this down to the native level so it can be polled
* efficiently into a an array for upload, instead of keeping a duplicate
* copy in Java.
*/
private static final int MAX_FRAME_HISTORY_SIZE = 200;
private static final int DOWNSAMPLE_FACTOR = 2;
private final byte[] downsampledFrame;
protected static ObjectTracker instance;
private final Map<String, TrackedObject> trackedObjects;
private long lastTimestamp;
private FrameChange lastKeypoints;
private final Vector<PointF> debugHistory;
private final LinkedList<TimestampedDeltas> timestampedDeltas;
protected final int frameWidth;
protected final int frameHeight;
private final int rowStride;
protected final boolean alwaysTrack;
private static class TimestampedDeltas {
final long timestamp;
final byte[] deltas;
public TimestampedDeltas(final long timestamp, final byte[] deltas) {
this.timestamp = timestamp;
this.deltas = deltas;
}
}
/**
* A simple class that records keypoint information, which includes
* local location, score and type. This will be used in calculating
* FrameChange.
*/
public static class Keypoint {
public final float x;
public final float y;
public final float score;
public final int type;
public Keypoint(final float x, final float y) {
this.x = x;
this.y = y;
this.score = 0;
this.type = -1;
}
public Keypoint(final float x, final float y, final float score, final int type) {
this.x = x;
this.y = y;
this.score = score;
this.type = type;
}
Keypoint delta(final Keypoint other) {
return new Keypoint(this.x - other.x, this.y - other.y);
}
}
/**
* A simple class that could calculate Keypoint delta.
* This class will be used in calculating frame translation delta
* for optical flow.
*/
public static class PointChange {
public final Keypoint keypointA;
public final Keypoint keypointB;
Keypoint pointDelta;
private final boolean wasFound;
public PointChange(final float x1, final float y1,
final float x2, final float y2,
final float score, final int type,
final boolean wasFound) {
this.wasFound = wasFound;
keypointA = new Keypoint(x1, y1, score, type);
keypointB = new Keypoint(x2, y2);
}
public Keypoint getDelta() {
if (pointDelta == null) {
pointDelta = keypointB.delta(keypointA);
}
return pointDelta;
}
}
/** A class that records a timestamped frame translation delta for optical flow. */
public static class FrameChange {
public static final int KEYPOINT_STEP = 7;
public final Vector<PointChange> pointDeltas;
private final float minScore;
private final float maxScore;
public FrameChange(final float[] framePoints) {
float minScore = 100.0f;
float maxScore = -100.0f;
pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP);
for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) {
final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR;
final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR;
final boolean wasFound = framePoints[i + 2] > 0.0f;
final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR;
final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR;
final float score = framePoints[i + 5];
final int type = (int) framePoints[i + 6];
minScore = Math.min(minScore, score);
maxScore = Math.max(maxScore, score);
pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound));
}
this.minScore = minScore;
this.maxScore = maxScore;
}
}
public static synchronized ObjectTracker getInstance(
final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
if (!libraryFound) {
LOGGER.e(
"Native object tracking support not found. "
+ "See tensorflow/examples/android/README.md for details.");
return null;
}
if (instance == null) {
instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
instance.init();
} else {
throw new RuntimeException(
"Tried to create a new objectracker before releasing the old one!");
}
return instance;
}
public static synchronized void clearInstance() {
if (instance != null) {
instance.release();
}
}
protected ObjectTracker(
final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
this.frameWidth = frameWidth;
this.frameHeight = frameHeight;
this.rowStride = rowStride;
this.alwaysTrack = alwaysTrack;
this.timestampedDeltas = new LinkedList<TimestampedDeltas>();
trackedObjects = new HashMap<String, TrackedObject>();
debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE);
downsampledFrame =
new byte
[(frameWidth + DOWNSAMPLE_FACTOR - 1)
/ DOWNSAMPLE_FACTOR
* (frameHeight + DOWNSAMPLE_FACTOR - 1)
/ DOWNSAMPLE_FACTOR];
}
protected void init() {
// The native tracker never sees the full frame, so pre-scale dimensions
// by the downsample factor.
initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack);
}
private final float[] matrixValues = new float[9];
private long downsampledTimestamp;
@SuppressWarnings("unused")
public synchronized void drawOverlay(final GL10 gl,
final Size cameraViewSize, final Matrix matrix) {
final Matrix tempMatrix = new Matrix(matrix);
tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR);
tempMatrix.getValues(matrixValues);
drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues);
}
public synchronized void nextFrame(
final byte[] frameData, final byte[] uvData,
final long timestamp, final float[] transformationMatrix,
final boolean updateDebugInfo) {
if (downsampledTimestamp != timestamp) {
ObjectTracker.downsampleImageNative(
frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
downsampledTimestamp = timestamp;
}
// Do Lucas Kanade using the fullframe initializer.
nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix);
timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR)));
while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) {
timestampedDeltas.removeFirst();
}
for (final TrackedObject trackedObject : trackedObjects.values()) {
trackedObject.updateTrackedPosition();
}
if (updateDebugInfo) {
updateDebugHistory();
}
lastTimestamp = timestamp;
}
public synchronized void release() {
releaseMemoryNative();
synchronized (ObjectTracker.class) {
instance = null;
}
}
private void drawHistoryDebug(final Canvas canvas) {
drawHistoryPoint(
canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2);
}
private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) {
final Paint p = new Paint();
p.setAntiAlias(false);
p.setTypeface(Typeface.SERIF);
p.setColor(Color.RED);
p.setStrokeWidth(2.0f);
// Draw the center circle.
p.setColor(Color.GREEN);
canvas.drawCircle(startX, startY, 3.0f, p);
p.setColor(Color.RED);
// Iterate through in backwards order.
synchronized (debugHistory) {
final int numPoints = debugHistory.size();
float lastX = startX;
float lastY = startY;
for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) {
final PointF delta = debugHistory.get(numPoints - keypointNum - 1);
final float newX = lastX + delta.x;
final float newY = lastY + delta.y;
canvas.drawLine(lastX, lastY, newX, newY, p);
lastX = newX;
lastY = newY;
}
}
}
private static int floatToChar(final float value) {
return Math.max(0, Math.min((int) (value * 255.999f), 255));
}
private void drawKeypointsDebug(final Canvas canvas) {
final Paint p = new Paint();
if (lastKeypoints == null) {
return;
}
final int keypointSize = 3;
final float minScore = lastKeypoints.minScore;
final float maxScore = lastKeypoints.maxScore;
for (final PointChange keypoint : lastKeypoints.pointDeltas) {
if (keypoint.wasFound) {
final int r =
floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore));
final int b =
floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore));
final int color = 0xFF000000 | (r << 16) | b;
p.setColor(color);
final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y,
keypoint.keypointB.x, keypoint.keypointB.y};
canvas.drawRect(screenPoints[2] - keypointSize,
screenPoints[3] - keypointSize,
screenPoints[2] + keypointSize,
screenPoints[3] + keypointSize, p);
p.setColor(Color.CYAN);
canvas.drawLine(screenPoints[2], screenPoints[3],
screenPoints[0], screenPoints[1], p);
if (DRAW_TEXT) {
p.setColor(Color.WHITE);
canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score,
keypoint.keypointA.x, keypoint.keypointA.y, p);
}
} else {
p.setColor(Color.YELLOW);
final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y};
canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p);
}
}
}
private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX,
final float positionY, final float radius) {
final RectF currPosition = getCurrentPosition(timestamp,
new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius));
return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY);
}
private synchronized RectF getCurrentPosition(final long timestamp, final RectF
oldPosition) {
final RectF downscaledFrameRect = downscaleRect(oldPosition);
final float[] delta = new float[4];
getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top,
downscaledFrameRect.right, downscaledFrameRect.bottom, delta);
final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
return upscaleRect(newPosition);
}
private void updateDebugHistory() {
lastKeypoints = new FrameChange(getKeypointsNative(false));
if (lastTimestamp == 0) {
return;
}
final PointF delta =
getAccumulatedDelta(
lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100);
synchronized (debugHistory) {
debugHistory.add(delta);
while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) {
debugHistory.remove(0);
}
}
}
public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) {
canvas.save();
canvas.setMatrix(frameToCanvas);
drawHistoryDebug(canvas);
drawKeypointsDebug(canvas);
canvas.restore();
}
public Vector<String> getDebugText() {
final Vector<String> lines = new Vector<String>();
if (lastKeypoints != null) {
lines.add("Num keypoints " + lastKeypoints.pointDeltas.size());
lines.add("Min score: " + lastKeypoints.minScore);
lines.add("Max score: " + lastKeypoints.maxScore);
}
return lines;
}
public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) {
final List<byte[]> frameDeltas = new ArrayList<byte[]>();
while (timestampedDeltas.size() > 0) {
final TimestampedDeltas currentDeltas = timestampedDeltas.peek();
if (currentDeltas.timestamp <= endFrameTime) {
frameDeltas.add(currentDeltas.deltas);
timestampedDeltas.removeFirst();
} else {
break;
}
}
return frameDeltas;
}
private RectF downscaleRect(final RectF fullFrameRect) {
return new RectF(
fullFrameRect.left / DOWNSAMPLE_FACTOR,
fullFrameRect.top / DOWNSAMPLE_FACTOR,
fullFrameRect.right / DOWNSAMPLE_FACTOR,
fullFrameRect.bottom / DOWNSAMPLE_FACTOR);
}
private RectF upscaleRect(final RectF downsampledFrameRect) {
return new RectF(
downsampledFrameRect.left * DOWNSAMPLE_FACTOR,
downsampledFrameRect.top * DOWNSAMPLE_FACTOR,
downsampledFrameRect.right * DOWNSAMPLE_FACTOR,
downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR);
}
/**
* A TrackedObject represents a native TrackedObject, and provides access to the
* relevant native tracking information available after every frame update. They may
* be safely passed around and accessed externally, but will become invalid after
* stopTracking() is called or the related creating ObjectTracker is deactivated.
*
* @author andrewharp@google.com (Andrew Harp)
*/
public class TrackedObject {
private final String id;
private long lastExternalPositionTime;
private RectF lastTrackedPosition;
private boolean visibleInLastFrame;
private boolean isDead;
TrackedObject(final RectF position, final long timestamp, final byte[] data) {
isDead = false;
id = Integer.toString(this.hashCode());
lastExternalPositionTime = timestamp;
synchronized (ObjectTracker.this) {
registerInitialAppearance(position, data);
setPreviousPosition(position, timestamp);
trackedObjects.put(id, this);
}
}
public void stopTracking() {
checkValidObject();
synchronized (ObjectTracker.this) {
isDead = true;
forgetNative(id);
trackedObjects.remove(id);
}
}
public float getCurrentCorrelation() {
checkValidObject();
return ObjectTracker.this.getCurrentCorrelation(id);
}
void registerInitialAppearance(final RectF position, final byte[] data) {
final RectF externalPosition = downscaleRect(position);
registerNewObjectWithAppearanceNative(id,
externalPosition.left, externalPosition.top,
externalPosition.right, externalPosition.bottom,
data);
}
synchronized void setPreviousPosition(final RectF position, final long timestamp) {
checkValidObject();
synchronized (ObjectTracker.this) {
if (lastExternalPositionTime > timestamp) {
LOGGER.w("Tried to use older position time!");
return;
}
final RectF externalPosition = downscaleRect(position);
lastExternalPositionTime = timestamp;
setPreviousPositionNative(id,
externalPosition.left, externalPosition.top,
externalPosition.right, externalPosition.bottom,
lastExternalPositionTime);
updateTrackedPosition();
}
}
void setCurrentPosition(final RectF position) {
checkValidObject();
final RectF downsampledPosition = downscaleRect(position);
synchronized (ObjectTracker.this) {
setCurrentPositionNative(id,
downsampledPosition.left, downsampledPosition.top,
downsampledPosition.right, downsampledPosition.bottom);
}
}
private synchronized void updateTrackedPosition() {
checkValidObject();
final float[] delta = new float[4];
getTrackedPositionNative(id, delta);
lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
visibleInLastFrame = isObjectVisible(id);
}
public synchronized RectF getTrackedPositionInPreviewFrame() {
checkValidObject();
if (lastTrackedPosition == null) {
return null;
}
return upscaleRect(lastTrackedPosition);
}
synchronized long getLastExternalPositionTime() {
return lastExternalPositionTime;
}
public synchronized boolean visibleInLastPreviewFrame() {
return visibleInLastFrame;
}
private void checkValidObject() {
if (isDead) {
throw new RuntimeException("TrackedObject already removed from tracking!");
} else if (ObjectTracker.this != instance) {
throw new RuntimeException("TrackedObject created with another ObjectTracker!");
}
}
}
public synchronized TrackedObject trackObject(
final RectF position, final long timestamp, final byte[] frameData) {
if (downsampledTimestamp != timestamp) {
ObjectTracker.downsampleImageNative(
frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
downsampledTimestamp = timestamp;
}
return new TrackedObject(position, timestamp, downsampledFrame);
}
public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) {
return new TrackedObject(position, lastTimestamp, frameData);
}
/** ********************* NATIVE CODE ************************************ */
/** This will contain an opaque pointer to the native ObjectTracker */
private long nativeObjectTracker;
private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);
protected native void registerNewObjectWithAppearanceNative(
String objectId, float x1, float y1, float x2, float y2, byte[] data);
protected native void setPreviousPositionNative(
String objectId, float x1, float y1, float x2, float y2, long timestamp);
protected native void setCurrentPositionNative(
String objectId, float x1, float y1, float x2, float y2);
protected native void forgetNative(String key);
protected native String getModelIdNative(String key);
protected native boolean haveObject(String key);
protected native boolean isObjectVisible(String key);
protected native float getCurrentCorrelation(String key);
protected native float getMatchScore(String key);
protected native void getTrackedPositionNative(String key, float[] points);
protected native void nextFrameNative(
byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix);
protected native void releaseMemoryNative();
protected native void getCurrentPositionNative(long timestamp,
final float positionX1, final float positionY1,
final float positionX2, final float positionY2,
final float[] delta);
protected native byte[] getKeypointsPacked(float scaleFactor);
protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints);
protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas);
protected static native void downsampleImageNative(
int width, int height, int rowStride, byte[] input, int factor, byte[] output);
}