迁移模型检查点

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

概述

本指南假设您有一个模型,该模型使用 tf.compat.v1.Saver 保存和加载检查点,并且您希望迁移代码以使用 TF2 tf.train.Checkpoint API,或者在您的 TF2 模型中使用预先存在的检查点。

以下是您可能会遇到的一些常见情况

情况 1

存在来自先前训练运行的 TF1 检查点,需要加载或转换为 TF2。

情况 2

您正在调整模型,这可能会导致变量名称和路径发生变化(例如,当逐步从 get_variable 迁移到显式 tf.Variable 创建时),并且您希望在此过程中保持现有检查点的保存/加载。

请参阅有关 如何在模型迁移期间保持检查点兼容性 的部分

情况 3

您正在将训练代码和检查点迁移到 TF2,但您的推理管道目前仍需要 TF1 检查点(为了生产稳定性)。

选项 1

在训练时保存 TF1 和 TF2 检查点。

选项 2

将 TF2 检查点转换为 TF1。


以下示例展示了在 TF1/TF2 中保存和加载检查点的所有组合,因此您在确定如何迁移模型时具有一定的灵活性。

设置

import tensorflow as tf
import tensorflow.compat.v1 as tf1

def print_checkpoint(save_path):
  reader = tf.train.load_checkpoint(save_path)
  shapes = reader.get_variable_to_shape_map()
  dtypes = reader.get_variable_to_dtype_map()
  print(f"Checkpoint at '{save_path}':")
  for key in shapes:
    print(f"  (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
          f"value={reader.get_tensor(key)})")

从 TF1 到 TF2 的更改

如果您想知道 TF1 和 TF2 之间发生了哪些变化,以及我们所说的“基于名称的”(TF1)与“基于对象的”(TF2)检查点之间的区别,则本节内容供您参考。

这两种类型的检查点实际上以相同的格式保存,本质上是一个键值表。区别在于键的生成方式。

基于名称的检查点中的键是变量的名称。基于对象的检查点中的键是指从根对象到变量的路径(以下示例将帮助您更好地理解这意味着什么)。

首先,保存一些检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    saver.save(sess, 'tf1-ckpt')

print_checkpoint('tf1-ckpt')
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(7.0, name='c')

ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)

如果您查看 tf2-ckpt 中的键,它们都引用了每个变量的对象路径。例如,变量 avariables 列表中的第一个元素,因此它的键变为 variables/0/...(您可以随意忽略 .ATTRIBUTES/VARIABLE_VALUE 常量)。

仔细检查以下 Checkpoint 对象

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])

尝试使用以下代码段进行实验,并查看检查点键如何随对象结构而变化

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)

为什么 TF2 使用这种机制?

由于 TF2 中不再存在全局图,因此变量名称不可靠,并且在不同程序之间可能不一致。TF2 鼓励面向对象的建模方法,其中变量由层拥有,而层由模型拥有。

variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer

如何在模型迁移期间保持检查点兼容性

迁移过程中的一个重要步骤是 *确保所有变量都初始化为正确的值*,这反过来可以让你验证操作/函数是否执行了正确的计算。为了实现这一点,你必须考虑迁移各个阶段的模型之间的 **检查点兼容性**。本质上,本节回答了这样一个问题,*如何在更改模型的同时继续使用相同的检查点*。

以下是三种保持检查点兼容性的方法,按灵活性递增的顺序排列

  1. 模型具有与以前 **相同的变量名称**。
  2. 模型具有不同的变量名称,并维护一个 **分配映射**,该映射将检查点中的变量名称映射到新名称。
  3. 模型具有不同的变量名称,并维护一个 **TF2 检查点对象**,该对象存储所有变量。

当变量名称匹配时

长标题:当变量名称匹配时如何重新使用检查点。

简短答案:你可以使用 tf1.train.Savertf.train.Checkpoint 直接加载现有的检查点。


如果你使用的是 tf.compat.v1.keras.utils.track_tf1_style_variables,那么它将确保你的模型变量名称与以前相同。你也可以手动确保变量名称匹配。

当迁移后的模型中的变量名称匹配时,你可以直接使用 tf.train.Checkpointtf.compat.v1.train.Saver 加载检查点。这两个 API 都与急切模式和图模式兼容,因此你可以在迁移的任何阶段使用它们。

以下是使用相同检查点和不同模型的示例。首先,使用 tf1.train.Saver 保存 TF1 检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)

以下示例使用 tf.compat.v1.Saver 在急切模式下加载检查点

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")

# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)

下一个代码段使用 TF2 API tf.train.Checkpoint 加载检查点

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')

print("Variable names: ")
print(f"  a.name = {a.name}")
print(f"  b.name = {b.name}")
print(f"  c.name = {c.name}")
print(f"  c_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")

TF2 中的变量名称

  • 变量仍然都具有一个 name 参数,你可以设置它。
  • Keras 模型也接受一个 name 参数,它们将其设置为其变量的前缀。
  • v1.name_scope 函数可用于设置变量名称前缀。这与 tf.variable_scope 非常不同。它只影响名称,不跟踪变量和重用。

tf.compat.v1.keras.utils.track_tf1_style_variables 装饰器是一个垫片,它通过保持 tf.variable_scopetf.compat.v1.get_variable 的命名和重用语义不变,帮助你维护变量名称和 TF1 检查点兼容性。有关更多信息,请参阅 模型映射指南

注意 1:如果你使用的是垫片,请使用 TF2 API 加载你的检查点(即使使用预训练的 TF1 检查点)。

请参阅 *检查点 Keras* 部分。

注意 2:从 get_variable 迁移到 tf.Variable

如果你的垫片装饰的层或模块包含一些使用 tf.Variable 而不是 tf.compat.v1.get_variable 并以面向对象的方式作为属性/跟踪附加的变量(或 Keras 层/模型),那么它们在 TF1.x 图/会话中与急切执行期间可能具有不同的变量命名语义。

简而言之,*在 TF2 中运行时,名称可能不是你期望的*。

维护分配映射

分配映射通常用于在 TF1 模型之间传输权重,并且如果变量名称发生更改,也可以在模型迁移期间使用它们。

你可以将这些映射与 tf.compat.v1.train.init_from_checkpointtf.compat.v1.train.Savertf.train.load_checkpoint 一起使用,将权重加载到变量或作用域名称可能已更改的模型中。

本节中的示例将使用先前保存的检查点

print_checkpoint('tf1-ckpt')

使用 init_from_checkpoint 加载

tf1.train.init_from_checkpoint 必须在图/会话中调用,因为它将值放在变量初始化器中,而不是创建赋值操作。

你可以使用 assignment_map 参数配置变量的加载方式。从文档中

分配映射支持以下语法

  • 'checkpoint_scope_name/': 'scope_name/' - 将从 checkpoint_scope_name 加载当前 scope_name 中的所有变量,并具有匹配的张量名称。
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - 将从 checkpoint_scope_name/some_other_variable 初始化 scope_name/variable_name 变量。
  • 'scope_variable_name': variable - 将使用检查点中的张量 'scope_variable_name' 初始化给定的 tf.Variable 对象。
  • 'scope_variable_name': list(variable) - 将使用检查点中的张量 'scope_variable_name' 初始化已分区变量列表。
  • '/': 'scope_name/' - 将从检查点的根目录(例如,无作用域)加载当前 scope_name 中的所有变量。
# Restoring with tf1.train.init_from_checkpoint:

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # The assignment map will remap all variables in the checkpoint to the
    # new scope:
    tf1.train.init_from_checkpoint(
        'tf1-ckpt',
        assignment_map={'/': 'new_scope/'})
    # `init_from_checkpoint` adds the initializers to these variables.
    # Use `sess.run` to run these initializers.
    sess.run(tf1.global_variables_initializer())

    print("Restored [a, b, c]: ", sess.run([a, b, c]))

使用 tf1.train.Saver 加载

init_from_checkpoint 不同,tf.compat.v1.train.Saver 在图模式和急切模式下都运行。 var_list 参数可以选择接受字典,但它必须将变量名称映射到 tf.Variable 对象。

# Restoring with tf1.train.Saver (works in both graph and eager):

# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

使用 tf.train.load_checkpoint 加载

如果你需要对变量值进行精确控制,则可以使用此选项。同样,这在图模式和急切模式下都有效。

# Restoring with tf.train.load_checkpoint (works in both graph and eager):

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # It may be easier writing a loop if your model has a lot of variables.
    reader = tf.train.load_checkpoint('tf1-ckpt')
    sess.run(a.assign(reader.get_tensor('a')))
    sess.run(b.assign(reader.get_tensor('b')))
    sess.run(c.assign(reader.get_tensor('scoped/c')))
    print("Restored [a, b, c]: ", sess.run([a, b, c]))

维护 TF2 检查点对象

如果变量和作用域名称在迁移期间可能发生很大变化,那么请使用 tf.train.Checkpoint 和 TF2 检查点。TF2 使用 **对象结构** 而不是变量名称(有关更多详细信息,请参阅 *从 TF1 到 TF2 的更改*)。

简而言之,在创建 tf.train.Checkpoint 以保存或恢复检查点时,请确保它使用相同的 **排序**(对于列表)和 **键**(对于字典和 Checkpoint 初始化器的关键字参数)。以下是一些检查点兼容性的示例

ckpt = tf.train.Checkpoint(foo=[var_a, var_b])

# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])

# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])

以下代码示例展示了如何使用“相同”的 tf.train.Checkpoint 加载具有不同名称的变量。首先,保存 TF2 检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("[a, b, c]: ", sess.run([a, b, c]))

    # Save a TF2 checkpoint
    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    tf2_ckpt_path = ckpt.save('tf2-ckpt')
    print_checkpoint(tf2_ckpt_path)

即使变量/作用域名称发生变化,你也可以继续使用 tf.train.Checkpoint

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.variable_scope('different_scope'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))

    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    # `assert_consumed` validates that all checkpoint objects are restored from
    # the checkpoint. `run_restore_ops` is required when running in a TF1
    # session.
    ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()

    # Removing `assert_consumed` is fine if you want to skip the validation.
    # ckpt.restore(tf2_ckpt_path).run_restore_ops()

    print("Restored [a, b, c]: ", sess.run([a, b, c]))

在急切模式下

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])

ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

估计器中的 TF2 检查点

以上各节介绍了如何在迁移模型时保持检查点兼容性。这些概念也适用于估计器模型,尽管保存/加载检查点的方式略有不同。当你将估计器模型迁移到使用 TF2 API 时,你可能希望在 *模型仍在使用估计器* 的情况下,从 TF1 检查点切换到 TF2 检查点。本节介绍了如何做到这一点。

tf.estimator.EstimatorMonitoredSession 具有称为 scaffold 的保存机制,这是一个 tf.compat.v1.train.Scaffold 对象。 Scaffold 可以包含 tf1.train.Savertf.train.Checkpoint,这使 EstimatorMonitoredSession 能够保存 TF1 或 TF2 样式的检查点。

# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=None
  )

!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=tf1.train.Scaffold(saver=ckpt)
  )

!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
                             warm_start_from='est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)  

assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4

在从 est-tf1 热启动后,再训练 5 步,v 的最终值应为 16。训练步骤值不会从 warm_start 检查点中继承。

检查点 Keras

使用 Keras 构建的模型仍然使用 tf1.train.Savertf.train.Checkpoint 加载预先存在的权重。当你的模型完全迁移后,请切换到使用 model.save_weightsmodel.load_weights,尤其是在训练时使用 ModelCheckpoint 回调时。

关于检查点和 Keras,你应该了解一些事项

初始化与构建

Keras 模型和层必须经过 **两个步骤** 才能完全创建。第一步是 Python 对象的 *初始化*:layer = tf.keras.layers.Dense(x)。第二步是 *构建* 步骤,在此步骤中,大多数权重实际上都是创建的:layer.build(input_shape)。你也可以通过调用模型或运行单个 trainevalpredict 步骤来构建模型(仅第一次)。

如果你发现 model.load_weights(path).assert_consumed() 正在引发错误,那么可能是模型/层尚未构建。

Keras 使用 TF2 检查点

tf.train.Checkpoint(model).write 等效于 model.save_weightstf.train.Checkpoint(model).readmodel.load_weights 也是如此。请注意,Checkpoint(model) != Checkpoint(model=model)

TF2 检查点与 Keras 的 build() 步骤一起使用

tf.train.Checkpoint.restore 拥有一个名为“延迟恢复”的机制,它允许 tf.Module 和 Keras 对象在变量尚未创建的情况下存储变量值。这使得“已初始化”模型能够在加载权重后进行“构建”。

m = YourKerasModel()
status = m.load_weights(path)

# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)

status.assert_consumed()

由于这种机制的存在,我们强烈建议您在 Keras 模型中使用 TF2 检查点加载 API(即使是在将预先存在的 TF1 检查点恢复到 模型映射垫片 中时)。在 检查点指南 中了解更多信息。

代码片段

以下代码片段展示了检查点保存 API 中的 TF1/TF2 版本兼容性。

在 TF2 中保存 TF1 检查点

a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(3.0, name='c')

saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)

在 TF2 中加载 TF1 检查点

a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

在 TF1 中保存 TF2 检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
    print_checkpoint(tf2_in_tf1_path)

在 TF1 中加载 TF2 检查点

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
    print("Restored [a, b, c]: ", sess.run([a, b, c]))

检查点转换

您可以通过加载和重新保存检查点来在 TF1 和 TF2 之间转换检查点。另一种方法是使用 tf.train.load_checkpoint,如下面的代码所示。

将 TF1 检查点转换为 TF2

def convert_tf1_to_tf2(checkpoint_path, output_prefix):
  """Converts a TF1 checkpoint to TF2.

  To load the converted checkpoint, you must build a dictionary that maps
  variable names to variable objects.
  ```
  ckpt = tf.train.Checkpoint(vars={name: variable})  
  ckpt.restore(converted_ckpt_path)

    ```

    Args:
      checkpoint_path: Path to the TF1 checkpoint.
      output_prefix: Path prefix to the converted checkpoint.

    Returns:
      Path to the converted checkpoint.
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      vars[key] = tf.Variable(reader.get_tensor(key))
    return tf.train.Checkpoint(vars=vars).save(output_prefix)
  ```

Convert the checkpoint saved in the snippet `Save a TF1 checkpoint in TF2`:


确保在 在 TF2 中保存 TF1 检查点 中运行代码片段。

print_checkpoint('tf1-ckpt-saved-in-eager') converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 'converted-tf1-to-tf2') print("\n[已转换]") print_checkpoint(converted_path)

尝试加载转换后的检查点。

a = tf.Variable(0.) b = tf.Variable(0.) c = tf.Variable(0.) ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c}) ckpt.restore(converted_path).assert_consumed() print("\n已恢复 [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()]) ```

将 TF2 检查点转换为 TF1

def convert_tf2_to_tf1(checkpoint_path, output_prefix):
  """Converts a TF2 checkpoint to TF1.

  The checkpoint must be saved using a 
  `tf.train.Checkpoint(var_list={name: variable})`

  To load the converted checkpoint with `tf.compat.v1.Saver`:
  ```
  saver = tf.compat.v1.train.Saver(var_list={name: variable}) 

  # An alternative, if the variable names match the keys:
  saver = tf.compat.v1.train.Saver(var_list=[variables]) 
  saver.restore(sess, output_path)

    ```
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      # Get the "name" from the 
      if key.startswith('var_list/'):
        var_name = key.split('/')[1]
        # TF2 checkpoint keys use '/', so if they appear in the user-defined name,
        # they are escaped to '.S'.
        var_name = var_name.replace('.S', '/')
        vars[var_name] = tf.Variable(reader.get_tensor(key))

    return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
  ```

Convert the checkpoint saved in the snippet `Save a TF2 checkpoint in TF1`:


确保在 在 TF1 中保存 TF2 检查点 中运行代码片段。

print_checkpoint('tf2-ckpt-saved-in-session-1') converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1', 'converted-tf2-to-tf1') print("\n[已转换]") print_checkpoint(converted_path)

尝试加载转换后的检查点。

with tf.Graph().as_default() as g: a = tf1.get_variable('a', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) b = tf1.get_variable('b', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.variable_scope('scoped'): c = tf1.get_variable('c', shape=[], dtype=tf.float32, initializer=tf1.constant_initializer(0)) with tf1.Session() as sess: saver = tf1.train.Saver([a, b, c]) saver.restore(sess, converted_path) print("\n已恢复 [a, b, c]: ", sess.run([a, b, c])) ```