在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
TensorFlow Lite 支持将 TensorFlow 模型的输入/输出规范转换为 TensorFlow Lite 模型。输入/输出规范称为“签名”。在构建 SavedModel 或创建具体函数时,可以指定签名。
TensorFlow Lite 中的签名提供以下功能
- 它们通过尊重 TensorFlow 模型的签名来指定转换后的 TensorFlow Lite 模型的输入和输出。
- 允许单个 TensorFlow Lite 模型支持多个入口点。
签名由三个部分组成
- 输入:从签名中的输入名称到输入张量的输入映射。
- 输出:从签名中的输出名称到输出张量的输出映射。
- 签名键:标识图入口点的名称。
设置
import tensorflow as tf
示例模型
假设我们有一个 TensorFlow 模型,它包含两个任务,例如编码和解码。
class Model(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def encode(self, x):
result = tf.strings.as_string(x)
return {
"encoded_result": result
}
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def decode(self, x):
result = tf.strings.to_number(x)
return {
"decoded_result": result
}
从签名的角度来看,上述 TensorFlow 模型可以概括如下
签名
- 键:encode
- 输入:{"x"}
- 输出:{"encoded_result"}
签名
- 键:decode
- 输入:{"x"}
- 输出:{"decoded_result"}
转换包含签名的模型
TensorFlow Lite 转换器 API 将上述签名信息引入转换后的 TensorFlow Lite 模型。
此转换功能从 TensorFlow 2.7.0 版本开始在所有转换器 API 上可用。请参阅示例用法。
从 SavedModel
model = Model()
# Save the model
SAVED_MODEL_PATH = 'content/saved_models/coding'
tf.saved_model.save(
model, SAVED_MODEL_PATH,
signatures={
'encode': model.encode.get_concrete_function(),
'decode': model.decode.get_concrete_function()
})
# Convert the saved model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
从 Keras 模型
# Generate a Keras model.
keras_model = tf.keras.Sequential(
[
tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),
tf.keras.layers.Dense(1, activation='relu', name='output'),
]
)
# Convert the keras model using TFLiteConverter.
# Keras model converter API uses the default signature automatically.
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
从具体函数
model = Model()
# Convert the concrete functions using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[model.encode.get_concrete_function(),
model.decode.get_concrete_function()], model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
运行签名
TensorFlow 推理 API 支持基于签名的执行
- 通过签名指定的输入和输出的名称访问输入/输出张量。
- 分别运行图的每个入口点,由签名键标识。
- 支持 SavedModel 的初始化过程。
目前提供 Java、C++ 和 Python 语言绑定。请参阅以下部分的示例。
Java
try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
// Run encoding signature.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("encoded_result", encoded_result);
interpreter.runSignature(inputs, outputs, "encode");
// Run decoding signature.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", encoded_result);
Map<String, Object> outputs = new HashMap<>();
outputs.put("decoded_result", decoded_result);
interpreter.runSignature(inputs, outputs, "decode");
}
C++
SignatureRunner* encode_runner =
interpreter->GetSignatureRunner("encode");
encode_runner->ResizeInputTensor("x", {100});
encode_runner->AllocateTensors();
TfLiteTensor* input_tensor = encode_runner->input_tensor("x");
float* input = GetTensorData<float>(input_tensor);
// Fill `input`.
encode_runner->Invoke();
const TfLiteTensor* output_tensor = encode_runner->output_tensor(
"encoded_result");
float* output = GetTensorData<float>(output_tensor);
// Access `output`.
Python
# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# Print the signatures from the converted model
signatures = interpreter.get_signature_list()
print('Signature:', signatures)
# encode and decode are callable with input as arguments.
encode = interpreter.get_signature_runner('encode')
decode = interpreter.get_signature_runner('decode')
# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.
input = tf.constant([1, 2, 3], dtype=tf.float32)
print('Input:', input)
encoded = encode(x=input)
print('Encoded result:', encoded)
decoded = decode(x=encoded['encoded_result'])
print('Decoded result:', decoded)
已知限制
- 由于 TFLite 解释器不保证线程安全,因此来自同一解释器的签名运行器不会并发执行。
- 目前尚不支持 iOS/Swift。
更新
- 版本 2.7
- 已实现多签名功能。
- 版本二中的所有转换器 API 都生成支持签名的 TensorFlow Lite 模型。
- 版本 2.5
- 签名功能可通过
from_saved_model
转换器 API 使用。
- 签名功能可通过