在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
持续保存“最佳”模型或模型权重/参数有很多好处。这些好处包括能够跟踪训练进度并从不同的保存状态加载保存的模型。
在 TensorFlow 1 中,要使用 tf.estimator.Estimator
API 在训练/验证期间配置检查点保存,您需要在 tf.estimator.RunConfig
中指定一个计划,或者使用 tf.estimator.CheckpointSaverHook
。本指南演示了如何从该工作流程迁移到 TensorFlow 2 Keras API。
在 TensorFlow 2 中,您可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint
- 根据使用
save_best_only=True
参数监控的指标保存“最佳”版本,其中monitor
可以是例如'loss'
、'val_loss'
、'accuracy'
或'val_accuracy'
。 - 以特定频率持续保存(使用
save_freq
参数)。 - 通过将
save_weights_only
设置为True
,仅保存权重/参数,而不是整个模型。
有关更多详细信息,请参阅 tf.keras.callbacks.ModelCheckpoint
API 文档以及 保存和加载模型 教程中的“在训练期间保存检查点”部分。在 保存和加载 Keras 模型 指南的“TF 检查点格式”部分中了解有关检查点格式的更多信息。此外,要添加容错,您可以使用 tf.keras.callbacks.BackupAndRestore
或 tf.train.Checkpoint
进行手动检查点。在 容错迁移指南 中了解有关更多信息。
Keras 回调 是在内置 Keras Model.fit
/Model.evaluate
/Model.predict
API 中训练/评估/预测期间的不同点调用的对象。在指南末尾的“后续步骤”部分中了解有关更多信息。
设置
从导入和简单的演示数据集开始
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1:使用 tf.estimator API 保存检查点
此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig
以使用 tf.estimator.Estimator
API 在训练/评估期间的每个步骤保存检查点
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
%ls {classifier.model_dir}
TensorFlow 2:使用 Keras 回调为 Model.fit 保存检查点
在 TensorFlow 2 中,当您使用内置 Keras Model.fit
(或 Model.evaluate
)进行训练/评估时,您可以配置 tf.keras.callbacks.ModelCheckpoint
,然后将其传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。(在 API 文档和 使用内置方法进行训练和评估 指南的“使用回调”部分中了解有关更多信息。)
在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint
回调将检查点存储在临时目录中
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
%ls {model_checkpoint_callback.filepath}
后续步骤
了解有关检查点的更多信息
- API 文档:
tf.keras.callbacks.ModelCheckpoint
- 教程:保存和加载模型(训练期间保存检查点部分)
- 指南:保存和加载 Keras 模型(TF 检查点格式部分)
了解更多关于回调的信息
- API 文档:
tf.keras.callbacks.Callback
- 指南:编写自己的回调
- 指南:使用内置方法进行训练和评估(使用回调部分)
您可能还会发现以下与迁移相关的资源有用
- 该容错迁移指南:
tf.keras.callbacks.BackupAndRestore
用于Model.fit
,或tf.train.Checkpoint
和tf.train.CheckpointManager
API 用于自定义训练循环 - 该提前停止迁移指南:
tf.keras.callbacks.EarlyStopping
是一个内置的提前停止回调 - 该TensorBoard 迁移指南:TensorBoard 允许跟踪和显示指标
- 该LoggingTensorHook 和 StopAtStepHook 到 Keras 回调迁移指南
- 该SessionRunHook 到 Keras 回调指南