移动应用程序开发人员通常与类型化对象(如位图)或基本类型(如整数)进行交互。但是,运行设备上机器学习模型的 TensorFlow Lite 解释器 API 使用 ByteBuffer 形式的张量,这可能难以调试和操作。 TensorFlow Lite Android 支持库 旨在帮助处理 TensorFlow Lite 模型的输入和输出,并使 TensorFlow Lite 解释器更易于使用。
入门
导入 Gradle 依赖项和其他设置
将 .tflite
模型文件复制到将运行模型的 Android 模块的 assets 目录。指定该文件不应压缩,并将 TensorFlow Lite 库添加到模块的 build.gradle
文件中
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import tflite dependencies
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
// The GPU delegate library is optional. Depend on it as needed.
implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly-SNAPSHOT'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT'
}
探索 MavenCentral 上托管的 TensorFlow Lite 支持库 AAR 以获取支持库的不同版本。
基本图像操作和转换
TensorFlow Lite 支持库具有一套基本图像操作方法,例如裁剪和调整大小。要使用它,请创建一个 ImagePreprocessor
并添加所需的运算。要将图像转换为 TensorFlow Lite 解释器所需的张量格式,请创建一个 TensorImage
作为输入
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
// Initialization code
// Create an ImageProcessor with all ops required. For more ops, please
// refer to the ImageProcessor Architecture section in this README.
ImageProcessor imageProcessor =
new ImageProcessor.Builder()
.add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
.build();
// Create a TensorImage object. This creates the tensor of the corresponding
// tensor type (uint8 in this case) that the TensorFlow Lite interpreter needs.
TensorImage tensorImage = new TensorImage(DataType.UINT8);
// Analysis code for every frame
// Preprocess the image
tensorImage.load(bitmap);
tensorImage = imageProcessor.process(tensorImage);
DataType
可以通过 元数据提取器库 以及其他模型信息读取。
基本音频数据处理
TensorFlow Lite 支持库还定义了一个 TensorAudio
类,它包装了一些基本的音频数据处理方法。它主要与 AudioRecord 一起使用,并在环形缓冲区中捕获音频样本。
import android.media.AudioRecord;
import org.tensorflow.lite.support.audio.TensorAudio;
// Create an `AudioRecord` instance.
AudioRecord record = AudioRecord(...)
// Create a `TensorAudio` object from Android AudioFormat.
TensorAudio tensorAudio = new TensorAudio(record.getFormat(), size)
// Load all audio samples available in the AudioRecord without blocking.
tensorAudio.load(record)
// Get the `TensorBuffer` for inference.
TensorBuffer buffer = tensorAudio.getTensorBuffer()
创建输出对象并运行模型
在运行模型之前,我们需要创建将存储结果的容器对象
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
// Create a container for the result and specify that this is a quantized model.
// Hence, the 'DataType' is defined as UINT8 (8-bit unsigned integer)
TensorBuffer probabilityBuffer =
TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);
加载模型并运行推理
import java.nio.MappedByteBuffer;
import org.tensorflow.lite.InterpreterFactory;
import org.tensorflow.lite.InterpreterApi;
// Initialise the model
try{
MappedByteBuffer tfliteModel
= FileUtil.loadMappedFile(activity,
"mobilenet_v1_1.0_224_quant.tflite");
InterpreterApi tflite = new InterpreterFactory().create(
tfliteModel, new InterpreterApi.Options());
} catch (IOException e){
Log.e("tfliteSupport", "Error reading model", e);
}
// Running inference
if(null != tflite) {
tflite.run(tImage.getBuffer(), probabilityBuffer.getBuffer());
}
访问结果
开发人员可以通过 probabilityBuffer.getFloatArray()
直接访问输出。如果模型生成量化输出,请记住转换结果。对于 MobileNet 量化模型,开发人员需要将每个输出值除以 255,以获得每个类别的概率范围,从 0(最不可能)到 1(最可能)。
可选:将结果映射到标签
开发人员还可以选择将结果映射到标签。首先,将包含标签的文本文件复制到模块的 assets 目录。接下来,使用以下代码加载标签文件
import org.tensorflow.lite.support.common.FileUtil;
final String ASSOCIATED_AXIS_LABELS = "labels.txt";
List<String> associatedAxisLabels = null;
try {
associatedAxisLabels = FileUtil.loadLabels(this, ASSOCIATED_AXIS_LABELS);
} catch (IOException e) {
Log.e("tfliteSupport", "Error reading label file", e);
}
以下代码段演示了如何将概率与类别标签关联起来
import java.util.Map;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.label.TensorLabel;
// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
new TensorProcessor.Builder().add(new NormalizeOp(0, 255)).build();
if (null != associatedAxisLabels) {
// Map of labels and their corresponding probability
TensorLabel labels = new TensorLabel(associatedAxisLabels,
probabilityProcessor.process(probabilityBuffer));
// Create a map to access the result based on label
Map<String, Float> floatMap = labels.getMapWithFloatValue();
}
当前用例覆盖范围
当前版本的 TensorFlow Lite 支持库涵盖了
- 常见的 tflite 模型输入和输出数据类型(浮点数、uint8、图像、音频和这些对象的数组)。
- 基本图像操作(裁剪图像、调整大小和旋转)。
- 规范化和量化
- 文件实用程序
未来版本将改进对文本相关应用程序的支持。
ImageProcessor 架构
ImageProcessor
的设计允许在构建过程中预先定义图像操作并进行优化。 ImageProcessor
目前支持三种基本预处理操作,如以下代码段中的三个注释所述
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.common.ops.QuantizeOp;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
int width = bitmap.getWidth();
int height = bitmap.getHeight();
int size = height > width ? width : height;
ImageProcessor imageProcessor =
new ImageProcessor.Builder()
// Center crop the image to the largest square possible
.add(new ResizeWithCropOrPadOp(size, size))
// Resize using Bilinear or Nearest neighbour
.add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR));
// Rotation counter-clockwise in 90 degree increments
.add(new Rot90Op(rotateDegrees / 90))
.add(new NormalizeOp(127.5, 127.5))
.add(new QuantizeOp(128.0, 1/128.0))
.build();
有关规范化和量化的更多详细信息,请参阅 此处。
支持库的最终目标是支持所有 tf.image
转换。这意味着转换将与 TensorFlow 相同,并且实现将独立于操作系统。
开发者也可以创建自定义处理器。在这些情况下,与训练过程保持一致非常重要 - 也就是说,相同的预处理应该应用于训练和推理,以提高可重复性。
量化
在初始化输入或输出对象(例如 TensorImage
或 TensorBuffer
)时,需要指定它们的类型为 DataType.UINT8
或 DataType.FLOAT32
。
TensorImage tensorImage = new TensorImage(DataType.UINT8);
TensorBuffer probabilityBuffer =
TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);
可以使用 TensorProcessor
对输入张量进行量化或对输出张量进行反量化。例如,在处理量化的输出 TensorBuffer
时,开发者可以使用 DequantizeOp
将结果反量化为 0 到 1 之间的浮点概率。
import org.tensorflow.lite.support.common.TensorProcessor;
// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
new TensorProcessor.Builder().add(new DequantizeOp(0, 1/255.0)).build();
TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer);
可以通过 元数据提取器库 读取张量的量化参数。