此页面说明了在将 TensorFlow 代码从 TensorFlow 1 迁移到 TensorFlow 2 时如何继续使用 TensorFlow Hub。它补充了 TensorFlow 的通用迁移指南。
对于 TF2,TF Hub 已从用于构建tf.compat.v1.Graph
(类似于tf.contrib.v1.layers
)的旧版hub.Module
API 中切换出来。现在,有一个hub.KerasLayer
,可与其他 Keras 层一起用于构建tf.keras.Model
(通常在 TF2 的新急切执行环境中),以及用于低级 TensorFlow 代码的底层hub.load()
方法。
hub.Module
API 仍然可以在tensorflow_hub
库中使用,以便在 TF1 和 TF2 的 TF1 兼容模式中使用。它只能加载TF1 Hub 格式中的模型。
新 API hub.load()
和hub.KerasLayer
适用于 TensorFlow 1.15(在急切模式和图形模式中)和 TensorFlow 2。此新 API 可以加载新的TF2 SavedModel资产,并且在模型兼容性指南中列出了限制后,可以加载 TF1 Hub 格式中的旧版模型。
一般来说,建议尽可能使用新 API。
新 API 的摘要
hub.load()
是一个新的低级别函数,用于从 TensorFlow Hub(或兼容服务)加载 SavedModel。它封装了 TF2 的 tf.saved_model.load()
;TensorFlow 的 SavedModel 指南 介绍了您可以使用结果执行的操作。
m = hub.load(handle)
outputs = m(inputs)
hub.KerasLayer
类调用 hub.load()
,并调整结果以与其他 Keras 层一起在 Keras 中使用。(它甚至可能是以其他方式使用的已加载 SavedModel 的一个便捷封装器。)
model = tf.keras.Sequential([
hub.KerasLayer(handle),
...])
许多教程展示了这些 API 的实际应用。以下是一些示例
在 Estimator 训练中使用新 API
如果您在 Estimator 中使用 TF2 SavedModel 来使用参数服务器(或在 TF1 会话中使用放置在远程设备上的变量)进行训练,则需要在 tf.Session 的 ConfigProto 中设置 experimental.share_cluster_devices_in_session
,否则您将收到一条错误消息,例如“分配的设备 '/job:ps/replica:0/task:0/device:CPU:0' 与任何设备不匹配”。
必要的选项可以像这样设置
session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)
从 TF2.2 开始,此选项不再是实验性的,并且可以删除 .experimental
部分。
在 TF1 Hub 格式中加载旧版模型
可能会出现这种情况:对于您的用例,新的 TF2 SavedModel 尚未可用,并且您需要加载 TF1 Hub 格式中的旧版模型。从 tensorflow_hub
版本 0.7 开始,您可以将 TF1 Hub 格式中的旧版模型与 hub.KerasLayer
一起使用,如下所示
m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)
此外,KerasLayer
公开了指定 tags
、signature
、output_key
和 signature_outputs_as_dict
的功能,以便更具体地使用 TF1 Hub 格式中的旧版模型和旧版 SavedModel。
有关 TF1 Hub 格式兼容性的更多信息,请参阅 模型兼容性指南。
使用更低级别的 API
可以通过 tf.saved_model.load
加载旧版 TF1 Hub 格式模型。代替
# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)
建议使用
# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)
在这些示例中,m.signatures
是一个 TensorFlow 具体函数 字典,按签名名称进行键控。调用此类函数会计算其所有输出,即使未使用。(这与 TF1 图模式的惰性求值不同。)