在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
模型进度可以在训练期间和训练后保存。这意味着模型可以从上次中断的地方继续训练,避免长时间的训练时间。保存还意味着您可以分享您的模型,其他人可以重现您的工作。在发布研究模型和技术时,大多数机器学习从业者会分享
- 创建模型的代码,以及
- 模型的训练权重或参数
分享这些数据有助于其他人了解模型的工作原理,并尝试使用新数据进行测试。
选项
根据您使用的 API,有不同的方法可以保存 TensorFlow 模型。本指南使用 tf.keras——一个在 TensorFlow 中构建和训练模型的高级 API。本教程中使用的新的高级 .keras
格式建议用于保存 Keras 对象,因为它提供了强大的、基于名称的保存,通常比低级或旧版格式更容易调试。有关更高级的保存或序列化工作流程,尤其是涉及自定义对象的那些工作流程,请参阅 保存和加载 Keras 模型指南。有关其他方法,请参阅 使用 SavedModel 格式指南。
设置
安装和导入
安装并导入 TensorFlow 和依赖项
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
获取示例数据集
为了演示如何保存和加载权重,您将使用 MNIST 数据集。为了加快这些运行速度,请使用前 1000 个示例
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
定义模型
首先构建一个简单的顺序模型
# Define a simple sequential model
def create_model():
model = tf.keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
在训练期间保存检查点
您可以使用训练过的模型而无需重新训练,或者在训练过程中断的情况下从中断的地方继续训练。 tf.keras.callbacks.ModelCheckpoint
回调允许您在训练期间和训练结束时持续保存模型。
检查点回调用法
创建一个 tf.keras.callbacks.ModelCheckpoint
回调,该回调仅在训练期间保存权重
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
这将创建一个 TensorFlow 检查点文件的单个集合,这些文件在每个 epoch 结束时更新
os.listdir(checkpoint_dir)
只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,请创建一个与原始模型具有相同架构的模型,然后设置其权重。
现在重建一个新的、未经训练的模型,并在测试集上对其进行评估。未经训练的模型将在机会水平(约 10% 的准确率)上执行
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
然后从检查点加载权重并重新评估
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
检查点回调选项
回调提供了一些选项来为检查点提供唯一的名称并调整检查点频率。
训练一个新模型,并在每五个 epoch 保存一次具有唯一名称的检查点
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Calculate the number of batches per epoch
import math
n_batches = len(train_images) / batch_size
n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*n_batches)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
现在,查看生成的检查点并选择最新的检查点
os.listdir(checkpoint_dir)
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
为了测试,重置模型,并加载最新的检查点
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
这些文件是什么?
上面的代码将权重存储到一个 检查点 格式文件的集合中,这些文件仅包含以二进制格式训练的权重。检查点包含
- 一个或多个包含模型权重的分片。
- 一个索引文件,指示哪些权重存储在哪个分片中。
如果您在一台机器上训练模型,您将有一个后缀为 .data-00000-of-00001
的分片。
手动保存权重
要手动保存权重,请使用 tf.keras.Model.save_weights
。默认情况下,tf.keras
——尤其是 Model.save_weights
方法——使用 TensorFlow 检查点 格式,扩展名为 .ckpt
。要以 HDF5 格式保存,扩展名为 .h5
,请参阅 保存和加载模型 指南。
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
保存整个模型
调用 tf.keras.Model.save
以将模型的架构、权重和训练配置保存在单个 model.keras
zip 存档中。
整个模型可以保存为三种不同的文件格式(新的 .keras
格式和两种旧版格式:SavedModel
和 HDF5
)。将模型保存为 path/to/model.keras
会自动以最新格式保存。
您可以切换到 SavedModel 格式,方法是
- 将
save_format='tf'
传递给save()
- 传递一个没有扩展名的文件名
您可以切换到 H5 格式,方法是
- 将
save_format='h5'
传递给save()
- 传递一个以
.h5
结尾的文件名
保存一个功能齐全的模型非常有用——您可以在 TensorFlow 中加载它们。 TensorFlow.js (Saved Model,HDF5),然后在 Web 浏览器中训练和运行它们,或者将它们转换为使用 TensorFlow Lite 在移动设备上运行 (Saved Model,HDF5)
*自定义对象(例如,子类模型或层)在保存和加载时需要特别注意。请参阅下面的保存自定义对象部分。
新的高级 .keras
格式
新的 Keras v3 保存格式,以 .keras
扩展名标记,是一种更简单、更高效的格式,它实现了基于名称的保存,确保您加载的内容与您从 Python 的角度保存的内容完全一致。这使得调试变得容易得多,并且是 Keras 的推荐格式。
以下部分说明了如何在 .keras
格式中保存和恢复模型。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a `.keras` zip archive.
model.save('my_model.keras')
从 .keras
zip 存档中重新加载一个新的 Keras 模型
new_model = tf.keras.models.load_model('my_model.keras')
# Show the model architecture
new_model.summary()
尝试使用加载的模型运行评估和预测
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
SavedModel 格式
SavedModel 格式是序列化模型的另一种方法。以这种格式保存的模型可以使用 tf.keras.models.load_model
恢复,并且与 TensorFlow Serving 兼容。 SavedModel 指南 详细介绍了如何 serve/inspect
SavedModel。以下部分说明了保存和恢复模型的步骤。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
SavedModel 格式是一个目录,其中包含一个 protobuf 二进制文件和一个 TensorFlow 检查点。检查保存的模型目录
# my_model directory
ls saved_model
# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
从保存的模型中重新加载一个新的 Keras 模型
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
恢复的模型使用与原始模型相同的参数进行编译。尝试使用加载的模型运行评估和预测
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
HDF5 格式
Keras 使用 HDF5 标准提供了一个基本的旧版高级保存格式。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
现在,从该文件重新创建模型
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
检查其准确性
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
Keras 通过检查模型的架构来保存模型。此技术保存所有内容
- 权重值
- 模型的架构
- 模型的训练配置(您传递给
.compile()
方法的内容) - 优化器及其状态(如果有)(这使您能够从中断的地方重新开始训练)
Keras 无法保存 v1.x
优化器(来自 tf.compat.v1.train
),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载后重新编译模型——从而丢失优化器的状态。
保存自定义对象
如果您使用的是 SavedModel 格式,则可以跳过本节。高级 .keras
/HDF5 格式与低级 SavedModel 格式之间的主要区别在于,.keras
/HDF5 格式使用对象配置来保存模型架构,而 SavedModel 保存执行图。因此,SavedModel 能够保存自定义对象,例如子类模型和自定义层,而无需原始代码。但是,调试低级 SavedModel 可能更困难,因此我们建议使用高级 .keras
格式,因为它具有基于名称的 Keras 本地性质。
要将自定义对象保存到 .keras
和 HDF5,您必须执行以下操作
- 在您的对象中定义一个
get_config
方法,并可选地定义一个from_config
类方法。get_config(self)
返回一个 JSON 可序列化的参数字典,这些参数需要重新创建对象。from_config(cls, config)
使用从get_config
返回的配置来创建一个新对象。默认情况下,此函数将使用配置作为初始化关键字参数 (return cls(**config)
)。
- 通过以下三种方式之一将自定义对象传递给模型
- 使用
@tf.keras.utils.register_keras_serializable
装饰器注册自定义对象。(推荐) - 在加载模型时直接将对象传递给
custom_objects
参数。该参数必须是一个字典,将字符串类名映射到 Python 类。例如,tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
- 使用一个
tf.keras.utils.custom_object_scope
,其中包含custom_objects
字典参数中的对象,并将tf.keras.models.load_model(path)
调用放在范围内。
- 使用
请参阅 从头开始编写层和模型 教程,以了解自定义对象和 get_config
的示例。
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.