Google Play 服务 Java API 中的 TensorFlow Lite

除了原生 API 之外,还可以使用 Java API 访问 Google Play 服务中的 TensorFlow Lite。特别是,Google Play 服务中的 TensorFlow Lite 可通过 TensorFlow Lite 任务 APITensorFlow Lite 解释器 API 使用。任务库为使用视觉、音频和文本数据的常见机器学习任务提供了经过优化的开箱即用模型接口。TensorFlow 运行时提供的 TensorFlow Lite 解释器 API 提供了一个更通用的接口,用于构建和运行机器学习模型。

以下部分提供了有关如何在 Google Play 服务中使用 TensorFlow Lite 的解释器和任务库 API 的说明。虽然应用程序可以使用解释器 API 和任务库 API,但大多数应用程序应该只使用其中一组 API。

使用任务库 API

TensorFlow Lite 任务 API 封装了解释器 API,并为使用视觉、音频和文本数据的常见机器学习任务提供了一个高级编程接口。如果您的应用程序需要 支持的任务 之一,则应使用任务 API。

1. 添加项目依赖项

您的项目依赖项取决于您的机器学习用例。任务 API 包含以下库

  • 视觉库:org.tensorflow:tensorflow-lite-task-vision-play-services
  • 音频库:org.tensorflow:tensorflow-lite-task-audio-play-services
  • 文本库:org.tensorflow:tensorflow-lite-task-text-play-services

将其中一个依赖项添加到您的应用程序项目代码中,以访问 TensorFlow Lite 的 Play 服务 API。例如,使用以下代码来实现视觉任务

dependencies {
...
    implementation 'org.tensorflow:tensorflow-lite-task-vision-play-services:0.4.2'
...
}

2. 添加 TensorFlow Lite 的初始化

使用 TensorFlow Lite API 之前,初始化 Google Play 服务 API 的 TensorFlow Lite 组件。以下示例初始化视觉库

Kotlin

init {
  TfLiteVision.initialize(context)
}

3. 运行推理

初始化 TensorFlow Lite 组件后,调用 detect() 方法以生成推理。detect() 方法中的确切代码因库和用例而异。以下示例适用于使用 TfLiteVision 库的简单目标检测用例

Kotlin

fun detect(...) {
  if (!TfLiteVision.isInitialized()) {
    Log.e(TAG, "detect: TfLiteVision is not initialized yet")
    return
  }

  if (objectDetector == null) {
    setupObjectDetector()
  }

  ...

}

根据数据格式,您可能还需要在 detect() 方法中预处理和转换数据,然后才能生成推理。例如,目标检测器的图像数据需要以下内容

val imageProcessor = ImageProcessor.Builder().add(Rot90Op(-imageRotation / 90)).build()
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(image))
val results = objectDetector?.detect(tensorImage)

使用解释器 API

解释器 API 比任务库 API 提供了更多控制和灵活性。如果您的机器学习任务不受任务库支持,或者您需要一个更通用的接口来构建和运行机器学习模型,则应使用解释器 API。

1. 添加项目依赖项

将以下依赖项添加到您的应用程序项目代码中,以访问 TensorFlow Lite 的 Play 服务 API

dependencies {
...
    // Tensorflow Lite dependencies for Google Play services
    implementation 'com.google.android.gms:play-services-tflite-java:16.0.1'
    // Optional: include Tensorflow Lite Support Library
    implementation 'com.google.android.gms:play-services-tflite-support:16.0.1'
...
}

2. 添加 TensorFlow Lite 的初始化

在使用 TensorFlow Lite API 之前,初始化 Google Play 服务 API 的 TensorFlow Lite 组件

Kotlin

val initializeTask: Task<Void> by lazy { TfLite.initialize(this) }

Java

Task<Void> initializeTask = TfLite.initialize(context);

3. 创建解释器并设置运行时选项

使用 InterpreterApi.create() 创建解释器,并通过调用 InterpreterApi.Options.setRuntime() 将其配置为使用 Google Play 服务运行时,如以下示例代码所示

Kotlin

import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime
...
private lateinit var interpreter: InterpreterApi
...
initializeTask.addOnSuccessListener {
  val interpreterOption =
    InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
  interpreter = InterpreterApi.create(
    modelBuffer,
    interpreterOption
  )}
  .addOnFailureListener { e ->
    Log.e("Interpreter", "Cannot initialize interpreter", e)
  }

Java

import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime
...
private InterpreterApi interpreter;
...
initializeTask.addOnSuccessListener(a -> {
    interpreter = InterpreterApi.create(modelBuffer,
      new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY));
  })
  .addOnFailureListener(e -> {
    Log.e("Interpreter", String.format("Cannot initialize interpreter: %s",
          e.getMessage()));
  });

您应该使用上面的实现,因为它避免了阻塞 Android 用户界面线程。如果您需要更密切地管理线程执行,可以向解释器创建添加 Tasks.await() 调用

Kotlin

import androidx.lifecycle.lifecycleScope
...
lifecycleScope.launchWhenStarted { // uses coroutine
  initializeTask.await()
}

Java

@BackgroundThread
InterpreterApi initializeInterpreter() {
    Tasks.await(initializeTask);
    return InterpreterApi.create(...);
}

4. 运行推理

使用您创建的 interpreter 对象,调用 run() 方法以生成推理。

Kotlin

interpreter.run(inputBuffer, outputBuffer)

Java

interpreter.run(inputBuffer, outputBuffer);

硬件加速

TensorFlow Lite 允许您使用专用硬件处理器(例如图形处理单元 (GPU))来加速模型的性能。您可以使用称为 委托 的硬件驱动程序来利用这些专用处理器。您可以在 Google Play 服务中的 TensorFlow Lite 中使用以下硬件加速委托

  • GPU 委托(推荐) - 此委托通过 Google Play 服务提供,并与 Task API 和 Interpreter API 的 Play 服务版本一样,是动态加载的。

  • NNAPI 委托 - 此委托在您的 Android 开发项目中作为包含的库依赖项提供,并捆绑到您的应用程序中。

有关使用 TensorFlow Lite 进行硬件加速的更多信息,请参阅 TensorFlow Lite 委托 页面。

检查设备兼容性

并非所有设备都支持使用 TFLite 进行 GPU 硬件加速。为了减轻错误和潜在的崩溃,请使用 TfLiteGpu.isGpuDelegateAvailable 方法来检查设备是否与 GPU 委托兼容。

使用此方法确认设备是否与 GPU 兼容,并在不支持 GPU 时使用 CPU 或 NNAPI 委托作为后备。

useGpuTask = TfLiteGpu.isGpuDelegateAvailable(context)

获得类似 useGpuTask 的变量后,您可以使用它来确定设备是否使用 GPU 委托。以下示例展示了如何使用 Task 库和 Interpreter API 来实现这一点。

使用 Task API

Kotlin

lateinit val optionsTask = useGpuTask.continueWith { task ->
  val baseOptionsBuilder = BaseOptions.builder()
  if (task.result) {
    baseOptionsBuilder.useGpu()
  }
 ObjectDetectorOptions.builder()
          .setBaseOptions(baseOptionsBuilder.build())
          .setMaxResults(1)
          .build()
}
    

Java

Task<ObjectDetectorOptions> optionsTask = useGpuTask.continueWith({ task ->
  BaseOptions baseOptionsBuilder = BaseOptions.builder();
  if (task.getResult()) {
    baseOptionsBuilder.useGpu();
  }
  return ObjectDetectorOptions.builder()
          .setBaseOptions(baseOptionsBuilder.build())
          .setMaxResults(1)
          .build()
});
    

使用 Interpreter API

Kotlin

val interpreterTask = useGpuTask.continueWith { task ->
  val interpreterOptions = InterpreterApi.Options()
      .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
  if (task.result) {
      interpreterOptions.addDelegateFactory(GpuDelegateFactory())
  }
  InterpreterApi.create(FileUtil.loadMappedFile(context, MODEL_PATH), interpreterOptions)
}
    

Java

Task<InterpreterApi.Options> interpreterOptionsTask = useGpuTask.continueWith({ task ->
  InterpreterApi.Options options =
      new InterpreterApi.Options().setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY);
  if (task.getResult()) {
     options.addDelegateFactory(new GpuDelegateFactory());
  }
  return options;
});
    

使用 Task 库 API 的 GPU

要使用 Task API 的 GPU 委托

  1. 更新项目依赖项以使用来自 Play 服务的 GPU 委托

    implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0'
    
  2. 使用 setEnableGpuDelegateSupport 初始化 GPU 委托。例如,您可以使用以下方法为 TfLiteVision 初始化 GPU 委托

    Kotlin

        TfLiteVision.initialize(context, TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build())
        

    Java

        TfLiteVision.initialize(context, TfLiteInitializationOptions.builder().setEnableGpuDelegateSupport(true).build());
        
  3. 使用 BaseOptions 启用 GPU 委托选项

    Kotlin

        val baseOptions = BaseOptions.builder().useGpu().build()
        

    Java

        BaseOptions baseOptions = BaseOptions.builder().useGpu().build();
        
  4. 使用 .setBaseOptions 配置选项。例如,您可以使用以下方法在 ObjectDetector 中设置 GPU

    Kotlin

        val options =
            ObjectDetectorOptions.builder()
                .setBaseOptions(baseOptions)
                .setMaxResults(1)
                .build()
        

    Java

        ObjectDetectorOptions options =
            ObjectDetectorOptions.builder()
                .setBaseOptions(baseOptions)
                .setMaxResults(1)
                .build();
        

使用 Interpreter API 的 GPU

要使用 Interpreter API 的 GPU 委托

  1. 更新项目依赖项以使用来自 Play 服务的 GPU 委托

    implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0'
    
  2. 在 TFlite 初始化中启用 GPU 委托选项

    Kotlin

        TfLite.initialize(context,
          TfLiteInitializationOptions.builder()
           .setEnableGpuDelegateSupport(true)
           .build())
        

    Java

        TfLite.initialize(context,
          TfLiteInitializationOptions.builder()
           .setEnableGpuDelegateSupport(true)
           .build());
        
  3. 在解释器选项中启用 GPU 委托:通过在 InterpreterApi.Options() 中调用 addDelegateFactory() 将委托工厂设置为 GpuDelegateFactory

    Kotlin

        val interpreterOption = InterpreterApi.Options()
         .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
         .addDelegateFactory(GpuDelegateFactory())
        

    Java

        Options interpreterOption = InterpreterApi.Options()
          .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY)
          .addDelegateFactory(new GpuDelegateFactory());
        

从独立 TensorFlow Lite 迁移

如果您计划将您的应用程序从独立 TensorFlow Lite 迁移到 Play 服务 API,请查看以下有关更新应用程序项目代码的额外指南

  1. 查看本页的 限制 部分,以确保您的用例受支持。
  2. 在更新代码之前,请对您的模型进行性能和准确性检查,尤其是在使用低于 2.1 版本的 TensorFlow Lite 时,以便您有一个基线与新实现进行比较。
  3. 如果您已将所有代码迁移到使用 Play 服务 API 的 TensorFlow Lite,则应从您的 build.gradle 文件中删除现有的 TensorFlow Lite 运行时库 依赖项(具有 org.tensorflow:tensorflow-lite:* 的条目),以便您可以减小应用程序大小。
  4. 在您的代码中识别所有 new Interpreter 对象创建的出现,并修改每个对象,使其使用 InterpreterApi.create() 调用。新的 TfLite.initialize 是异步的,这意味着在大多数情况下它不是直接替换:您必须注册一个监听器,以便在调用完成时收到通知。请参考 步骤 3 代码中的代码片段。
  5. import org.tensorflow.lite.InterpreterApi;import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime; 添加到使用 org.tensorflow.lite.Interpreterorg.tensorflow.lite.InterpreterApi 类的任何源文件。
  6. 如果对 InterpreterApi.create() 的任何调用只有一个参数,请将 new InterpreterApi.Options() 附加到参数列表。
  7. .setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY) 附加到对 InterpreterApi.create() 的任何调用的最后一个参数。
  8. org.tensorflow.lite.Interpreter 类的所有其他出现替换为 org.tensorflow.lite.InterpreterApi

如果您想并排使用独立 TensorFlow Lite 和 Play 服务 API,则必须使用 TensorFlow Lite 2.9(或更高版本)。TensorFlow Lite 2.8 和更早版本与 Play 服务 API 版本不兼容。