由于 TensorFlow Lite 内置运算符库仅支持有限数量的 TensorFlow 运算符,因此并非所有模型都可转换。有关详细信息,请参阅 运算符兼容性。
为了允许转换,用户可以在 TensorFlow Lite 中提供他们自己对不受支持的 TensorFlow 运算符的自定义实现,称为自定义运算符。如果您希望将一系列不受支持(或支持)的 TensorFlow 运算符组合成单个融合的优化自定义运算符,请参阅 运算符融合。
使用自定义运算符包括四个步骤。
创建 TensorFlow 模型。 确保保存的模型(或图形定义)引用了正确命名的 TensorFlow Lite 运算符。
转换为 TensorFlow Lite 模型。 确保设置了正确的 TensorFlow Lite 转换器属性,以便成功转换模型。
创建并注册运算符。 这样 TensorFlow Lite 运行时就知道如何将图中的运算符和参数映射到可执行的 C/C++ 代码。
测试和分析您的运算符。 如果您只想测试自定义运算符,最好创建一个仅包含自定义运算符的模型,并使用 benchmark_model 程序。
让我们逐步了解一个使用自定义运算符 tf.atan
(命名为 Atan
,请参阅 #create_a_tensorflow_model)运行模型的端到端示例,该运算符在 TensorFlow 中受支持,但在 TensorFlow Lite 中不受支持。
TensorFlow 文本运算符是自定义运算符的一个示例。有关代码示例,请参阅 将 TF 文本转换为 TF Lite 教程。
示例:自定义 Atan
运算符
让我们逐步了解一个支持 TensorFlow Lite 没有的 TensorFlow 运算符的示例。假设我们使用 Atan
运算符,并且我们正在为函数 y = atan(x + offset)
构建一个非常简单的模型,其中 offset
是可训练的。
创建 TensorFlow 模型
以下代码片段训练了一个简单的 TensorFlow 模型。该模型只包含一个名为 Atan
的自定义运算符,它是一个函数 y = atan(x + offset)
,其中 offset
是可训练的。
import tensorflow as tf
# Define training dataset and variables
x = [-8, 0.5, 2, 2.2, 201]
y = [-1.4288993, 0.98279375, 1.2490457, 1.2679114, 1.5658458]
offset = tf.Variable(0.0)
# Define a simple model which just contains a custom operator named `Atan`
@tf.function(input_signature=[tf.TensorSpec.from_tensor(tf.constant(x))])
def atan(x):
return tf.atan(x + offset, name="Atan")
# Train model
optimizer = tf.optimizers.Adam(0.01)
def train(x, y):
with tf.GradientTape() as t:
predicted_y = atan(x)
loss = tf.reduce_sum(tf.square(predicted_y - y))
grads = t.gradient(loss, [offset])
optimizer.apply_gradients(zip(grads, [offset]))
for i in range(1000):
train(x, y)
print("The actual offset is: 1.0")
print("The predicted offset is:", offset.numpy())
The actual offset is: 1.0
The predicted offset is: 0.99999905
此时,如果您尝试使用默认转换器标志生成 TensorFlow Lite 模型,您将收到以下错误消息
Error:
error: 'tf.Atan' op is neither a custom op nor a flex op.
转换为 TensorFlow Lite 模型
通过将转换器属性 allow_custom_ops
设置为如下所示,创建一个包含自定义操作符的 TensorFlow Lite 模型。
converter = tf.lite.TFLiteConverter.from_concrete_functions([atan.get_concrete_function()], atan) converter.allow_custom_ops = True tflite_model = converter.convert()
此时,如果您使用以下命令运行它,使用默认解释器
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
您仍然会收到错误
Encountered unresolved custom op: Atan.
创建并注册操作符。
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_opaque.h"
TensorFlow Lite 自定义操作符使用简单的纯 C API 定义,该 API 由一个不透明类型 (TfLiteRegistrationExternal
) 和相关函数组成。
TfLiteRegistrationExternal
是一个不透明类型
typedef struct TfLiteRegistrationExternal TfLiteRegistrationExternal;
TfLiteRegistrationExternal
存储操作符的标识和实现。(请注意,操作符与其操作数不同,操作数存储在调用操作符的节点的 TF Lite 图节点中。)
此类型的实例使用对 TfLiteRegistrationExternalCreate
的调用构建,可以通过调用 TfLiteRegistrationExternalDelete
销毁。
操作符的标识通过构造函数 TfLiteRegistrationExternalCreate
的参数设置
TfLiteRegistrationExternal*
TfLiteRegistrationExternalCreate(
TfLiteBuiltinOperator builtin_code, // Normally `TfLiteBuiltinCustom`.
const char* custom_name, // The name of the custom op.
int version // Normally `1` for the first version of a custom op.
);
操作符实现可以定义具有以下签名的“方法”。所有这些方法都是可选的,但为了成功评估操作符,操作符实现需要定义并设置(使用 setter 函数)至少 Prepare
和 Invoke
方法。
// Initializes the op from serialized data.
void* Init(TfLiteOpaqueContext* context, const char* buffer, size_t length);
// Deallocates the op.
// The pointer `buffer` is the data previously returned by an Init invocation.
void Free(TfLiteOpaqueContext* context, void* buffer);
// Called when the inputs that this node depends on have been resized.
TfLiteStatus Prepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
// Called when the node is executed. (Should read node inputs and write to
// node outputs).
TfLiteStatus Invoke(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);
// Retrieves the async kernel.
TfLiteAsyncKernel AsyncKernel(TfLiteOpaqueContext* context,
TfLiteOpaqueNode* node);
您在操作实现中的函数 *名称*(或 C++ 的命名空间前缀)不必与上面代码段中的函数名称匹配,因为 TF Lite 自定义操作 API 仅使用它们的地址。实际上,我们建议您在匿名命名空间中或作为静态函数声明它们。
但最好将您的操作符名称作为命名空间或前缀包含在这些函数名称中
C++
namespace my_namespace::my_custom_op { void* Init(TfLiteOpaqueContext* context, const char* buffer, size_t length) { ... } // ... plus definitions of Free, Prepare, and Invoke ... }
C
void* MyCustomOpInit(TfLiteOpaqueContext* context, const char* buffer, size_t length) { ... } // ... plus definitions of MyCustomOpFree, MyCustomOpPrepare, and // MyCustomOpInvoke.
由于这是一个 C API,这些“方法”在 TfLiteRegistrationExternal
类型中实现为 C 函数指针,这些指针通过将您实现函数的地址传递给相应的 setter 函数 TfLiteRegistrationExternalSet
MethodName 来设置。
void TfLiteRegistrationExternalSetInit(
TfLiteRegistrationExternal* registration,
void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
size_t length));
void TfLiteRegistrationExternalSetFree(
TfLiteRegistrationExternal* registration,
void (*free)(TfLiteOpaqueContext* context, void* data));
void TfLiteRegistrationExternalSetPrepare(
TfLiteRegistrationExternal* registration,
TfLiteStatus (*prepare)(TfLiteOpaqueContext* context,
TfLiteOpaqueNode* node));
void TfLiteRegistrationExternalSetInvoke(
TfLiteRegistrationExternal* registration,
TfLiteStatus (*invoke)(TfLiteOpaqueContext* context,
TfLiteOpaqueNode* node));
void TfLiteRegistrationExternalSetAsyncKernel(
TfLiteRegistrationExternal* registration,
struct TfLiteAsyncKernel* (*async_kernel)(TfLiteOpaqueContext* context,
TfLiteOpaqueNode* node));
有关 TfLiteContext
和 TfLiteNode
的详细信息,请参阅 common.h
。 TfLiteContext
提供错误报告功能以及对全局对象的访问,包括所有张量。 TfLiteNode
允许操作符实现访问其输入和输出。
当解释器加载模型时,它会为图中的每个节点调用一次 Init()
方法。如果操作在图中多次使用,则会多次调用给定的 Init()
。对于自定义操作,将提供一个配置缓冲区,其中包含一个将参数名称映射到其值的 flexbuffer。内置操作的缓冲区为空,因为解释器已经解析了操作参数。需要状态的内核实现应在此处初始化状态并将所有权转移给调用者。对于每个 Init()
调用,都会有一个相应的 Free()
调用,允许实现处置它们可能在 Init()
中分配的缓冲区。
每当输入张量大小调整时,解释器都会遍历图,通知实现更改。这使它们有机会调整其内部缓冲区,检查输入形状和类型的有效性,并重新计算输出形状。所有这些都通过 Prepare()
方法完成,实现可以使用 TfLiteOpaqueNodeGetUserData(node)
访问其状态。
最后,每次推理运行时,解释器都会遍历图,调用 Invoke()
方法,并且状态也可以作为 TfLiteOpaqueNodeGetUserData(node)
获得。
自定义操作可以通过定义这些“方法”函数来实现,然后定义一个函数,该函数返回通过调用 TfLiteRegistrationExternalCreate
然后调用相关 setter 方法构建的 TfLiteRegistrationExternal
实例
C++
namespace my_namespace::my_custom_op { namespace { void* Init(TfLiteOpaqueContext* context, const char* buffer, size_t length) { ... } void Free(TfLiteOpaqueContext* context, void* buffer) { ... } TfLiteStatus Prepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { ... } TfLiteStatus Invoke(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {... } }; const TfLiteRegistrationExternal* MyCustomOpRegistrationExternal() { // Singleton instance, intentionally never destroyed. static const TfLiteRegistrationExternal* my_custom_op = ()[] { TfLiteRegistrationExternal* r = TfLiteRegistrationExternalCreate( kTfLiteBuiltinCustom, "MyCustomOp", /*version=*/ 1); TfLiteRegistrationExternalSetInit(r, Init); TfLiteRegistrationExternalSetFree(r, Free); TfLiteRegistrationExternalSetPrepare(r, Prepare); TfLiteRegistrationExternalSetInvoke(r, Eval); return r; }; return my_custom_op; } const TfLiteRegistration* MyCustomOpRegistration() { static const TfLiteRegistration my_custom_op { .registration_external = MyCustomOpRegistrationExternal(); }; return my_custom_op; } } // namespace my_namespace
C
static void* MyCustomOpInit(TfLiteOpaqueContext* context, const char* buffer, size_t length) { ... } static void MyCustomOpFree(TfLiteOpaqueContext* context, void* buffer) { ... } static TfLiteStatus MyCustomOpPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { ... } static TfLiteStatus MyCustomOpInvoke(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {... } static TfLiteRegistrationExternal* MyCustomOpCreate() { const TfLiteRegistrationExternal* r = TfLiteRegistrationExternalCreate( kTfLiteBuiltinCustom, "MyCustomOp", /*version=*/ 1); TfLiteRegistrationExternalSetInit(r, MyCustomOpInit); TfLiteRegistrationExternalSetFree(r, MyCustomOpFree); TfLiteRegistrationExternalSetPrepare(r, MyCustomOpPrepare); TfLiteRegistrationExternalSetInvoke(r, MyCustomOpEval); return r; } const TfLiteRegistrationExternal* MyCustomOpRegistrationExternal() { // Singleton instance, intentionally never destroyed. static const TfLiteRegistrationExternal* my_custom_op = MyCustomOpCreate(); return my_custom_op; } const TfLiteRegistration MyCustomOpRegistration() { static const TfLiteRegistration my_custom_op { .registration_external = MyCustomOpRegistrationExternal(); }; return my_custom_op; }
请注意,注册不是自动的,应该显式调用您的 MyCustomOpRegistration
函数(有关详细信息,请参见下文)。虽然标准 BuiltinOpResolver
(从 :builtin_ops
目标获得)负责内置函数的注册,但自定义操作必须在单独的自定义库中收集。
在 TensorFlow Lite 运行时定义内核
为了在 TensorFlow Lite 中使用操作,我们只需要定义两个函数 (Prepare
和 Eval
),以及第三个函数来构建 TfLiteRegistrationExternal
C++
namespace atan_op { namespace { TfLiteStatus AtanPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumInputs(node), 1); TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumOutputs(node), 1); const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0); TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0); int num_dims = TfLiteOpaqueTensorNumDimensions(input); TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims); for (int i=0; i < num_dims; ++i) { output_size->data[i] = input->dims->data[i]; } return TfLiteOpaqueContextResizeTensor(context, output, output_size); } TfLiteStatus AtanEval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0); TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0); float* input_data = static_cast(TfLiteOpaqueTensorData(input)); float* output_data = static_cast (TfLiteOpaqueTensorData(output)); size_t count = 1; int num_dims = TfLiteOpaqueTensorNumDimensions(input); for (int i = 0; i < num_dims; ++i) { count *= input->dims->data[i]; } for (size_t i = 0; i < count; ++i) { output_data[i] = atan(input_data[i]); } return kTfLiteOk; } } // anonymous namespace const TfLiteRegistrationExternal* AtanOpRegistrationExternal() { // Singleton instance, intentionally never destroyed. static const TfLiteRegistrationExternal* atan_op = ()[] { auto* r = TfLiteRegistrationExternalCreate( kTfLiteBuiltinCustom, "ATAN", /*version=*/ 1); TfLiteRegistrationExternalSetPrepare(r, Prepare); TfLiteRegistrationExternalSetInvoke(r, Eval); return r; }; return atan_op; } const TfLiteRegistration AtanOpRegistration() { static const TfLiteRegistration atan_op { .registration_external = AtanOpRegistrationExternal(); }; return atan_op; } } // namespace atan_op
C
static TfLiteStatus AtanPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumInputs(node), 1); TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumOutputs(node), 1); const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0); TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0); int num_dims = TfLiteOpaqueTensorNumDimensions(input); TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims); for (int i = 0; i < num_dims; ++i) { output_size->data[i] = input->dims->data[i]; } return TfLiteOpaqueContextResizeTensor(context, output, output_size); } static TfLiteStatus AtanEval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0); TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0); float* input_data = static_cast(TfLiteOpaqueTensorData(input)); float* output_data = static_cast (TfLiteOpaqueTensorData(output)); size_t count = 1; int num_dims = TfLiteOpaqueTensorNumDimensions(input); for (int i = 0; i < num_dims; ++i) { count *= input->dims->data[i]; } for (size_t i = 0; i < count; ++i) { output_data[i] = atan(input_data[i]); } return kTfLiteOk; } static const TfLiteRegistrationExternal* AtanOpCreate() { TfLiteRegistrationExternal* r = TfLiteRegistrationExternalCreate( kTfLiteBuiltinCustom, "ATAN", /*version=*/ 1); TfLiteRegistrationExternalSetPrepare(r, Prepare); TfLiteRegistrationExternalSetInvoke(r, Eval); return r; } const TfLiteRegistrationExternal* AtanOpRegistrationExternal() { // Singleton instance, intentionally never destroyed. static const TfLiteRegistrationExternal* atan_op = AtanOpCreate(); return atan_op; } const TfLiteRegistration AtanOpRegistration() { static const TfLiteRegistration atan_op { .registration_external = AtanOpRegistrationExternal(); }; return atan_op; }
在初始化 OpResolver
时,将自定义操作添加到解析器中(有关示例,请参见下文)。这将向 Tensorflow Lite 注册操作符,以便 TensorFlow Lite 可以使用新实现。请注意,TfLiteRegistration
中的最后两个参数对应于您为自定义操作定义的 AtanPrepare
和 AtanEval
函数。如果您使用 AtanInit
和 AtanFree
函数分别初始化操作中使用的变量并释放空间,那么它们将被添加到 TfLiteRegistration
的前两个参数中;在本例中,这些参数设置为 nullptr
。
向内核库注册操作符
现在,我们需要向内核库注册操作符。这使用 OpResolver
完成。在幕后,解释器将加载一个内核库,该库将被分配来执行模型中的每个操作符。虽然默认库只包含内置内核,但可以使用自定义库操作符替换/增强它。
将操作符代码和名称转换为实际代码的 OpResolver
类定义如下
class OpResolver {
public:
virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
virtual TfLiteRegistration* FindOp(const char* op) const = 0;
...
};
请注意,为了向后兼容,此类使用旧的具体类型 TfLiteRegistration
而不是不透明类型 TfLiteRegistrationExternal
,但 TfLiteRegistration
结构包含一个类型为 TfLiteRegistrationExternal*
的 registration_external
字段。
MutableOpResolver
和 BuiltinOpResolver
类派生自 OpResolver
class MutableOpResolver : public OpResolver {
public:
MutableOpResolver(); // Constructs an initially empty op resolver.
void AddBuiltin(tflite::BuiltinOperator op, const TfLiteRegistration* registration) = 0;
void AddCustom(const char* op, const TfLiteRegistration* registration) = 0;
void AddAll(const MutableOpResolver& other);
...
};
class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver(); // Constructs an op resolver with all the builtin ops.
};
常规使用(不使用自定义操作)要求您使用 BuiltinOpResolver
并编写
tflite::ops::builtin::BuiltinOpResolver resolver;
要添加上面创建的自定义操作,您可以改用 MutableOpResolver
,并在将解析器传递给 InterpreterBuilder
之前调用 AddCustom
tflite::ops::builtin::MutableOpResolver resolver;
resolver.AddAll(tflite::ops::builtin::BuiltinOpResolver());
resolver.AddCustom("Atan", AtanOpRegistration());
如果内置操作集被认为太大,则可以根据给定的操作子集(可能仅包含给定模型中包含的操作)代码生成新的 OpResolver
。这等效于 TensorFlow 的选择性注册(并且 tools
目录中提供了它的简单版本)。
如果您想在 Java 中定义自定义操作符,您目前需要构建自己的自定义 JNI 层并编译自己的 AAR 在此 jni 代码中。类似地,如果您希望在 Python 中定义这些操作符,您可以将注册信息放在 Python 包装器代码 中。
请注意,可以按照上述类似过程来支持一组操作而不是单个操作。只需添加您需要的 AddCustom
操作符即可。此外,MutableOpResolver
还允许您使用 AddBuiltin
覆盖内置函数的实现。
测试和分析您的操作符
要使用 TensorFlow Lite 基准测试工具分析您的操作,您可以使用 TensorFlow Lite 的 基准模型工具。为了测试目的,您可以通过将适当的 AddCustom
调用(如上所示)添加到 register.cc 来使您的本地 TensorFlow Lite 构建了解您的自定义操作。
最佳实践
谨慎优化内存分配和释放。在
Prepare
中分配内存比在Invoke
中更有效,在循环之前分配内存比在每次迭代中分配内存更好。使用临时张量数据而不是自己进行 malloc(参见项目 2)。尽可能使用指针/引用而不是复制。如果数据结构在整个操作期间持续存在,我们建议使用临时张量预先分配内存。您可能需要使用 OpData 结构来引用其他函数中的张量索引。请参阅 卷积内核 中的示例。下面是一个示例代码片段。
struct MyOpData { int temp_tensor_index; ... }; void* Init(TfLiteOpaqueContext* context, const char* buffer, size_t length) { auto* op_data = new MyOpData{}; ... return op_data; } void Free(TfLiteOpaqueContext* context, void* buffer) { ... delete reinterpret_cast<MyOpData*>(buffer); } TfLiteStatus Prepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { ... auto* op_data = reinterpret_cast<MyOpData*>(TfLiteOpaqueNodeGetUserData(node)); const int num_temporaries = 1; int temporary_tensor_indices[num_temporaries]; TfLiteOpaqueTensorBuilder* builder = TfLiteOpaqueTensorBuilderCreate(); TfLiteOpaqueTensorBuilderSetType(builder, kTfLiteFloat32); TfLiteOpaqueTensorBuilderSetAllocationType(builder, kTfLiteArenaRw); TfLiteOpaqueContextAddTensor(context, builder, &temporary_tensor_indices[0]); TfLiteOpaqueTensorBuilderDelete(builder); TfLiteOpaqueNodeSetTemporaries(node, temporary_tensor_indices, num_temporaries); op_data->temp_tensor_index = temporary_tensor_indices[0]; ... return kTfLiteOk; } TfLiteStatus Invoke(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) { ... auto* op_data = reinterpret_cast<MyOpData*>( TfLiteOpaqueNodeGetUserData(node)); TfLiteOpaqueTensor* temp_tensor = TfLiteOpaqueContextGetOpaqueTensor(context, op_data->temp_tensor_index); TF_LITE_OPAQUE_ENSURE(context, TfLiteTensorType(temp_tensor) == kTfLiteFloat32); TF_LITE_OPAQUE_ENSURE(context, TfLiteTensorGetAllocationType(temp_Tensor) == kTfLiteArenaRw); void *temp_data = TfLiteTensorData(temp_tensor); TF_LITE_OPAQUE_ENSURE(context, temp_data != nullptr); ... return kTfLiteOk; }
如果不会浪费太多内存,建议使用静态固定大小数组(或在
Resize
中预先分配的std::vector
)而不是在每次执行迭代中使用动态分配的std::vector
。避免实例化不存在的标准库容器模板,因为它们会影响二进制大小。例如,如果您在操作中需要一个在其他内核中不存在的
std::map
,则可以使用具有直接索引映射的std::vector
,同时保持二进制大小较小。请查看其他内核的使用情况以获得洞察力(或询问)。检查
malloc
返回的内存指针。如果此指针为nullptr
,则不应使用该指针执行任何操作。如果您在函数中malloc
并出现错误退出,请在退出之前释放内存。使用
TF_LITE_OPAQUE_ENSURE(context, condition)
检查特定条件。当使用TF_LITE_OPAQUE_ENSURE
时,您的代码不得留下悬挂的内存,即这些宏应在分配任何将泄漏的资源之前使用。