TFLite 作者工具

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow Lite 作者 API 提供了一种方法来维护您的 tf.function 模型与 TensorFlow Lite 兼容。

设置

import tensorflow as tf

TensorFlow 到 TensorFlow Lite 的兼容性问题

如果您想在设备上使用您的 TF 模型,您需要将其转换为 TFLite 模型,以便从 TFLite 解释器中使用它。在转换过程中,您可能会遇到兼容性错误,因为 TFLite 内置运算符集不支持 TensorFlow 运算符。

这是一个令人讨厌的问题。您如何在模型创作时尽早检测到它?

请注意,以下代码将在 converter.convert() 调用时失败。

@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
# Convert the tf.function
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [f.get_concrete_function()], f)
try:
  fb_model = converter.convert()
except Exception as e:
  print(f"Got an exception: {e}")

简单的目标感知创作使用

我们引入了创作 API,以便在模型创作时检测 TensorFlow Lite 兼容性问题。

您只需要添加 @tf.lite.experimental.authoring.compatible 装饰器来包装您的 tf.function 模型以检查 TFLite 兼容性。

在此之后,当您评估模型时,将自动检查兼容性。

@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")

如果发现任何 TensorFlow Lite 兼容性问题,它将显示 COMPATIBILITY WARNINGCOMPATIBILITY ERROR,并提供有问题的运算符的确切位置。在本例中,它显示了 tf.Cosh 运算符在您的 tf.function 模型中的位置。

您还可以使用 <function_name>.get_compatibility_log() 方法检查兼容性日志。

compatibility_log = '\n'.join(f.get_compatibility_log())
print (f"compatibility_log = {compatibility_log}")

为不兼容性引发异常

您可以为 @tf.lite.experimental.authoring.compatible 装饰器提供一个选项。当您尝试评估装饰的模型时,raise_exception 选项会为您提供一个异常。

@tf.lite.experimental.authoring.compatible(raise_exception=True)
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
try:
  result = f(tf.constant([0.0]))
  print (f"result = {result}")
except Exception as e:
  print(f"Got an exception: {e}")

指定“选择 TF 运算符”使用

如果您已经了解 选择 TF 运算符 的使用,您可以通过设置 converter_target_spec 将其告知创作 API。它与您将用于 tf.lite.TFLiteConverter API 的 tf.lite.TargetSpec 对象相同。

target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec, raise_exception=True)
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
  return tf.cosh(x)

# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")

检查 GPU 兼容性

如果您想确保您的模型与 TensorFlow Lite 的 GPU 代理 兼容,您可以设置 experimental_supported_backends 属性,该属性属于 tf.lite.TargetSpec

以下示例展示了如何确保模型与 GPU 代理兼容。请注意,此模型存在兼容性问题,因为它使用了带有 tf.slice 运算符的二维张量和不支持的 tf.cosh 运算符。您将看到两个带有位置信息的 COMPATIBILITY WARNING

target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]
target_spec.experimental_supported_backends = ["GPU"]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec)
@tf.function(input_signature=[
    tf.TensorSpec(shape=[4, 4], dtype=tf.float32)
])
def func(x):
  y = tf.cosh(x)
  return y + tf.slice(x, [1, 1], [1, 1])

result = func(tf.ones(shape=(4,4), dtype=tf.float32))

了解更多

有关更多信息,请参阅