从头开始编写训练循环

作者: fchollet

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 在 keras.io 上查看

设置

import tensorflow as tf
import keras
from keras import layers
import numpy as np

简介

Keras 提供默认的训练和评估循环,fit()evaluate()。它们的用法在指南 使用内置方法进行训练和评估 中介绍。

如果您想在仍然利用 fit() 的便利性的同时自定义模型的学习算法(例如,使用 fit() 训练 GAN),您可以对 Model 类进行子类化并实现自己的 train_step() 方法,该方法在 fit() 期间反复调用。这在指南 自定义 fit() 中的行为 中介绍。

现在,如果您希望对训练和评估进行非常底层的控制,您应该从头开始编写自己的训练和评估循环。这就是本指南的主题。

使用 GradientTape:第一个端到端示例

GradientTape 范围内调用模型可以使您检索层可训练权重相对于损失值的梯度。使用优化器实例,您可以使用这些梯度来更新这些变量(可以使用 model.trainable_weights 检索)。

让我们考虑一个简单的 MNIST 模型

inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)

让我们使用自定义训练循环使用小批量梯度对其进行训练。

首先,我们需要一个优化器、一个损失函数和一个数据集

# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

这是我们的训练循环

  • 我们打开一个 for 循环,该循环遍历各个时期
  • 对于每个时期,我们打开一个 for 循环,该循环遍历数据集中的各个批次
  • 对于每个批次,我们打开一个 GradientTape() 范围
  • 在这个范围内,我们调用模型(前向传播)并计算损失
  • 在模型范围之外,我们获取模型权重相对于损失的梯度。
  • 最后,我们使用优化器根据梯度更新模型的权重。
epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:
            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))
Start of epoch 0
WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7f51fe36a4c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://tensorflowcn.cn/guide/function#controlling_retracing and https://tensorflowcn.cn/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7f51fe36a4c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://tensorflowcn.cn/guide/function#controlling_retracing and https://tensorflowcn.cn/api_docs/python/tf/function for  more details.
Training loss (for one batch) at step 0: 131.3794
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.2871
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.2652
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.8800
Seen so far: 38464 samples

Start of epoch 1
Training loss (for one batch) at step 0: 0.8296
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.3322
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.0486
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.6610
Seen so far: 38464 samples

指标的低级处理

让我们在这个基本循环中添加指标监控。

您可以在从头开始编写的此类训练循环中轻松地重复使用内置指标(或您编写的自定义指标)。以下是流程

  • 在循环开始时实例化指标
  • 在每个批次之后调用 metric.update_state()
  • 当您需要显示指标的当前值时,调用 metric.result()
  • 当您需要清除指标的状态时,调用 metric.reset_states()(通常在每个 epoch 结束时)

让我们使用这些知识来计算每个 epoch 结束时验证数据的 SparseCategoricalAccuracy

# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

这是我们的训练和评估循环

import time

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))
Start of epoch 0
Training loss (for one batch) at step 0: 106.2691
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.9259
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.9347
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.7641
Seen so far: 38464 samples
Training acc over epoch: 0.7332
Validation acc: 0.8325
Time taken: 10.95s

Start of epoch 1
Training loss (for one batch) at step 0: 0.5238
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.7125
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.5705
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.6006
Seen so far: 38464 samples
Training acc over epoch: 0.8424
Validation acc: 0.8525
Time taken: 10.59s

使用 tf.function 加速您的训练步骤

TensorFlow 2 中的默认运行时是 急切执行。因此,我们上面的训练循环急切地执行。

这对于调试非常有用,但图形编译具有明显的性能优势。将您的计算描述为静态图形使框架能够应用全局性能优化。当框架被限制为贪婪地执行一个接一个的操作时,这是不可能的,并且对接下来会发生什么一无所知。

您可以将任何以张量作为输入的函数编译成静态图形。只需在它上面添加一个 @tf.function 装饰器,如下所示

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

让我们对评估步骤做同样的事情

@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

现在,让我们使用这个编译的训练步骤重新运行我们的训练循环

import time

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))
Start of epoch 0
Training loss (for one batch) at step 0: 0.5162
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4599
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.3975
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.2557
Seen so far: 38464 samples
Training acc over epoch: 0.8747
Validation acc: 0.8545
Time taken: 1.85s

Start of epoch 1
Training loss (for one batch) at step 0: 0.6145
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.3751
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.3464
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.4128
Seen so far: 38464 samples
Training acc over epoch: 0.8919
Validation acc: 0.8996
Time taken: 1.34s

快多了,不是吗?

模型跟踪的损失的低级处理

层和模型递归地跟踪在正向传递期间由调用 self.add_loss(value) 的层创建的任何损失。在正向传递结束时,可以通过属性 model.losses 获取结果的标量损失值列表。

如果您想使用这些损失组件,您应该将它们加起来并添加到训练步骤中的主要损失中。

考虑这一层,它创建了一个活动正则化损失

@keras.saving.register_keras_serializable()
class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs

让我们构建一个非常简单的模型来使用它

inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

现在我们的训练步骤应该如下所示

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        # Add any extra losses created during the forward pass.
        loss_value += sum(model.losses)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

总结

现在您已经了解了使用内置训练循环和从头开始编写自己的训练循环的所有知识。

最后,这是一个将您在本指南中学习到的所有内容结合在一起的简单端到端示例:在 MNIST 数字上训练的 DCGAN。

端到端示例:从头开始的 GAN 训练循环

您可能熟悉生成对抗网络 (GAN)。GAN 可以生成看起来几乎真实的全新图像,方法是学习图像训练数据集的潜在分布(图像的“潜在空间”)。

GAN 由两部分组成:“生成器”模型,它将潜在空间中的点映射到图像空间中的点,“鉴别器”模型,一个分类器,它可以区分真实图像(来自训练数据集)和假图像(生成器网络的输出)。

GAN 训练循环如下所示

1) 训练鉴别器。- 在潜在空间中采样一批随机点。- 通过“生成器”模型将点转换为假图像。- 获取一批真实图像并将它们与生成的图像组合在一起。- 训练“鉴别器”模型以对生成的图像与真实图像进行分类。

2) 训练生成器。- 在潜在空间中采样随机点。- 通过“生成器”网络将点转换为假图像。- 获取一批真实图像并将它们与生成的图像组合在一起。- 训练“生成器”模型以“欺骗”鉴别器并将假图像分类为真实图像。

有关 GAN 工作原理的更详细概述,请参阅 使用 Python 进行深度学习

让我们实现这个训练循环。首先,创建一个旨在对假数字与真数字进行分类的鉴别器

discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 14, 14, 64)        640       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 14, 14, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 128)         73856     
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                                 
 global_max_pooling2d (Glob  (None, 128)               0         
 alMaxPooling2D)                                                 
                                                                 
 dense_4 (Dense)             (None, 1)                 129       
                                                                 
=================================================================
Total params: 74625 (291.50 KB)
Trainable params: 74625 (291.50 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

然后让我们创建一个生成器网络,它将潜在向量转换为形状为 (28, 28, 1) 的输出(表示 MNIST 数字)

latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

这是关键部分:训练循环。如您所见,它非常简单。训练步骤函数只占 17 行。

# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)

# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def train_step(real_images):
    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Decode them to fake images
    generated_images = generator(random_latent_vectors)
    # Combine them with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble labels discriminating real from fake images
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )
    # Add random noise to the labels - important trick!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # Train the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator(combined_images)
        d_loss = loss_fn(labels, predictions)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Assemble labels that say "all real images"
    misleading_labels = tf.zeros((batch_size, 1))

    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with tf.GradientTape() as tape:
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = loss_fn(misleading_labels, predictions)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
    return d_loss, g_loss, generated_images

让我们通过在图像批次上重复调用 train_step 来训练我们的 GAN。

由于我们的鉴别器和生成器是卷积神经网络,因此您需要在 GPU 上运行此代码。

import os

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

epochs = 1  # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"

for epoch in range(epochs):
    print("\nStart epoch", epoch)

    for step, real_images in enumerate(dataset):
        # Train the discriminator & generator on one batch of real images.
        d_loss, g_loss, generated_images = train_step(real_images)

        # Logging.
        if step % 200 == 0:
            # Print metrics
            print("discriminator loss at step %d: %.2f" % (step, d_loss))
            print("adversarial loss at step %d: %.2f" % (step, g_loss))

            # Save one generated image
            img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
            img.save(os.path.join(save_dir, "generated_img" + str(step) + ".png"))

        # To limit execution time we stop after 10 steps.
        # Remove the lines below to actually train the model!
        if step > 10:
            break
Start epoch 0
discriminator loss at step 0: 0.72
adversarial loss at step 0: 0.72

就是这样!您将在 Colab GPU 上训练大约 30 秒后获得看起来不错的假 MNIST 数字。