Android 快速入门

本页面介绍如何使用 TensorFlow Lite 构建 Android 应用程序,以分析实时相机馈送并识别对象。此机器学习用例称为目标检测。示例应用程序使用 TensorFlow Lite 视觉任务库 通过 Google Play 服务 来启用目标检测机器学习模型的执行,这是使用 TensorFlow Lite 构建 ML 应用程序的推荐方法。

Object detection animated demo

设置和运行示例

在本练习的第一部分,从 GitHub 下载 示例代码,并使用 Android Studio 运行它。本文档的以下部分将探讨代码示例的相关部分,以便您可以将它们应用到自己的 Android 应用程序中。您需要安装以下版本的这些工具

  • Android Studio 4.2 或更高版本
  • Android SDK 版本 21 或更高版本

获取示例代码

创建示例代码的本地副本,以便您可以构建和运行它。

要克隆和设置示例代码

  1. 克隆 git 存储库
    git clone https://github.com/tensorflow/examples.git
    
  2. 配置您的 git 实例以使用稀疏检出,以便您只有目标检测示例应用程序的文件
    cd examples
    git sparse-checkout init --cone
    git sparse-checkout set lite/examples/object_detection/android_play_services
    

导入并运行项目

使用 Android Studio 从下载的示例代码创建项目,构建项目并运行它。

要导入和构建示例代码项目

  1. 启动 Android Studio
  2. 从 Android Studio 的欢迎页面中,选择导入项目,或选择文件 > 新建 > 导入项目
  3. 导航到包含 build.gradle 文件的示例代码目录 (...examples/lite/examples/object_detection/android_play_services/build.gradle) 并选择该目录。

选择此目录后,Android Studio 会创建一个新项目并构建它。构建完成后,Android Studio 会在构建输出状态面板中显示 BUILD SUCCESSFUL 消息。

要运行项目

  1. 从 Android Studio 中,通过选择运行 > 运行…MainActivity 来运行项目。
  2. 选择一个带有摄像头的已连接 Android 设备来测试应用程序。

示例应用程序的工作原理

示例应用程序使用预训练的目标检测模型,例如 mobilenetv1.tflite,以 TensorFlow Lite 格式在来自 Android 设备摄像头的实时视频流中查找对象。此功能的代码主要位于以下文件中

  • ObjectDetectorHelper.kt - 初始化运行时环境,启用硬件加速,并运行目标检测 ML 模型。
  • CameraFragment.kt - 构建相机图像数据流,准备模型数据,并显示目标检测结果。

接下来的部分将向您展示这些代码文件的关键组件,以便您可以修改 Android 应用程序以添加此功能。

构建应用程序

以下部分解释了构建自己的 Android 应用程序并运行示例应用程序中显示的模型的关键步骤。这些说明使用前面显示的示例应用程序作为参考点。

添加项目依赖项

在您的基本 Android 应用程序中,添加运行 TensorFlow Lite 机器学习模型和访问 ML 数据实用程序函数的项目依赖项。这些实用程序函数将图像等数据转换为可以由模型处理的张量数据格式。

示例应用程序使用来自 Google Play 服务 的 TensorFlow Lite 用于视觉的任务库 来启用目标检测机器学习模型的执行。以下说明解释了如何将所需的库依赖项添加到您自己的 Android 应用程序项目中。

要添加模块依赖项

  1. 在使用 TensorFlow Lite 的 Android 应用程序模块中,更新模块的 build.gradle 文件以包含以下依赖项。在示例代码中,此文件位于此处:...examples/lite/examples/object_detection/android_play_services/app/build.gradle

    ...
    dependencies {
    ...
        // Tensorflow Lite dependencies
        implementation 'org.tensorflow:tensorflow-lite-task-vision-play-services:0.4.2'
        implementation 'com.google.android.gms:play-services-tflite-gpu:16.1.0'
    ...
    }
    
  2. 在 Android Studio 中,通过选择以下选项同步项目依赖项:文件 > 使用 Gradle 文件同步项目

初始化 Google Play 服务

当您使用 Google Play 服务 运行 TensorFlow Lite 模型时,您必须在使用它之前初始化该服务。如果您想使用该服务的硬件加速支持,例如 GPU 加速,您还需要在初始化过程中启用该支持。

要使用 Google Play 服务初始化 TensorFlow Lite

  1. 创建一个 TfLiteInitializationOptions 对象并修改它以启用 GPU 支持

    val options = TfLiteInitializationOptions.builder()
        .setEnableGpuDelegateSupport(true)
        .build()
    
  2. 使用 TfLiteVision.initialize() 方法启用 Play 服务运行时的使用,并设置一个侦听器以验证它是否已成功加载

    TfLiteVision.initialize(context, options).addOnSuccessListener {
        objectDetectorListener.onInitialized()
    }.addOnFailureListener {
        // Called if the GPU Delegate is not supported on the device
        TfLiteVision.initialize(context).addOnSuccessListener {
            objectDetectorListener.onInitialized()
        }.addOnFailureListener{
            objectDetectorListener.onError("TfLiteVision failed to initialize: "
                    + it.message)
        }
    }
    

初始化 ML 模型解释器

通过加载模型文件并设置模型参数来初始化 TensorFlow Lite 机器学习模型解释器。TensorFlow Lite 模型包含一个 .tflite 文件,其中包含模型代码。您应该将模型存储在开发项目的 src/main/assets 目录中,例如

.../src/main/assets/mobilenetv1.tflite`

要初始化模型

  1. 将一个 .tflite 模型文件添加到开发项目的 src/main/assets 目录中,例如 ssd_mobilenet_v1
  2. modelName 变量设置为指定 ML 模型的文件名

    val modelName = "mobilenetv1.tflite"
    
  3. 设置模型选项,例如预测阈值和结果集大小

    val optionsBuilder =
        ObjectDetector.ObjectDetectorOptions.builder()
            .setScoreThreshold(threshold)
            .setMaxResults(maxResults)
    
  4. 使用选项启用 GPU 加速,并允许代码在设备不支持加速的情况下正常失败

    try {
        optionsBuilder.useGpu()
    } catch(e: Exception) {
        objectDetectorListener.onError("GPU is not supported on this device")
    }
    
    
  5. 使用此对象中的设置来构造一个 TensorFlow Lite ObjectDetector 对象,其中包含模型

    objectDetector =
        ObjectDetector.createFromFileAndOptions(
            context, modelName, optionsBuilder.build())
    

有关使用 TensorFlow Lite 的硬件加速委托的更多信息,请参阅 TensorFlow Lite 委托

准备模型数据

通过将现有数据(例如图像)转换为 张量 数据格式来准备用于解释的数据,以便模型可以处理它。张量中的数据必须具有与用于训练模型的数据格式匹配的特定维度或形状。根据您使用的模型,您可能需要转换数据以适合模型的预期。示例应用程序使用 ImageAnalysis 对象从相机子系统中提取图像帧。

要准备用于模型处理的数据

  1. 构建一个 ImageAnalysis 对象以提取所需格式的图像

    imageAnalyzer =
        ImageAnalysis.Builder()
            .setTargetAspectRatio(AspectRatio.RATIO_4_3)
            .setTargetRotation(fragmentCameraBinding.viewFinder.display.rotation)
            .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
            .setOutputImageFormat(OUTPUT_IMAGE_FORMAT_RGBA_8888)
            .build()
            ...
    
  2. 将分析器连接到相机子系统并创建一个位图缓冲区来包含从相机接收的数据

            .also {
            it.setAnalyzer(cameraExecutor) { image ->
                if (!::bitmapBuffer.isInitialized) {
                    bitmapBuffer = Bitmap.createBitmap(
                        image.width,
                        image.height,
                        Bitmap.Config.ARGB_8888
                    )
                }
                detectObjects(image)
            }
        }
    
  3. 提取模型所需的特定图像数据,并传递图像旋转信息

    private fun detectObjects(image: ImageProxy) {
        // Copy out RGB bits to the shared bitmap buffer
        image.use { bitmapBuffer.copyPixelsFromBuffer(image.planes[0].buffer) }
        val imageRotation = image.imageInfo.rotationDegrees
        objectDetectorHelper.detect(bitmapBuffer, imageRotation)
    }    
    
  4. 完成任何最终的数据转换并将图像数据添加到 TensorImage 对象中,如示例应用程序的 ObjectDetectorHelper.detect() 方法中所示

    val imageProcessor = ImageProcessor.Builder().add(Rot90Op(-imageRotation / 90)).build()
    
    // Preprocess the image and convert it into a TensorImage for detection.
    val tensorImage = imageProcessor.process(TensorImage.fromBitmap(image))
    

运行预测

创建具有正确格式的图像数据的 TensorImage 对象后,您可以对该数据运行模型以生成预测或推断。在示例应用程序中,此代码包含在 ObjectDetectorHelper.detect() 方法中。

要运行模型并从图像数据生成预测

  • 通过将图像数据传递给您的预测函数来运行预测

    val results = objectDetector?.detect(tensorImage)
    

处理模型输出

在您对目标检测模型运行图像数据后,它会生成一个预测结果列表,您的应用程序代码必须通过执行其他业务逻辑、向用户显示结果或采取其他操作来处理这些结果。示例应用程序中的目标检测模型会生成一个预测列表和已检测对象的边界框。在示例应用程序中,预测结果将传递给侦听器对象以进行进一步处理并显示给用户。

要处理模型预测结果

  1. 使用侦听器模式将结果传递给您的应用程序代码或用户界面对象。示例应用程序使用此模式将检测结果从 ObjectDetectorHelper 对象传递到 CameraFragment 对象

    objectDetectorListener.onResults( // instance of CameraFragment
        results,
        inferenceTime,
        tensorImage.height,
        tensorImage.width)
    
  2. 对结果采取行动,例如向用户显示预测。示例应用程序在 CameraPreview 对象上绘制一个叠加层以显示结果

    override fun onResults(
      results: MutableList<Detection>?,
      inferenceTime: Long,
      imageHeight: Int,
      imageWidth: Int
    ) {
        activity?.runOnUiThread {
            fragmentCameraBinding.bottomSheetLayout.inferenceTimeVal.text =
                String.format("%d ms", inferenceTime)
    
            // Pass necessary information to OverlayView for drawing on the canvas
            fragmentCameraBinding.overlay.setResults(
                results ?: LinkedList<Detection>(),
                imageHeight,
                imageWidth
            )
    
            // Force a redraw
            fragmentCameraBinding.overlay.invalidate()
        }
    }
    

后续步骤