使用 Keras 和 MultiWorkerMirroredStrategy 的自定义训练循环

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

本教程演示了如何使用 Keras 模型以及 自定义训练循环 使用 tf.distribute.Strategy API 进行多工作器分布式训练。训练循环通过 tf.distribute.MultiWorkerMirroredStrategy 进行分布,因此 tf.keras 模型(设计为在 单工作器 上运行)可以无缝地在多个工作器上运行,只需进行最少的代码更改。自定义训练循环提供了灵活性,并对训练提供了更大的控制,同时还简化了模型调试。详细了解 编写基本训练循环从头开始编写训练循环 以及 自定义训练

如果您正在寻找如何使用 MultiWorkerMirroredStrategytf.keras.Model.fit,请参考此 教程

TensorFlow 中的分布式训练 指南可用于概述 TensorFlow 支持的分布式策略,对于那些希望更深入了解 tf.distribute.Strategy API 的人来说。

设置

首先,导入一些必要的库。

import json
import os
import sys

在导入 TensorFlow 之前,对环境进行一些更改

  • 禁用所有 GPU。这可以防止所有工作器尝试使用同一个 GPU 导致的错误。在实际应用中,每个工作器都应该在不同的机器上。
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  • 重置 'TF_CONFIG' 环境变量(稍后您将看到更多关于它的内容)。
os.environ.pop('TF_CONFIG', None)
  • 确保当前目录在 Python 的路径上。这允许笔记本导入稍后由 %%writefile 编写的文件。
if '.' not in sys.path:
  sys.path.insert(0, '.')

现在导入 TensorFlow。

import tensorflow as tf

数据集和模型定义

接下来,创建一个名为 mnist.py 的文件,其中包含一个简单的模型和数据集设置。此 Python 文件将被本教程中的工作器进程使用。

%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # You need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000)
  return train_dataset

def dataset_fn(global_batch_size, input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = mnist_dataset(batch_size)
  dataset = dataset.shard(input_context.num_input_pipelines,
                          input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  return dataset

def build_cnn_model():
  regularizer = tf.keras.regularizers.L2(1e-5)
  return tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3,
                             activation='relu',
                             kernel_regularizer=regularizer),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128,
                            activation='relu',
                            kernel_regularizer=regularizer),
      tf.keras.layers.Dense(10, kernel_regularizer=regularizer)
  ])

多工作器配置

现在让我们进入多工作器训练的世界。在 TensorFlow 中,'TF_CONFIG' 环境变量是多机训练的必要条件。每台机器可能扮演不同的角色。下面使用的 'TF_CONFIG' 变量是一个 JSON 字符串,它指定了集群中每个工作器的集群配置。这是使用 cluster_resolver.TFConfigClusterResolver 指定集群的默认方法,但 distribute.cluster_resolver 模块中还有其他选项可用。了解有关在 分布式训练指南 中设置 'TF_CONFIG' 变量的更多信息。

描述您的集群

这是一个配置示例

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

请注意,tf_config 只是 Python 中的一个局部变量。要将其用于训练配置,请将其序列化为 JSON 并将其放置在 'TF_CONFIG' 环境变量中。以下是同一个 'TF_CONFIG' 序列化为 JSON 字符串的形式

json.dumps(tf_config)

'TF_CONFIG' 有两个组成部分:'cluster''task'

  • 'cluster' 对所有工作器都相同,并提供有关训练集群的信息,它是一个包含不同类型作业(例如 'worker')的字典。在使用 MultiWorkerMirroredStrategy 的多工作器训练中,通常有一个 'worker' 承担更多责任,例如保存检查点和为 TensorBoard 写入摘要文件,除了常规 'worker' 所做的工作之外。这样的工作器被称为 'chief' 工作器,通常将 'index' 为 0 的 'worker' назначен为 chief worker

  • 'task' 提供当前任务的信息,每个工作器都不相同。它指定了该工作器的 'type''index'

在此示例中,您将任务 'type' 设置为 'worker',并将任务 'index' 设置为 0。这台机器是第一个工作器,并将被指定为 chief 工作器,比其他工作器做更多工作。请注意,其他机器也需要设置 'TF_CONFIG' 环境变量,它应该具有相同的 'cluster' 字典,但不同的任务 'type' 或任务 'index',具体取决于这些机器的角色。

为了说明目的,本教程展示了如何在一个 'localhost' 上设置两个工作器的 'TF_CONFIG'。在实践中,用户会在外部 IP 地址/端口上创建多个工作器,并在每个工作器上适当地设置 'TF_CONFIG'

此示例使用两个工作器。第一个工作器的 'TF_CONFIG' 如上所示。对于第二个工作器,请设置 tf_config['task']['index']=1

笔记本中的环境变量和子进程

子进程继承其父进程的环境变量。因此,如果您在此 Jupyter Notebook 进程中设置了一个环境变量

os.environ['GREETINGS'] = 'Hello TensorFlow!'

那么您就可以从子进程中访问该环境变量

echo ${GREETINGS}

在下一节中,您将使用它将 'TF_CONFIG' 传递给工作器子进程。您永远不会真正以这种方式启动作业,但这对于本教程的目的来说已经足够了:演示一个最小的多工作器示例。

MultiWorkerMirroredStrategy

在训练模型之前,首先创建一个 tf.distribute.MultiWorkerMirroredStrategy 实例

strategy = tf.distribute.MultiWorkerMirroredStrategy()

使用 tf.distribute.Strategy.scope 指定在构建模型时应使用策略。这允许策略控制诸如变量放置之类的事情——它将在所有工作器上的每个设备上创建模型层中所有变量的副本。

import mnist
with strategy.scope():
  # Model building needs to be within `strategy.scope()`.
  multi_worker_model = mnist.build_cnn_model()

自动将您的数据跨工作器分片

在多工作器训练中,需要进行数据集分片以确保收敛和可重复性。分片意味着将整个数据集的子集分配给每个工作器——它有助于创建类似于在单个工作器上训练的体验。在下面的示例中,您依赖于 tf.distribute 的默认自动分片策略。您也可以通过设置 tf.data.experimental.AutoShardPolicy 来自定义它。 tf.data.experimental.DistributeOptions。要了解更多信息,请参阅 分布式输入教程 中的分片部分。

per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers

with strategy.scope():
  multi_worker_dataset = strategy.distribute_datasets_from_function(
      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))

定义自定义训练循环并训练模型

指定优化器

with strategy.scope():
  # The creation of optimizer and train_accuracy needs to be in
  # `strategy.scope()` as well, since they create variables.
  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')

使用 tf.function 定义训练步骤

@tf.function
def train_step(iterator):
  """Training step function."""

  def step_fn(inputs):
    """Per-Replica step function."""
    x, y = inputs
    with tf.GradientTape() as tape:
      predictions = multi_worker_model(x, training=True)
      per_example_loss = tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True,
          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
      loss = tf.nn.compute_average_loss(per_example_loss)
      model_losses = multi_worker_model.losses
      if model_losses:
        loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))

    grads = tape.gradient(loss, multi_worker_model.trainable_variables)
    optimizer.apply_gradients(
        zip(grads, multi_worker_model.trainable_variables))
    train_accuracy.update_state(y, predictions)
    return loss

  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
  return strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

检查点保存和恢复

当您编写自定义训练循环时,您需要手动处理 检查点保存,而不是依赖 Keras 回调。请注意,对于 MultiWorkerMirroredStrategy,保存检查点或完整模型需要所有工作器的参与,因为尝试仅在 chief 工作器上保存可能会导致死锁。工作器还需要写入不同的路径以避免相互覆盖。以下是如何配置目录的示例

from multiprocessing import util
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')

def _is_chief(task_type, task_id, cluster_spec):
  return (task_type is None
          or task_type == 'chief'
          or (task_type == 'worker'
              and task_id == 0
              and "chief" not in cluster_spec.as_dict()))

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id, cluster_spec):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id, cluster_spec):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

创建一个 tf.train.Checkpoint 来跟踪模型,该模型由 tf.train.CheckpointManager 管理,以便只保留最新的检查点

epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64),
    name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this
# illustrative example you did not set `'TF_CONFIG'` before initializing the
# strategy. Check out the next section for "real-world" usage.
cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])

checkpoint = tf.train.Checkpoint(
    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)

write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,
                                      cluster_spec)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

现在,当您需要恢复检查点时,可以使用方便的 tf.train.latest_checkpoint 函数(或通过调用 tf.train.CheckpointManager.restore_or_initialize)找到保存的最新检查点。

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
  checkpoint.restore(latest_checkpoint)

恢复检查点后,您可以继续使用自定义训练循环进行训练。

num_epochs = 3
num_steps_per_epoch = 70

while epoch.numpy() < num_epochs:
  iterator = iter(multi_worker_dataset)
  total_loss = 0.0
  num_batches = 0

  while step_in_epoch.numpy() < num_steps_per_epoch:
    total_loss += train_step(iterator)
    num_batches += 1
    step_in_epoch.assign_add(1)

  train_loss = total_loss / num_batches
  print('Epoch: %d, accuracy: %f, train_loss: %f.'
                %(epoch.numpy(), train_accuracy.result(), train_loss))

  train_accuracy.reset_states()

  # Once the `CheckpointManager` is set up, you're now ready to save, and remove
  # the checkpoints non-chief workers saved.
  checkpoint_manager.save()
  if not _is_chief(task_type, task_id, cluster_spec):
    tf.io.gfile.rmtree(write_checkpoint_dir)

  epoch.assign_add(1)
  step_in_epoch.assign(0)

完整代码一览

总结迄今为止讨论的所有过程

  1. 您创建工作器进程。
  2. 'TF_CONFIG' 传递给工作器进程。
  3. 让每个工作器进程运行下面包含训练代码的脚本。

文件:main.py

当前目录现在包含这两个 Python 文件

ls *.py

因此,将 'TF_CONFIG' 序列化为 JSON 并将其添加到环境变量中

os.environ['TF_CONFIG'] = json.dumps(tf_config)

现在,您可以启动一个工作器进程,它将运行 main.py 并使用 'TF_CONFIG'

# first kill any previous runs
%killbgscripts
python main.py &> job_0.log

关于上述命令,需要注意以下几点

  1. 它使用 %%bash,这是一个 笔记本“魔法”,用于运行一些 bash 命令。
  2. 它使用 --bg 标志在后台运行 bash 进程,因为此工作器不会终止。它在开始之前等待所有工作器。

后台工作器进程不会将输出打印到此笔记本。 &> 将其输出重定向到一个文件,以便您可以检查发生了什么。

等待几秒钟让进程启动

import time
time.sleep(20)

现在,检查到目前为止写入工作器日志文件的内容

cat job_0.log

日志文件的最后一行应该显示:Started server with target: grpc://127.0.0.1:12345。第一个工作器现在已准备就绪,并且正在等待所有其他工作器准备就绪以继续。

更新第二个工作器进程要获取的 tf_config

tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

现在启动第二个工作器。由于所有工作器都处于活动状态,这将启动训练(因此不需要将此进程置于后台)

python main.py > /dev/null 2>&1

如果您重新检查第一个工作器写入的日志,请注意它参与了训练该模型

cat job_0.log
# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts

深入了解多工作器训练

本教程演示了多工作器设置的自定义训练循环工作流程。有关其他主题的详细说明,请参阅适用于自定义训练循环的 使用 Keras 的多工作器训练 (tf.keras.Model.fit) 教程。

了解更多信息

  1. TensorFlow 中的分布式训练 指南概述了可用的分布式策略。
  2. 官方模型,其中许多模型可以配置为运行多个分布式策略。
  3. 性能部分tf.function 指南中提供了有关其他策略和 工具 的信息,您可以使用这些工具来优化 TensorFlow 模型的性能。