训练检查点

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

短语“保存 TensorFlow 模型”通常意味着两件事之一

  1. 检查点,或
  2. SavedModel。

检查点捕获模型使用的所有参数(tf.Variable 对象)的精确值。检查点不包含模型定义的计算的任何描述,因此通常仅在使用保存的参数值的源代码可用时才有用。

另一方面,SavedModel 格式除了参数值(检查点)之外,还包含模型定义的计算的序列化描述。此格式的模型独立于创建模型的源代码。因此,它们适合通过 TensorFlow Serving、TensorFlow Lite、TensorFlow.js 或其他编程语言(C、C++、Java、Go、Rust、C# 等 TensorFlow API)中的程序进行部署。

本指南介绍了用于写入和读取检查点的 API。

设置

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.keras 训练 API 保存

请参阅 tf.keras 关于保存和恢复的指南。

tf.keras.Model.save_weights 保存 TensorFlow 检查点。

net.save_weights('easy_checkpoint')

写入检查点

TensorFlow 模型的持久状态存储在 tf.Variable 对象中。这些对象可以直接构造,但通常是通过高级 API(如 tf.keras.layerstf.keras.Model)创建的。

管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。

tf.train.Checkpointtf.keras.layers.Layertf.keras.Model 的子类会自动跟踪分配给其属性的变量。以下示例构建了一个简单的线性模型,然后写入包含模型所有变量值的检查点。

您可以使用 Model.save_weights 轻松保存模型检查点。

手动检查点

设置

为了帮助演示 tf.train.Checkpoint 的所有功能,请定义一个玩具数据集和优化步骤

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

创建检查点对象

使用 tf.train.Checkpoint 对象手动创建检查点,将您要检查点的对象设置为该对象的属性。

tf.train.CheckpointManager 也可以帮助管理多个检查点。

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

训练和检查点模型

以下训练循环创建模型和优化器的实例,然后将它们收集到一个 tf.train.Checkpoint 对象中。它在每个数据批次上循环调用训练步骤,并定期将检查点写入磁盘。

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)

恢复并继续训练

在第一个训练周期之后,您可以传递一个新的模型和管理器,但可以从您停止的地方继续训练。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)

tf.train.CheckpointManager 对象会删除旧的检查点。在上面,它被配置为只保留三个最新的检查点。

print(manager.checkpoints)  # List the three remaining checkpoints

这些路径,例如 './tf_ckpts/ckpt-10',不是磁盘上的文件。相反,它们是 index 文件和一个或多个包含变量值的数据文件的首缀。这些前缀被分组到一个单独的 checkpoint 文件 ('./tf_ckpts/checkpoint') 中,CheckpointManager 在其中保存其状态。

ls ./tf_ckpts

加载机制

TensorFlow 通过遍历一个带有命名边的有向图来匹配变量和检查点值,从要加载的对象开始。边名通常来自对象中的属性名,例如 self.l1 = tf.keras.layers.Dense(5) 中的 "l1"tf.train.Checkpoint 使用其关键字参数名,例如 tf.train.Checkpoint(step=...) 中的 "step"

上面示例的依赖关系图如下所示

Visualization of the dependency graph for the example training loop

优化器以红色显示,常规变量以蓝色显示,优化器槽变量以橙色显示。其他节点(例如,表示 tf.train.Checkpoint)以黑色显示。

槽变量是优化器状态的一部分,但它是为特定变量创建的。例如,上面的 'm' 边对应于动量,Adam 优化器为每个变量跟踪动量。只有当变量和优化器都将被保存时,槽变量才会保存在检查点中,因此虚线边。

调用 restore 在一个 tf.train.Checkpoint 对象上会将请求的恢复排队,并在 Checkpoint 对象中存在匹配路径时立即恢复变量值。例如,您可以通过重建一条通往网络和层的路径来加载上面定义的模型中的偏差。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.

这些新对象的依赖关系图是您上面写入的较大检查点的较小子图。它只包含偏差和一个保存计数器,该计数器由 tf.train.Checkpoint 用于对检查点进行编号。

Visualization of a subgraph for the bias variable

restore 返回一个状态对象,该对象具有可选断言。新 Checkpoint 中创建的所有对象都已恢复,因此 status.assert_existing_objects_matched 通过。

status.assert_existing_objects_matched()

检查点中还有许多未匹配的对象,包括层的内核和优化器的变量。 status.assert_consumed 仅在检查点和程序完全匹配时才通过,并且在此处会抛出异常。

延迟恢复

TensorFlow 中的 Layer 对象可能会将变量的创建延迟到第一次调用,此时输入形状可用。例如,Dense 层内核的形状取决于层的输入和输出形状,因此作为构造函数参数所需的输出形状不足以单独创建变量。由于调用 Layer 也会读取变量的值,因此恢复必须发生在变量创建和第一次使用之间。

为了支持这种习惯用法,tf.train.Checkpoint 会延迟尚未匹配变量的恢复。

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored

手动检查检查点

tf.train.load_checkpoint 返回一个 CheckpointReader,它提供对检查点内容的更低级访问。它包含从每个变量的键到检查点中每个变量的形状和类型的映射。变量的键是其对象路径,就像上面显示的图中一样。

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())

因此,如果您对 net.l1.kernel 的值感兴趣,您可以使用以下代码获取该值

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)

它还提供了一个 get_tensor 方法,允许您检查变量的值

reader.get_tensor(key)

对象跟踪

检查点通过“跟踪”其属性中设置的任何变量或可跟踪对象来保存和恢复 tf.Variable 对象的值。在执行保存时,变量会从所有可达到的跟踪对象中递归收集。

与直接属性赋值(如 self.l1 = tf.keras.layers.Dense(5))一样,将列表和字典分配给属性将跟踪其内容。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

您可能会注意到列表和字典的包装器对象。这些包装器是底层数据结构的可检查点版本。就像基于属性的加载一样,这些包装器会在变量添加到容器后立即恢复其值。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()

可跟踪对象包括 tf.train.Checkpointtf.Module 及其子类(例如 keras.layers.Layerkeras.Model)以及识别的 Python 容器

  • dict(和 collections.OrderedDict
  • list
  • tuple(和 collections.namedtupletyping.NamedTuple

其他容器类型不受支持,包括

  • collections.defaultdict
  • set

所有其他 Python 对象都被忽略,包括

  • int
  • string
  • float

总结

TensorFlow 对象提供了一种简单的自动机制,用于保存和恢复它们使用的变量的值。