迁移检查点保存

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

持续保存“最佳”模型或模型权重/参数有很多好处。这些好处包括能够跟踪训练进度并从不同的保存状态加载保存的模型。

在 TensorFlow 1 中,要使用 tf.estimator.Estimator API 在训练/验证期间配置检查点保存,您需要在 tf.estimator.RunConfig 中指定一个计划,或者使用 tf.estimator.CheckpointSaverHook。本指南演示了如何从该工作流程迁移到 TensorFlow 2 Keras API。

在 TensorFlow 2 中,您可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint

  • 根据使用 save_best_only=True 参数监控的指标保存“最佳”版本,其中 monitor 可以是例如 'loss''val_loss''accuracy''val_accuracy'
  • 以特定频率持续保存(使用 save_freq 参数)。
  • 通过将 save_weights_only 设置为 True,仅保存权重/参数,而不是整个模型。

有关更多详细信息,请参阅 tf.keras.callbacks.ModelCheckpoint API 文档以及 保存和加载模型 教程中的“在训练期间保存检查点”部分。在 保存和加载 Keras 模型 指南的“TF 检查点格式”部分中了解有关检查点格式的更多信息。此外,要添加容错,您可以使用 tf.keras.callbacks.BackupAndRestoretf.train.Checkpoint 进行手动检查点。在 容错迁移指南 中了解有关更多信息。

Keras 回调 是在内置 Keras Model.fit/Model.evaluate/Model.predict API 中训练/评估/预测期间的不同点调用的对象。在指南末尾的“后续步骤”部分中了解有关更多信息。

设置

从导入和简单的演示数据集开始

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1:使用 tf.estimator API 保存检查点

此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig 以使用 tf.estimator.Estimator API 在训练/评估期间的每个步骤保存检查点

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]

config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)

test_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_test},
    y=y_test.astype(np.int32),
    num_epochs=10,
    shuffle=False
)

train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
                                   steps=10,
                                   throttle_secs=0)

tf1.estimator.train_and_evaluate(estimator=classifier,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
%ls {classifier.model_dir}

TensorFlow 2:使用 Keras 回调为 Model.fit 保存检查点

在 TensorFlow 2 中,当您使用内置 Keras Model.fit(或 Model.evaluate)进行训练/评估时,您可以配置 tf.keras.callbacks.ModelCheckpoint,然后将其传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。(在 API 文档和 使用内置方法进行训练和评估 指南的“使用回调”部分中了解有关更多信息。)

在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint 回调将检查点存储在临时目录中

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=log_dir)

model.fit(x=x_train,
          y=y_train,
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[model_checkpoint_callback])
%ls {model_checkpoint_callback.filepath}

后续步骤

了解有关检查点的更多信息

了解更多关于回调的信息

您可能还会发现以下与迁移相关的资源有用