在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
TensorFlow Lite 作者 API 提供了一种方法来维护您的 tf.function
模型与 TensorFlow Lite 兼容。
设置
import tensorflow as tf
TensorFlow 到 TensorFlow Lite 的兼容性问题
如果您想在设备上使用您的 TF 模型,您需要将其转换为 TFLite 模型,以便从 TFLite 解释器中使用它。在转换过程中,您可能会遇到兼容性错误,因为 TFLite 内置运算符集不支持 TensorFlow 运算符。
这是一个令人讨厌的问题。您如何在模型创作时尽早检测到它?
请注意,以下代码将在 converter.convert()
调用时失败。
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
# Convert the tf.function
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[f.get_concrete_function()], f)
try:
fb_model = converter.convert()
except Exception as e:
print(f"Got an exception: {e}")
简单的目标感知创作使用
我们引入了创作 API,以便在模型创作时检测 TensorFlow Lite 兼容性问题。
您只需要添加 @tf.lite.experimental.authoring.compatible
装饰器来包装您的 tf.function
模型以检查 TFLite 兼容性。
在此之后,当您评估模型时,将自动检查兼容性。
@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
如果发现任何 TensorFlow Lite 兼容性问题,它将显示 COMPATIBILITY WARNING
或 COMPATIBILITY ERROR
,并提供有问题的运算符的确切位置。在本例中,它显示了 tf.Cosh
运算符在您的 tf.function 模型中的位置。
您还可以使用 <function_name>.get_compatibility_log()
方法检查兼容性日志。
compatibility_log = '\n'.join(f.get_compatibility_log())
print (f"compatibility_log = {compatibility_log}")
为不兼容性引发异常
您可以为 @tf.lite.experimental.authoring.compatible
装饰器提供一个选项。当您尝试评估装饰的模型时,raise_exception
选项会为您提供一个异常。
@tf.lite.experimental.authoring.compatible(raise_exception=True)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
try:
result = f(tf.constant([0.0]))
print (f"result = {result}")
except Exception as e:
print(f"Got an exception: {e}")
指定“选择 TF 运算符”使用
如果您已经了解 选择 TF 运算符 的使用,您可以通过设置 converter_target_spec
将其告知创作 API。它与您将用于 tf.lite.TFLiteConverter API 的 tf.lite.TargetSpec 对象相同。
target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec, raise_exception=True)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
检查 GPU 兼容性
如果您想确保您的模型与 TensorFlow Lite 的 GPU 代理 兼容,您可以设置 experimental_supported_backends
属性,该属性属于 tf.lite.TargetSpec。
以下示例展示了如何确保模型与 GPU 代理兼容。请注意,此模型存在兼容性问题,因为它使用了带有 tf.slice 运算符的二维张量和不支持的 tf.cosh 运算符。您将看到两个带有位置信息的 COMPATIBILITY WARNING
。
target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
target_spec.experimental_supported_backends = ["GPU"]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec)
@tf.function(input_signature=[
tf.TensorSpec(shape=[4, 4], dtype=tf.float32)
])
def func(x):
y = tf.cosh(x)
return y + tf.slice(x, [1, 1], [1, 1])
result = func(tf.ones(shape=(4,4), dtype=tf.float32))
了解更多
有关更多信息,请参阅