作者: fchollet
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 在 keras.io 上查看 |
介绍
在进行监督学习时,您可以使用 fit()
,一切都会顺利进行。
当您需要从头开始编写自己的训练循环时,可以使用 GradientTape
并控制每一个细节。
但是,如果您需要自定义训练算法,但仍然希望从 fit()
的便捷功能中获益,例如回调、内置分布式支持或步骤融合,该怎么办?
Keras 的核心原则是 **逐步揭示复杂性**。您应该始终能够以渐进的方式进入更低级的流程。如果高级功能不完全符合您的用例,您不应该掉入陷阱。您应该能够对细节进行更多控制,同时保留相应数量的高级便利性。
当您需要自定义 fit()
的行为时,您应该 **覆盖 Model
类的训练步骤函数**。这是 fit()
为每批数据调用的函数。然后,您就可以像往常一样调用 fit()
- 它将运行您自己的学习算法。
请注意,这种模式不会阻止您使用函数式 API 构建模型。无论您是构建 Sequential
模型、函数式 API 模型还是子类化模型,都可以这样做。
让我们看看它是如何工作的。
设置
需要 TensorFlow 2.8 或更高版本。
import tensorflow as tf
from tensorflow import keras
第一个简单的示例
让我们从一个简单的示例开始
- 我们创建一个新的类,它对
keras.Model
进行子类化。 - 我们只需覆盖方法
train_step(self, data)
。 - 我们返回一个字典,该字典将度量名称(包括损失)映射到它们的当前值。
输入参数 data
是传递给 fit 作为训练数据的。
- 如果您传递 NumPy 数组,通过调用
fit(x, y, ...)
,那么data
将是元组(x, y)
。 - 如果您传递一个
tf.data.Dataset
,通过调用fit(dataset, ...)
,那么data
将是每次批次中dataset
生成的内容。
在 train_step
方法的代码体中,我们实现了常规的训练更新,类似于您已经熟悉的。重要的是,**我们通过 self.compute_loss()
计算损失**,它封装了传递给 compile()
的损失函数。
类似地,我们对来自 self.metrics
的指标调用 metric.update_state(y, y_pred)
,以更新传递给 compile()
的指标状态,并在最后查询 self.metrics
的结果以检索其当前值。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compute_loss(y=y, y_pred=y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
让我们试试这个。
import numpy as np
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3 32/32 [==============================] - 3s 2ms/step - loss: 1.6446 Epoch 2/3 32/32 [==============================] - 0s 2ms/step - loss: 0.7554 Epoch 3/3 32/32 [==============================] - 0s 2ms/step - loss: 0.3924 <keras.src.callbacks.History at 0x7fef5c11ba30>
更底层
当然,您可以直接跳过在 compile()
中传递损失函数,而是在 train_step
中手动完成所有操作。指标也是如此。
这是一个更底层的示例,它只使用 compile()
来配置优化器。
- 我们首先创建
Metric
实例来跟踪我们的损失和 MAE 分数(在__init__()
中)。 - 我们实现了一个自定义的
train_step()
,它更新这些指标的状态(通过调用它们的update_state()
),然后查询它们(通过result()
)以返回它们的当前平均值,以便进度条显示并传递给任何回调。 - 请注意,我们需要在每个 epoch 之间调用我们的指标上的
reset_states()
!否则,调用result()
将返回从训练开始以来的平均值,而我们通常使用每个 epoch 的平均值。值得庆幸的是,框架可以为我们做到这一点:只需将您想要重置的任何指标列在模型的metrics
属性中。模型将在每个fit()
epoch 的开始或调用evaluate()
的开始时,对这里列出的任何对象调用reset_states()
。
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute our own loss
loss = keras.losses.mean_squared_error(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Compute our own metrics
self.loss_tracker.update_state(loss)
self.mae_metric.update_state(y, y_pred)
return {"loss": self.loss_tracker.result(), "mae": self.mae_metric.result()}
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [self.loss_tracker, self.mae_metric]
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 0.3240 - mae: 0.4583 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2416 - mae: 0.3984 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2340 - mae: 0.3919 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2274 - mae: 0.3870 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2197 - mae: 0.3808 <keras.src.callbacks.History at 0x7fef3c130b20>
支持 sample_weight
& class_weight
您可能已经注意到,我们第一个基本示例没有提到样本加权。如果您想支持 fit()
参数 sample_weight
和 class_weight
,您只需执行以下操作。
- 从
data
参数中解包sample_weight
。 - 将其传递给
compute_loss
&update_state
(当然,如果您不依赖compile()
来处理损失和指标,您也可以手动应用它)。 - 就是这样。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value.
# The loss function is configured in `compile()`.
loss = self.compute_loss(
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics.
# Metrics are configured in `compile()`.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred, sample_weight=sample_weight)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3 32/32 [==============================] - 0s 2ms/step - loss: 0.1298 Epoch 2/3 32/32 [==============================] - 0s 2ms/step - loss: 0.1179 Epoch 3/3 32/32 [==============================] - 0s 2ms/step - loss: 0.1121 <keras.src.callbacks.History at 0x7fef3c168100>
提供您自己的评估步骤
如果您想对调用 model.evaluate()
做同样的事情怎么办?然后,您将以完全相同的方式覆盖 test_step
。以下是它的样子。
class CustomModel(keras.Model):
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
self.compute_loss(y=y, y_pred=y_pred)
# Update the metrics.
for metric in self.metrics:
if metric.name != "loss":
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 [==============================] - 0s 1ms/step - loss: 0.9028 0.9028095006942749
总结:端到端 GAN 示例
让我们逐步完成一个利用您刚刚学到的所有内容的端到端示例。
让我们考虑
- 一个旨在生成 28x28x1 图像的生成器网络。
- 一个旨在将 28x28x1 图像分类为两个类别(“假”和“真”)的鉴别器网络。
- 每个网络都有一个优化器。
- 一个用于训练鉴别器的损失函数。
from tensorflow.keras import layers
# Create the discriminator
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
layers.Dense(7 * 7 * 128),
layers.LeakyReLU(alpha=0.2),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
这是一个功能完整的 GAN 类,它覆盖了 compile()
以使用它自己的签名,并在 train_step
中用 17 行代码实现了整个 GAN 算法。
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics and return their value.
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
}
让我们试用一下。
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step 100/100 [==============================] - 8s 15ms/step - d_loss: 0.4372 - g_loss: 0.8775 <keras.src.callbacks.History at 0x7feee42ff190>
深度学习背后的理念很简单,那么为什么它们的实现会如此痛苦呢?