TensorFlow Lite 元数据写入器 API

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow Lite 模型元数据 是一种标准的模型描述格式。它包含关于一般模型信息、输入/输出和关联文件的丰富语义,这使得模型具有自描述性和可交换性。

模型元数据目前在以下两个主要用例中使用

  1. 使用 TensorFlow Lite 任务库代码生成工具 轻松进行模型推理。 模型元数据包含推理过程中所需的强制信息,例如图像分类中的标签文件、音频分类中音频输入的采样率以及自然语言模型中处理输入字符串的标记器类型。

  2. 使模型创建者能够包含文档,例如模型输入/输出的描述或如何使用模型。模型用户可以通过可视化工具(如 Netron)查看这些文档。

TensorFlow Lite 元数据写入器 API 提供了一个易于使用的 API,用于为 TFLite 任务库支持的流行机器学习任务创建模型元数据。此笔记本展示了如何为以下任务填充元数据的示例

用于 BERT 自然语言分类器和 BERT 问答器的元数据写入器即将推出。

如果您想为不支持的用例添加元数据,请使用 Flatbuffers Python API。请参阅 此处 的教程。

先决条件

安装 TensorFlow Lite 支持 Pypi 包。

pip install tflite-support-nightly

为任务库和代码生成创建模型元数据

图像分类器

有关支持的模型格式的更多详细信息,请参阅 图像分类器模型兼容性要求

步骤 1:导入所需的包。

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import writer_utils

步骤 2:下载示例图像分类器 mobilenet_v2_1.0_224.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt

步骤 3:创建元数据写入器并填充。

ImageClassifierWriter = image_classifier.MetadataWriter
_MODEL_PATH = "mobilenet_v2_1.0_224.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "mobilenet_labels.txt"
_SAVE_TO_PATH = "mobilenet_v2_1.0_224_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://tensorflowcn.cn/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ImageClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

目标检测器

有关支持的模型格式的更多详细信息,请参阅 目标检测器模型兼容性要求

步骤 1:导入所需的包。

from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

步骤 2:下载示例目标检测器,ssd_mobilenet_v1.tflite,以及 标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/ssd_mobilenet_v1.tflite -o ssd_mobilenet_v1.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/object_detector/labelmap.txt -o ssd_mobilenet_labels.txt

步骤 3:创建元数据写入器并填充。

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "ssd_mobilenet_labels.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://tensorflowcn.cn/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

图像分割器

有关支持的模型格式的更多详细信息,请参阅 图像分割器模型兼容性要求

步骤 1:导入所需的包。

from tflite_support.metadata_writers import image_segmenter
from tflite_support.metadata_writers import writer_utils

步骤 2:下载示例图像分割器,deeplabv3.tflite,以及 标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/deeplabv3.tflite -o deeplabv3.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_segmenter/labelmap.txt -o deeplabv3_labels.txt

步骤 3:创建元数据写入器并填充。

ImageSegmenterWriter = image_segmenter.MetadataWriter
_MODEL_PATH = "deeplabv3.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "deeplabv3_labels.txt"
_SAVE_TO_PATH = "deeplabv3_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://tensorflowcn.cn/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ImageSegmenterWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

自然语言分类器

有关支持的模型格式的更多详细信息,请参阅 自然语言分类器模型兼容性要求

步骤 1:导入所需的包。

from tflite_support.metadata_writers import nl_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

步骤 2:下载示例自然语言分类器,movie_review.tflite标签文件,以及 词汇文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/movie_review.tflite -o movie_review.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier/labels.txt -o movie_review_labels.txt
curl -L https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/nl_classifier/vocab.txt -o movie_review_vocab.txt

步骤 3:创建元数据写入器并填充。

NLClassifierWriter = nl_classifier.MetadataWriter
_MODEL_PATH = "movie_review.tflite"
# Task Library expects label files and vocab files that are in the same formats
# as the ones below.
_LABEL_FILE = "movie_review_labels.txt"
_VOCAB_FILE = "movie_review_vocab.txt"
# NLClassifier supports tokenize input string using the regex tokenizer. See
# more details about how to set up RegexTokenizer below:
# https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py#L130
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_SAVE_TO_PATH = "moview_review_metadata.tflite"

# Create the metadata writer.
writer = nl_classifier.MetadataWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH),
    metadata_info.RegexTokenizerMd(_DELIM_REGEX_PATTERN, _VOCAB_FILE),
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

音频分类器

有关支持的模型格式的更多详细信息,请参阅 音频分类器模型兼容性要求

步骤 1:导入所需的包。

from tflite_support.metadata_writers import audio_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

步骤 2:下载示例音频分类器,yamnet.tflite,以及 标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_wavin_quantized_mel_relu6.tflite -o yamnet.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/audio_classifier/yamnet_521_labels.txt -o yamnet_labels.txt

步骤 3:创建元数据写入器并填充。

AudioClassifierWriter = audio_classifier.MetadataWriter
_MODEL_PATH = "yamnet.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "yamnet_labels.txt"
# Expected sampling rate of the input audio buffer.
_SAMPLE_RATE = 16000
# Expected number of channels of the input audio buffer. Note, Task library only
# support single channel so far.
_CHANNELS = 1
_SAVE_TO_PATH = "yamnet_metadata.tflite"

# Create the metadata writer.
writer = AudioClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

使用语义信息创建模型元数据

您可以通过元数据写入器 API 填写有关模型和每个张量的更多描述性信息,以帮助提高模型理解。这可以通过每个元数据写入器中的“create_from_metadata_info”方法完成。通常,您可以通过“create_from_metadata_info”的参数填写数据,即 general_mdinput_mdoutput_md。请参阅以下示例,以了解如何为图像分类器创建丰富的模型元数据。

步骤 1:导入所需的包。

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata_schema_py_generated as _metadata_fb

步骤 2:下载示例图像分类器 mobilenet_v2_1.0_224.tflite标签文件

curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite -o mobilenet_v2_1.0_224.tflite
curl -L https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt

步骤 3:创建模型和张量信息。

model_buffer = writer_utils.load_file("mobilenet_v2_1.0_224.tflite")

# Create general model information.
general_md = metadata_info.GeneralMd(
    name="ImageClassifier",
    version="v1",
    description=("Identify the most prominent object in the image from a "
                 "known set of categories."),
    author="TensorFlow Lite",
    licenses="Apache License. Version 2.0")

# Create input tensor information.
input_md = metadata_info.InputImageTensorMd(
    name="input image",
    description=("Input image to be classified. The expected image is "
                 "128 x 128, with three channels (red, blue, and green) per "
                 "pixel. Each element in the tensor is a value between min and "
                 "max, where (per-channel) min is [0] and max is [255]."),
    norm_mean=[127.5],
    norm_std=[127.5],
    color_space_type=_metadata_fb.ColorSpaceType.RGB,
    tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])

# Create output tensor information.
output_md = metadata_info.ClassificationTensorMd(
    name="probability",
    description="Probabilities of the 1001 labels respectively.",
    label_files=[
        metadata_info.LabelFileMd(file_path="mobilenet_labels.txt",
                                  locale="en")
    ],
    tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])

步骤 4:创建元数据写入器并填充。

ImageClassifierWriter = image_classifier.MetadataWriter
# Create the metadata writer.
writer = ImageClassifierWriter.create_from_metadata_info(
    model_buffer, general_md, input_md, output_md)

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

读取填充到模型中的元数据。

您可以通过以下代码显示 TFLite 模型中的元数据和关联文件

from tflite_support import metadata

displayer = metadata.MetadataDisplayer.with_model_file("mobilenet_v2_1.0_224_metadata.tflite")
print("Metadata populated:")
print(displayer.get_metadata_json())

print("Associated file(s) populated:")
for file_name in displayer.get_packed_associated_file_list():
  print("file name: ", file_name)
  print("file content:")
  print(displayer.get_associated_file_buffer(file_name))