使用 TensorFlow Hub 从 TF1 迁移到 TF2

此页面说明了在将 TensorFlow 代码从 TensorFlow 1 迁移到 TensorFlow 2 时如何继续使用 TensorFlow Hub。它补充了 TensorFlow 的通用迁移指南

对于 TF2,TF Hub 已从用于构建tf.compat.v1.Graph(类似于tf.contrib.v1.layers)的旧版hub.Module API 中切换出来。现在,有一个hub.KerasLayer,可与其他 Keras 层一起用于构建tf.keras.Model(通常在 TF2 的新急切执行环境中),以及用于低级 TensorFlow 代码的底层hub.load()方法。

hub.Module API 仍然可以在tensorflow_hub库中使用,以便在 TF1 和 TF2 的 TF1 兼容模式中使用。它只能加载TF1 Hub 格式中的模型。

新 API hub.load()hub.KerasLayer适用于 TensorFlow 1.15(在急切模式和图形模式中)和 TensorFlow 2。此新 API 可以加载新的TF2 SavedModel资产,并且在模型兼容性指南中列出了限制后,可以加载 TF1 Hub 格式中的旧版模型。

一般来说,建议尽可能使用新 API。

新 API 的摘要

hub.load() 是一个新的低级别函数,用于从 TensorFlow Hub(或兼容服务)加载 SavedModel。它封装了 TF2 的 tf.saved_model.load();TensorFlow 的 SavedModel 指南 介绍了您可以使用结果执行的操作。

m = hub.load(handle)
outputs = m(inputs)

hub.KerasLayer 类调用 hub.load(),并调整结果以与其他 Keras 层一起在 Keras 中使用。(它甚至可能是以其他方式使用的已加载 SavedModel 的一个便捷封装器。)

model = tf.keras.Sequential([
    hub.KerasLayer(handle),
    ...])

许多教程展示了这些 API 的实际应用。以下是一些示例

在 Estimator 训练中使用新 API

如果您在 Estimator 中使用 TF2 SavedModel 来使用参数服务器(或在 TF1 会话中使用放置在远程设备上的变量)进行训练,则需要在 tf.Session 的 ConfigProto 中设置 experimental.share_cluster_devices_in_session,否则您将收到一条错误消息,例如“分配的设备 '/job:ps/replica:0/task:0/device:CPU:0' 与任何设备不匹配”。

必要的选项可以像这样设置

session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)

从 TF2.2 开始,此选项不再是实验性的,并且可以删除 .experimental 部分。

在 TF1 Hub 格式中加载旧版模型

可能会出现这种情况:对于您的用例,新的 TF2 SavedModel 尚未可用,并且您需要加载 TF1 Hub 格式中的旧版模型。从 tensorflow_hub 版本 0.7 开始,您可以将 TF1 Hub 格式中的旧版模型与 hub.KerasLayer 一起使用,如下所示

m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)

此外,KerasLayer 公开了指定 tagssignatureoutput_keysignature_outputs_as_dict 的功能,以便更具体地使用 TF1 Hub 格式中的旧版模型和旧版 SavedModel。

有关 TF1 Hub 格式兼容性的更多信息,请参阅 模型兼容性指南

使用更低级别的 API

可以通过 tf.saved_model.load 加载旧版 TF1 Hub 格式模型。代替

# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)

建议使用

# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)

在这些示例中,m.signatures 是一个 TensorFlow 具体函数 字典,按签名名称进行键控。调用此类函数会计算其所有输出,即使未使用。(这与 TF1 图模式的惰性求值不同。)