作者: fchollet
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 在 keras.io 上查看 |
简介
通常有两种方法可以将计算分布到多个设备上
数据并行,其中单个模型在多个设备或多台机器上复制。它们中的每一个都处理不同的数据批次,然后合并它们的结果。这种设置存在许多变体,它们在不同的模型副本如何合并结果、它们是否在每个批次都保持同步或它们是否更松散耦合等方面有所不同。
模型并行,其中单个模型的不同部分在不同的设备上运行,共同处理单个数据批次。这最适合具有自然并行架构的模型,例如具有多个分支的模型。
本指南重点介绍数据并行,特别是同步数据并行,其中模型的不同副本在处理完每个批次后保持同步。同步性使模型收敛行为与您在单设备训练中看到的行为相同。
具体来说,本指南将教您如何使用 tf.distribute
API 在多个 GPU 上训练 Keras 模型,对您的代码进行最小的更改,在以下两种设置中
- 在安装在单台机器上的多个 GPU 上(通常为 2 到 8 个)。这是研究人员和小规模行业工作流中最常见的设置。
- 在许多机器的集群上,每个机器都托管一个或多个 GPU(多工作器分布式训练)。对于大规模行业工作流来说,这是一个不错的设置,例如使用 20-100 个 GPU 在数千万张图像上训练高分辨率图像分类模型。
设置
import tensorflow as tf
import keras
单主机、多设备同步训练
在这种设置中,您有一台机器,上面安装了多个 GPU(通常为 2 到 8 个)。每个设备将运行模型的副本(称为**副本**)。为简单起见,在下文中,我们将假设我们正在处理 8 个 GPU,这不会影响一般性。
工作原理
在训练的每个步骤中
- 当前的数据批次(称为**全局批次**)将被分成 8 个不同的子批次(称为**本地批次**)。例如,如果全局批次有 512 个样本,则 8 个本地批次中的每一个将有 64 个样本。
- 8 个副本中的每一个都独立地处理一个本地批次:它们运行正向传递,然后是反向传递,输出权重相对于本地批次上模型损失的梯度。
- 源自本地梯度的权重更新将在 8 个副本之间有效地合并。由于这是在每一步结束时完成的,因此副本始终保持同步。
在实践中,同步更新模型副本权重的过程是在每个单独的权重变量级别处理的。这是通过**镜像变量**对象完成的。
如何使用它
要使用 Keras 模型进行单主机多设备同步训练,您将使用tf.distribute.MirroredStrategy
API。以下是它的工作原理
- 实例化一个
MirroredStrategy
,可以选择配置要使用的特定设备(默认情况下,策略将使用所有可用的 GPU)。 - 使用策略对象打开一个范围,并在该范围内创建所有包含变量的 Keras 对象。通常,这意味着在分布范围内部**创建和编译模型**。
- 像往常一样通过
fit()
训练模型。
重要的是,我们建议您使用tf.data.Dataset
对象在多设备或分布式工作流程中加载数据。
从结构上看,它看起来像这样
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = Model(...)
model.compile(...)
# Train the model on all available devices.
model.fit(train_dataset, validation_data=val_dataset, ...)
# Test the model on all available devices.
model.evaluate(test_dataset)
这是一个简单的端到端可运行示例
def get_compiled_model():
# Make a simple 2-layer densely-connected neural network.
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
def get_dataset():
batch_size = 32
num_val_samples = 10000
# Return the MNIST dataset in the form of a `tf.data.Dataset`.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are Numpy arrays)
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve num_val_samples samples for validation
x_val = x_train[-num_val_samples:]
y_val = y_train[-num_val_samples:]
x_train = x_train[:-num_val_samples]
y_train = y_train[:-num_val_samples]
return (
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
)
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = get_compiled_model()
# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
# Test the model on all available devices.
model.evaluate(test_dataset)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Number of devices: 4 2023-07-19 11:35:32.379801: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 50000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } Epoch 1/2 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 1556/1563 [============================>.] - ETA: 0s - loss: 0.2236 - sparse_categorical_accuracy: 0.9328INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 2023-07-19 11:35:46.769935: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:2" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 1563/1563 [==============================] - 16s 7ms/step - loss: 0.2238 - sparse_categorical_accuracy: 0.9328 - val_loss: 0.1347 - val_sparse_categorical_accuracy: 0.9592 Epoch 2/2 1563/1563 [==============================] - 11s 7ms/step - loss: 0.0940 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0984 - val_sparse_categorical_accuracy: 0.9684 2023-07-19 11:35:59.993148: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:4" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 313/313 [==============================] - 2s 4ms/step - loss: 0.1057 - sparse_categorical_accuracy: 0.9676 [0.10571097582578659, 0.9675999879837036]
使用回调来确保容错
在使用分布式训练时,您应该始终确保您有一个策略来从故障中恢复(容错)。处理此问题的最简单方法是将ModelCheckpoint
回调传递给fit()
,以便定期保存您的模型(例如,每 100 个批次或每个时期)。然后,您可以从保存的模型重新开始训练。
这是一个简单的示例
import os
from tensorflow import keras
# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def make_or_restore_model():
# Either restore the latest model, or create a fresh one
# if there is no checkpoint available.
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getctime)
print("Restoring from", latest_checkpoint)
return keras.models.load_model(latest_checkpoint)
print("Creating a new model")
return get_compiled_model()
def run_training(epochs=1):
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
# Open a strategy scope and create/restore the model
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# This callback saves a SavedModel every epoch
# We include the current epoch in the folder name.
keras.callbacks.ModelCheckpoint(
filepath=checkpoint_dir + "/ckpt-{epoch}", save_freq="epoch"
)
]
model.fit(
train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset,
verbose=2,
)
# Running the first time creates the model
run_training(epochs=1)
# Calling the same function again will resume from where we left off
run_training(epochs=1)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Creating a new model 2023-07-19 11:36:01.811216: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 50000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 2023-07-19 11:36:13.671835: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:2" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets 1563/1563 - 14s - loss: 0.2268 - sparse_categorical_accuracy: 0.9322 - val_loss: 0.1148 - val_sparse_categorical_accuracy: 0.9656 - 14s/epoch - 9ms/step INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Restoring from ./ckpt/ckpt-1 2023-07-19 11:36:16.521031: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 50000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 2023-07-19 11:36:28.440092: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:2" } } attr { key: "output_shapes" value { list { shape { dim { size: 784 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets INFO:tensorflow:Assets written to: ./ckpt/ckpt-1/assets 1563/1563 - 13s - loss: 0.0974 - sparse_categorical_accuracy: 0.9703 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9724 - 13s/epoch - 9ms/step
tf.data
性能提示
在进行分布式训练时,加载数据的效率通常会变得至关重要。以下是一些提示,以确保您的tf.data
管道运行得尽可能快。
关于数据集批处理的说明
在创建数据集时,确保它使用全局批次大小进行批处理。例如,如果您的 8 个 GPU 中的每一个都能够运行一个包含 64 个样本的批次,那么您可以使用 512 的全局批次大小。
调用dataset.cache()
如果您在数据集上调用.cache()
,则其数据将在运行第一个数据迭代后被缓存。每次后续迭代都将使用缓存的数据。缓存可以在内存中(默认)或您指定的本地文件中。
这可以在以下情况下提高性能
- 您的数据预计不会在迭代之间发生变化
- 您正在从远程分布式文件系统读取数据
- 您正在从本地磁盘读取数据,但您的数据可以放入内存,并且您的工作流程明显受 IO 限制(例如,读取和解码图像文件)。
调用dataset.prefetch(buffer_size)
您几乎应该始终在创建数据集后调用.prefetch(buffer_size)
。这意味着您的数据管道将与您的模型异步运行,新的样本将在当前批次样本用于训练模型时预处理并存储在缓冲区中。当当前批次结束时,下一个批次将在 GPU 内存中预取。
多工作器分布式同步训练
工作原理
在这种设置中,您有多台机器(称为**工作器**),每台机器上都安装了一个或多个 GPU。与单主机训练发生的情况非常相似,每个可用的 GPU 将运行一个模型副本,并且每个副本的变量的值在每个批次之后保持同步。
重要的是,当前的实现假设所有工作器都具有相同数量的 GPU(同构集群)。
如何使用它
- 设置一个集群(我们在下面提供指针)。
- 在每个工作器上设置适当的
TF_CONFIG
环境变量。这告诉工作器它的角色是什么以及如何与其对等方通信。 - 在每个工作器上,在
MultiWorkerMirroredStrategy
对象的范围内运行您的模型构建和编译代码,类似于我们在单主机训练中所做的。 - 在指定的评估器机器上运行评估代码。
设置集群
首先,设置一个集群(机器集合)。每台机器都应该单独设置,以便能够运行您的模型(通常,每台机器将运行相同的 Docker 镜像)并能够访问您的数据源(例如,GCS)。
集群管理超出了本指南的范围。这是一份文档,可以帮助您入门。您还可以查看Kubeflow。
设置TF_CONFIG
环境变量
虽然在每个工作器上运行的代码几乎与单主机工作流程中使用的代码相同(除了使用不同的tf.distribute
策略对象),但单主机工作流程和多工作器工作流程之间的一个显着区别是,您需要在集群中运行的每台机器上设置一个TF_CONFIG
环境变量。
TF_CONFIG
环境变量是一个 JSON 字符串,它指定
- 集群配置,以及构成集群的机器的地址和端口列表
- 工作器的“任务”,即这台特定机器在集群中要扮演的角色。
TF_CONFIG 的一个示例是
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
在多工作器同步训练设置中,机器的有效角色(任务类型)是“工作器”和“评估器”。
例如,如果您有 8 台机器,每台机器有 4 个 GPU,那么您可以有 7 个工作器和一个评估器。
- 工作器训练模型,每个工作器处理全局批次的子批次。
- 其中一个工作器(工作器 0)将充当“首席”,这是一种特殊的工作器,负责保存日志和检查点以供以后重用(通常保存到云存储位置)。
- 评估器运行一个连续循环,该循环加载首席工作器保存的最新检查点,对其运行评估(与其他工作器异步),并写入评估日志(例如,TensorBoard 日志)。
在每个工作器上运行代码
您将在每个工作器(包括首席)上运行训练代码,并在评估器上运行评估代码。
训练代码基本上与您在单主机设置中使用的代码相同,只是使用MultiWorkerMirroredStrategy
而不是MirroredStrategy
。
每个工作器将运行相同的代码(减去下面说明的差异),包括相同的回调。
评估器将简单地使用MirroredStrategy
(因为它在单台机器上运行,不需要与其他机器通信)并调用model.evaluate()
。它将加载首席工作器保存到云存储位置的最新检查点,并将评估日志保存到与首席日志相同的存储位置。
示例:在多工作器设置中运行的代码
在首席(工作器 0)上
# Set TF_CONFIG
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
# Open a strategy scope and create/restore the model.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# This callback saves a SavedModel every 100 batches
keras.callbacks.ModelCheckpoint(filepath='path/to/cloud/location/ckpt',
save_freq=100),
keras.callbacks.TensorBoard('path/to/cloud/location/tb/')
]
model.fit(train_dataset,
callbacks=callbacks,
...)
在其他工作器上
# Set TF_CONFIG
worker_index = 1 # For instance
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': worker_index}
})
# Open a strategy scope and create/restore the model.
# You can restore from the checkpoint saved by the chief.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = make_or_restore_model()
callbacks = [
keras.callbacks.ModelCheckpoint(filepath='local/path/ckpt', save_freq=100),
keras.callbacks.TensorBoard('local/path/tb/')
]
model.fit(train_dataset,
callbacks=callbacks,
...)
在评估器上
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = make_or_restore_model() # Restore from the checkpoint saved by the chief.
results = model.evaluate(val_dataset)
# Then, log the results on a shared location, write TensorBoard logs, etc