使用 tf.distribute.Strategy 进行自定义训练

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本教程演示了如何使用 tf.distribute.Strategy(一个 TensorFlow API,它提供了一个抽象来 将您的训练分布 到多个处理单元(GPU、多台机器或 TPU)上)以及自定义训练循环。在本示例中,您将在包含 70,000 张大小为 28 x 28 的图像的 Fashion MNIST 数据集 上训练一个简单的卷积神经网络。

自定义训练循环 提供了灵活性,并能更好地控制训练。它们还使调试模型和训练循环变得更加容易。

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

下载 Fashion MNIST 数据集

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

创建策略以分布变量和图

如何使用 tf.distribute.MirroredStrategy 策略?

  • 所有变量和模型图都会在副本之间复制。
  • 输入均匀地分布在副本之间。
  • 每个副本计算其接收到的输入的损失和梯度。
  • 梯度通过 **求和** 在所有副本之间同步。
  • 同步后,对每个副本上变量的副本进行相同的更新。
# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

设置输入管道

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

创建数据集并进行分布

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

创建模型

使用 tf.keras.Sequential 创建模型。你也可以使用 模型子类化 API函数式 API 来完成此操作。

def create_model():
  regularizer = tf.keras.regularizers.L2(1e-5)
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3,
                             activation='relu',
                             kernel_regularizer=regularizer),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3,
                             activation='relu',
                             kernel_regularizer=regularizer),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64,
                            activation='relu',
                            kernel_regularizer=regularizer),
      tf.keras.layers.Dense(10, kernel_regularizer=regularizer)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

定义损失函数

回想一下,损失函数由一到两部分组成

  • **预测损失** 衡量模型预测与一批训练示例的训练标签之间的差距。它针对每个标记示例进行计算,然后通过计算平均值在批次中进行减少。
  • 可选地,可以将 **正则化损失** 项添加到预测损失中,以引导模型远离过度拟合训练数据。一个常见的选择是 L2 正则化,它在所有模型权重的平方和上添加一个小的固定倍数,与示例数量无关。上面的模型使用 L2 正则化来演示它在下面训练循环中的处理方式。

对于在具有单个 GPU/CPU 的单台机器上进行训练,这将按如下方式工作

  • 针对批次中的每个示例计算预测损失,在批次中求和,然后除以批次大小。
  • 将正则化损失添加到预测损失中。
  • 计算总损失相对于每个模型权重的梯度,优化器根据相应的梯度更新每个模型权重。

使用 tf.distribute.Strategy,输入批次将在副本之间拆分。例如,假设你有 4 个 GPU,每个 GPU 上都有一个模型副本。一批 256 个输入示例均匀地分布在 4 个副本之间,因此每个副本获得一个大小为 64 的批次:我们有 256 = 4*64,或者通常 GLOBAL_BATCH_SIZE = num_replicas_in_sync * BATCH_SIZE_PER_REPLICA

每个副本计算其获得的训练示例的损失,并计算损失相对于每个模型权重的梯度。优化器负责在使用这些梯度更新每个副本上模型权重的副本之前,**将这些梯度在副本之间加起来**。

那么,使用 tf.distribute.Strategy 时,损失应该如何计算?

  • 每个副本计算其分配的所有示例的预测损失,将结果加起来并除以 num_replicas_in_sync * BATCH_SIZE_PER_REPLICA,或者等效地,GLOBAL_BATCH_SIZE
  • 每个副本计算正则化损失(如果有)并将其除以 num_replicas_in_sync

与非分布式训练相比,所有每个副本的损失项都按 1/num_replicas_in_sync 的因子缩小。另一方面,所有损失项(或者更确切地说,它们的梯度)都在该数量的副本之间求和,然后优化器应用它们。实际上,每个副本上的优化器使用的梯度与 GLOBAL_BATCH_SIZE 的非分布式计算发生时相同。这与 Keras Model.fit 的分布式和非分布式行为一致。有关更大的全局批次大小如何能够扩展学习率的教程,请参阅 使用 Keras 进行分布式训练

如何在 TensorFlow 中做到这一点?

  • 损失减少和缩放在 Keras Model.compileModel.fit 中自动完成

  • 如果你正在编写自定义训练循环(如本教程中),你应该使用 tf.nn.compute_average_loss 将每个示例的损失加起来,并将总和除以全局批次大小,该函数接受每个示例的损失和可选的样本权重作为参数,并返回缩放后的损失。

  • 如果使用 tf.keras.losses 类(如以下示例中),则需要显式指定损失减少为 NONESUM 之一。默认的 AUTOSUM_OVER_BATCH_SIZEModel.fit 之外是不允许的。

    • AUTO 不允许,因为用户应该明确考虑他们想要进行的减少,以确保它在分布式情况下是正确的。
    • SUM_OVER_BATCH_SIZE 不允许,因为目前它只会除以每个副本的批次大小,并将除以副本数量留给用户,这可能很容易被忽略。因此,你需要自己显式地进行减少。
  • 如果你正在为具有非空 Model.losses 列表(例如,权重正则化器)的模型编写自定义训练循环,你应该将它们加起来并将总和除以副本数量。你可以使用 tf.nn.scale_regularization_loss 函数来完成此操作。模型代码本身不知道副本数量。

    但是,模型可以使用 Keras API(例如 Layer.add_loss(...)Layer(activity_regularizer=...))定义依赖于输入的正则化损失。对于 Layer.add_loss(...),建模代码需要执行将每个示例的求和项除以每个副本 (!) 的批次大小,例如,通过使用 tf.math.reduce_mean()

with strategy.scope():
  # Set reduction to `NONE` so you can do the reduction yourself.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions, model_losses):
    per_example_loss = loss_object(labels, predictions)
    loss = tf.nn.compute_average_loss(per_example_loss)
    if model_losses:
      loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
    return loss

特殊情况

高级用户还应考虑以下特殊情况。

  • 短于 GLOBAL_BATCH_SIZE 的输入批次在多个地方会产生令人不快的极端情况。在实践中,通过使用 Dataset.repeat().batch() 允许批次跨越时期边界,并通过步数而不是数据集结束来定义近似时期,通常可以最好地避免它们。或者,Dataset.batch(drop_remainder=True) 保持时期的概念,但会删除最后几个示例。

    为了说明,此示例采用更难的路线,并允许短批次,以便每个训练时期只包含每个训练示例一次。

    tf.nn.compute_average_loss() 应该使用哪个分母?

    • 默认情况下,在上面的示例代码中,以及在 Keras.fit() 中,预测损失的总和将除以 num_replicas_in_sync 乘以副本上看到的实际批次大小(空批次被静默忽略)。这保留了预测损失与正则化损失之间的平衡。它特别适用于使用依赖于输入的正则化损失的模型。简单的 L2 正则化只是将权重衰减叠加到预测损失的梯度上,不太需要这种平衡。
    • 在实践中,许多自定义训练循环将一个常量 Python 值传递到 tf.nn.compute_average_loss(..., global_batch_size=GLOBAL_BATCH_SIZE) 中,以将其用作分母。这保留了批次之间训练示例的相对权重。如果没有它,短批次中较小的分母会有效地提高这些批次中示例的权重。(在 TensorFlow 2.13 之前,这也需要避免在某些副本收到实际批次大小为零的情况下出现 NaN。)

    如果避免短批次,如上所述,这两种选项是等效的。

  • 多维 labels 需要你将 per_example_loss 在每个示例中的预测数量上取平均值。考虑一个针对输入图像的所有像素的分类任务,其中 predictions 的形状为 (batch_size, H, W, n_classes),而 labels 的形状为 (batch_size, H, W)。你需要更新 per_example_loss,如下所示:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

定义用于跟踪损失和精度的指标

这些指标跟踪测试损失以及训练和测试精度。你可以使用 .result() 在任何时候获取累积的统计信息。

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

训练循环

# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions, model.losses)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print(template.format(epoch + 1, train_loss,
                         train_accuracy.result() * 100, test_loss.result(),
                         test_accuracy.result() * 100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

上面示例中需要注意的事项

恢复最新的检查点并进行测试

使用 tf.distribute.Strategy 检查点的模型可以有或没有策略恢复。

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))

遍历数据集的替代方法

使用迭代器

如果您想遍历给定数量的步骤而不是整个数据集,您可以使用 iter 调用创建迭代器,并显式地对迭代器调用 next。您可以选择在 tf.function 内部和外部遍历数据集。以下是一个使用迭代器在 tf.function 外部遍历数据集的小片段。

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
  train_accuracy.reset_states()

tf.function 内部迭代

您也可以使用 for x in ... 结构或通过创建像上面一样的迭代器,在 tf.function 内部遍历整个输入 train_dist_dataset。以下示例演示了使用 @tf.function 装饰器包装一个训练周期,并在函数内部遍历 train_dist_dataset

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

  train_accuracy.reset_states()

跟踪跨副本的训练损失

由于进行了损失缩放计算,因此不建议使用 tf.keras.metrics.Mean 跟踪跨不同副本的训练损失。

例如,如果您运行具有以下特征的训练作业

  • 两个副本
  • 每个副本处理两个样本
  • 每个副本上的结果损失值:[2, 3] 和 [4, 5]
  • 全局批次大小 = 4

使用损失缩放,您通过将损失值相加,然后除以全局批次大小来计算每个副本的每个样本损失值。在本例中:(2 + 3) / 4 = 1.25(4 + 5) / 4 = 2.25

如果您使用 tf.keras.metrics.Mean 跟踪跨两个副本的损失,结果将不同。在本例中,您最终得到 total 为 3.50,count 为 2,这导致在对指标调用 result()total/count = 1.75。使用 tf.keras.Metrics 计算的损失会额外缩放一个因子,该因子等于同步副本的数量。

指南和示例

以下是一些使用分布式策略和自定义训练循环的示例

  1. 分布式训练指南
  2. DenseNet 使用 MirroredStrategy 的示例。
  3. BERT 使用 MirroredStrategyTPUStrategy 训练的示例。此示例对于理解如何在分布式训练期间从检查点加载和生成周期性检查点等特别有用。
  4. NCF 使用 MirroredStrategy 训练的示例,可以使用 keras_use_ctl 标志启用。
  5. NMT 使用 MirroredStrategy 训练的示例。

您可以在 分布式策略指南 中的“示例和教程”部分找到更多示例。

下一步