可重用 SavedModels

简介

TensorFlow Hub 托管 TensorFlow 2 的 SavedModels,以及其他资产。它们可以通过 obj = hub.load(url) [了解更多] 加载回 Python 程序。返回的 objtf.saved_model.load() 的结果(参见 TensorFlow 的 SavedModel 指南)。此对象可以具有任意属性,这些属性是 tf.functions、tf.Variables(从其预训练值初始化)、其他资源以及递归地更多此类对象。

此页面描述了加载的 obj 要在 TensorFlow Python 程序中重用,需要实现的接口。符合此接口的 SavedModels 称为可重用 SavedModels

重用意味着在 obj 周围构建一个更大的模型,包括对其进行微调的能力。微调意味着在周围模型的一部分中进一步训练加载的 obj 中的权重。损失函数和优化器由周围模型确定;obj 仅定义输入到输出激活的映射(“前向传递”),可能包括诸如 dropout 或批归一化之类的技术。

TensorFlow Hub 团队建议在所有旨在以上述方式重用的 SavedModels 中实现可重用 SavedModel 接口tensorflow_hub 库中的许多实用程序,尤其是 hub.KerasLayer,要求 SavedModels 实现它。

与 SignatureDefs 的关系

此接口在 tf.functions 和其他 TF2 功能方面的定义与 SavedModel 的签名是分开的,签名自 TF1 以来一直可用,并且在 TF2 中继续用于推理(例如,将 SavedModels 部署到 TF Serving 或 TF Lite)。用于推理的签名不足以支持微调,而 tf.function 为重用模型提供了更自然、更具表现力的 Python API

与模型构建库的关系

可重用 SavedModel 仅使用 TensorFlow 2 原语,独立于任何特定的模型构建库,如 Keras 或 Sonnet。这有助于跨模型构建库进行重用,不受对原始模型构建代码的依赖关系的影响。

将可重用 SavedModels 加载到任何给定的模型构建库中或从任何给定的模型构建库中保存它们,都需要进行一定程度的调整。对于 Keras,hub.KerasLayer 提供加载功能,而 Keras 在 SavedModel 格式中的内置保存功能已针对 TF2 进行了重新设计,旨在提供此接口的超集(参见 RFC,来自 2019 年 5 月)。

与特定于任务的“常见 SavedModel API”的关系

此页面上的接口定义允许任意数量和类型的输入和输出。TF Hub 的常见 SavedModel API 使用特定任务的用法约定来细化此通用接口,以使模型易于互换。

接口定义

属性

可重用 SavedModel 是一个 TensorFlow 2 SavedModel,因此 obj = tf.saved_model.load(...) 返回一个具有以下属性的对象

  • __call__。必需。一个 tf.function,它根据以下规范实现模型的计算(“前向传递”)。

  • variables:一个 tf.Variable 对象列表,列出任何可能的 __call__ 调用使用的所有变量,包括可训练变量和不可训练变量。

    如果列表为空,则可以省略。

  • trainable_variables:一个 tf.Variable 对象列表,使得对于所有元素,v.trainable 为 true。这些变量必须是 variables 的子集。这些是在微调对象时要训练的变量。SavedModel 创建者可以选择在此处省略一些最初可训练的变量,以指示这些变量在微调期间不应修改。

    如果列表为空,则可以省略,特别是如果 SavedModel 不支持微调。

  • regularization_losses:一个 tf.function 列表,每个函数都接受零个输入并返回一个单一的标量浮点张量。对于微调,建议 SavedModel 用户将这些作为额外的正则化项包含到损失中(在最简单的情况下,无需进一步缩放)。通常,这些用于表示权重正则化器。(由于缺乏输入,这些 tf.functions 无法表达活动正则化器。)

    如果列表为空,则可以省略,特别是如果 SavedModel 不支持微调或不希望规定权重正则化。

__call__ 函数

已恢复的 SavedModel obj 具有一个 obj.__call__ 属性,该属性是一个已恢复的 tf.function,并允许 obj 如下调用。

概要(伪代码)

outputs = obj(inputs, trainable=..., **kwargs)

参数

参数如下。

  • 有一个位置参数,它是 SavedModel 的输入激活批次的必需参数。其类型是以下之一

    • 单个输入的单个张量,
    • 一个张量列表,用于未命名输入的有序序列,
    • 一个张量字典,由一组特定的输入名称作为键。

    (此接口的未来修订版可能允许更通用的嵌套。)SavedModel 创建者选择其中之一以及张量形状和数据类型。在有用的地方,形状的某些维度应该是不确定的(特别是批次大小)。

  • 可能有一个可选的关键字参数 training,它接受一个 Python 布尔值,TrueFalse。默认值为 False。如果模型支持微调,并且如果它的计算在两者之间有所不同(例如,在 dropout 和批次归一化中),则使用此参数实现这种区别。否则,此参数可能不存在。

    不需要 __call__ 接受一个 Tensor 值的 training 参数。调用者有责任使用 tf.cond()(如果需要)在它们之间进行调度。

  • SavedModel 创建者可以选择接受更多具有特定名称的可选 kwargs

    • 对于 Tensor 值参数,SavedModel 创建者定义了它们允许的数据类型和形状。 tf.function 接受一个 Python 默认值,该值在使用 tf.TensorSpec 输入跟踪的参数上。此类参数可用于允许自定义 __call__ 中涉及的数值超参数(例如,dropout 率)。

    • 对于 Python 值参数,SavedModel 创建者定义了它们允许的值。此类参数可用作标志,以便在跟踪的函数中进行离散选择(但要注意跟踪的组合爆炸)。

恢复的 __call__ 函数必须为所有允许的参数组合提供跟踪。在 TrueFalse 之间翻转 training 不应改变参数的允许性。

结果

调用 objoutputs 可以是

  • 单个输出的单个张量,
  • 一个张量列表,用于未命名输出的有序序列,
  • 一个张量字典,由一组特定的输出名称作为键。

(此接口的未来修订版可能允许更通用的嵌套。)返回类型可能会根据 Python 值 kwargs 而有所不同。这允许标志生成额外的输出。SavedModel 创建者定义输出数据类型和形状及其对输入的依赖关系。

命名可调用对象

一个可重用 SavedModel 可以通过将它们放入命名子对象中来提供多个模型片段,例如,obj.fooobj.bar 等等。每个子对象都提供一个 __call__ 方法和关于该模型片段的变量等的辅助属性。对于上面的示例,将有 obj.foo.__call__obj.foo.variables 等等。

请注意,此接口 *不* 涵盖将裸 tf.function 直接添加为 tf.foo 的方法。

可重用 SavedModel 的用户只应处理一层嵌套(obj.bar 但不是 obj.bar.baz)。(此接口的未来修订版可能允许更深的嵌套,并且可能放弃顶层对象本身可调用的要求。)

结束语

与进程内 API 的关系

本文档描述了一个 Python 类的接口,该接口由 tf.function 和 tf.Variable 等基元组成,这些基元在通过 tf.saved_model.save()tf.saved_model.load() 进行序列化后仍然存在。但是,该接口已经存在于传递给 tf.saved_model.save() 的原始对象上。对该接口的适应使模型片段能够在单个 TensorFlow 程序内的模型构建 API 之间进行交换。