使用 TensorFlow Lite 支持库处理输入和输出数据

移动应用程序开发人员通常与类型化对象(如位图)或基本类型(如整数)进行交互。但是,运行设备上机器学习模型的 TensorFlow Lite 解释器 API 使用 ByteBuffer 形式的张量,这可能难以调试和操作。 TensorFlow Lite Android 支持库 旨在帮助处理 TensorFlow Lite 模型的输入和输出,并使 TensorFlow Lite 解释器更易于使用。

入门

导入 Gradle 依赖项和其他设置

.tflite 模型文件复制到将运行模型的 Android 模块的 assets 目录。指定该文件不应压缩,并将 TensorFlow Lite 库添加到模块的 build.gradle 文件中

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import tflite dependencies
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
    // The GPU delegate library is optional. Depend on it as needed.
    implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly-SNAPSHOT'
    implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT'
}

探索 MavenCentral 上托管的 TensorFlow Lite 支持库 AAR 以获取支持库的不同版本。

基本图像操作和转换

TensorFlow Lite 支持库具有一套基本图像操作方法,例如裁剪和调整大小。要使用它,请创建一个 ImagePreprocessor 并添加所需的运算。要将图像转换为 TensorFlow Lite 解释器所需的张量格式,请创建一个 TensorImage 作为输入

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;

// Initialization code
// Create an ImageProcessor with all ops required. For more ops, please
// refer to the ImageProcessor Architecture section in this README.
ImageProcessor imageProcessor =
    new ImageProcessor.Builder()
        .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
        .build();

// Create a TensorImage object. This creates the tensor of the corresponding
// tensor type (uint8 in this case) that the TensorFlow Lite interpreter needs.
TensorImage tensorImage = new TensorImage(DataType.UINT8);

// Analysis code for every frame
// Preprocess the image
tensorImage.load(bitmap);
tensorImage = imageProcessor.process(tensorImage);

DataType 可以通过 元数据提取器库 以及其他模型信息读取。

基本音频数据处理

TensorFlow Lite 支持库还定义了一个 TensorAudio 类,它包装了一些基本的音频数据处理方法。它主要与 AudioRecord 一起使用,并在环形缓冲区中捕获音频样本。

import android.media.AudioRecord;
import org.tensorflow.lite.support.audio.TensorAudio;

// Create an `AudioRecord` instance.
AudioRecord record = AudioRecord(...)

// Create a `TensorAudio` object from Android AudioFormat.
TensorAudio tensorAudio = new TensorAudio(record.getFormat(), size)

// Load all audio samples available in the AudioRecord without blocking.
tensorAudio.load(record)

// Get the `TensorBuffer` for inference.
TensorBuffer buffer = tensorAudio.getTensorBuffer()

创建输出对象并运行模型

在运行模型之前,我们需要创建将存储结果的容器对象

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

// Create a container for the result and specify that this is a quantized model.
// Hence, the 'DataType' is defined as UINT8 (8-bit unsigned integer)
TensorBuffer probabilityBuffer =
    TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);

加载模型并运行推理

import java.nio.MappedByteBuffer;
import org.tensorflow.lite.InterpreterFactory;
import org.tensorflow.lite.InterpreterApi;

// Initialise the model
try{
    MappedByteBuffer tfliteModel
        = FileUtil.loadMappedFile(activity,
            "mobilenet_v1_1.0_224_quant.tflite");
    InterpreterApi tflite = new InterpreterFactory().create(
        tfliteModel, new InterpreterApi.Options());
} catch (IOException e){
    Log.e("tfliteSupport", "Error reading model", e);
}

// Running inference
if(null != tflite) {
    tflite.run(tImage.getBuffer(), probabilityBuffer.getBuffer());
}

访问结果

开发人员可以通过 probabilityBuffer.getFloatArray() 直接访问输出。如果模型生成量化输出,请记住转换结果。对于 MobileNet 量化模型,开发人员需要将每个输出值除以 255,以获得每个类别的概率范围,从 0(最不可能)到 1(最可能)。

可选:将结果映射到标签

开发人员还可以选择将结果映射到标签。首先,将包含标签的文本文件复制到模块的 assets 目录。接下来,使用以下代码加载标签文件

import org.tensorflow.lite.support.common.FileUtil;

final String ASSOCIATED_AXIS_LABELS = "labels.txt";
List<String> associatedAxisLabels = null;

try {
    associatedAxisLabels = FileUtil.loadLabels(this, ASSOCIATED_AXIS_LABELS);
} catch (IOException e) {
    Log.e("tfliteSupport", "Error reading label file", e);
}

以下代码段演示了如何将概率与类别标签关联起来

import java.util.Map;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.label.TensorLabel;

// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
    new TensorProcessor.Builder().add(new NormalizeOp(0, 255)).build();

if (null != associatedAxisLabels) {
    // Map of labels and their corresponding probability
    TensorLabel labels = new TensorLabel(associatedAxisLabels,
        probabilityProcessor.process(probabilityBuffer));

    // Create a map to access the result based on label
    Map<String, Float> floatMap = labels.getMapWithFloatValue();
}

当前用例覆盖范围

当前版本的 TensorFlow Lite 支持库涵盖了

  • 常见的 tflite 模型输入和输出数据类型(浮点数、uint8、图像、音频和这些对象的数组)。
  • 基本图像操作(裁剪图像、调整大小和旋转)。
  • 规范化和量化
  • 文件实用程序

未来版本将改进对文本相关应用程序的支持。

ImageProcessor 架构

ImageProcessor 的设计允许在构建过程中预先定义图像操作并进行优化。 ImageProcessor 目前支持三种基本预处理操作,如以下代码段中的三个注释所述

import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.common.ops.QuantizeOp;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;

int width = bitmap.getWidth();
int height = bitmap.getHeight();

int size = height > width ? width : height;

ImageProcessor imageProcessor =
    new ImageProcessor.Builder()
        // Center crop the image to the largest square possible
        .add(new ResizeWithCropOrPadOp(size, size))
        // Resize using Bilinear or Nearest neighbour
        .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR));
        // Rotation counter-clockwise in 90 degree increments
        .add(new Rot90Op(rotateDegrees / 90))
        .add(new NormalizeOp(127.5, 127.5))
        .add(new QuantizeOp(128.0, 1/128.0))
        .build();

有关规范化和量化的更多详细信息,请参阅 此处

支持库的最终目标是支持所有 tf.image 转换。这意味着转换将与 TensorFlow 相同,并且实现将独立于操作系统。

开发者也可以创建自定义处理器。在这些情况下,与训练过程保持一致非常重要 - 也就是说,相同的预处理应该应用于训练和推理,以提高可重复性。

量化

在初始化输入或输出对象(例如 TensorImageTensorBuffer)时,需要指定它们的类型为 DataType.UINT8DataType.FLOAT32

TensorImage tensorImage = new TensorImage(DataType.UINT8);
TensorBuffer probabilityBuffer =
    TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);

可以使用 TensorProcessor 对输入张量进行量化或对输出张量进行反量化。例如,在处理量化的输出 TensorBuffer 时,开发者可以使用 DequantizeOp 将结果反量化为 0 到 1 之间的浮点概率。

import org.tensorflow.lite.support.common.TensorProcessor;

// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
    new TensorProcessor.Builder().add(new DequantizeOp(0, 1/255.0)).build();
TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer);

可以通过 元数据提取器库 读取张量的量化参数。