作者: fchollet
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 在 keras.io 上查看 |
设置
import numpy as np
import tensorflow as tf
from tensorflow import keras
2023-10-03 11:11:08.160283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-03 11:11:08.160349: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-03 11:11:08.160404: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
简介
迁移学习是指将在一个问题上学习到的特征应用于另一个类似的新问题。例如,从学习识别浣熊的模型中获得的特征可能有助于启动一个旨在识别狸猫的模型。
迁移学习通常用于数据集数据量不足以从头开始训练完整模型的任务。
深度学习中迁移学习最常见的形式是以下工作流程
- 从先前训练的模型中获取层。
- 冻结它们,以避免在未来的训练回合中破坏它们包含的任何信息。
- 在冻结层之上添加一些新的可训练层。它们将学习将旧特征转换为对新数据集的预测。
- 在您的数据集上训练新层。
最后,可选步骤是微调,它包括解冻您上面获得的整个模型(或部分模型),并在新数据上以非常低的学习率重新训练它。这可以通过逐步调整预训练特征以适应新数据来实现有意义的改进。
首先,我们将详细介绍 Keras 的 trainable
API,它构成了大多数迁移学习和微调工作流程的基础。
然后,我们将通过获取在 ImageNet 数据集上预训练的模型,并在 Kaggle 的“猫与狗”分类数据集上对其进行重新训练来演示典型的工作流程。
这是从 使用 Python 进行深度学习 和 2016 年的博客文章 “使用很少的数据构建强大的图像分类模型” 中改编而来。
冻结层:了解 trainable
属性
层和模型具有三个权重属性
weights
是层的权重变量列表。trainable_weights
是指在训练过程中通过梯度下降法更新以最小化损失的权重列表。non_trainable_weights
是指在训练过程中不会被更新的权重列表。通常它们会在模型的前向传播过程中被更新。
例如:Dense
层有两个可训练权重(kernel 和 bias)。
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 2 non_trainable_weights: 0 2023-10-03 11:11:10.677246: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
一般来说,所有权重都是可训练权重。唯一具有不可训练权重的内置层是 BatchNormalization
层。它使用不可训练权重来跟踪训练过程中输入的均值和方差。要了解如何在自定义层中使用不可训练权重,请参阅 从头开始编写新层的指南。
例如:BatchNormalization
层有两个可训练权重和两个不可训练权重。
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4 trainable_weights: 2 non_trainable_weights: 2
层和模型还具有一个布尔属性 trainable
。它的值可以更改。将 layer.trainable
设置为 False
会将所有层的权重从可训练权重移动到不可训练权重。这被称为“冻结”层:冻结层的状态在训练过程中不会更新(无论是使用 fit()
训练还是使用任何依赖于 trainable_weights
来应用梯度更新的自定义循环)。
例如:将 trainable
设置为 False
。
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 0 non_trainable_weights: 2
当可训练权重变为不可训练权重时,它的值在训练过程中不再更新。
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 324ms/step - loss: 0.0178
不要将 layer.trainable
属性与 layer.__call__()
中的参数 training
混淆(它控制层应该以推理模式还是训练模式运行其前向传播)。有关更多信息,请参阅 Keras 常见问题解答。
递归设置 trainable
属性
如果您在模型或任何具有子层的层上设置 trainable = False
,所有子层也将变为不可训练。
示例
inner_model = keras.Sequential(
[
keras.Input(shape=(3,)),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(3, activation="relu"),
]
)
model = keras.Sequential(
[
keras.Input(shape=(3,)),
inner_model,
keras.layers.Dense(3, activation="sigmoid"),
]
)
model.trainable = False # Freeze the outer model
assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
典型的迁移学习工作流程
这将我们引向如何在 Keras 中实现典型的迁移学习工作流程。
- 实例化一个基础模型,并将预训练权重加载到其中。
- 通过设置
trainable = False
来冻结基础模型中的所有层。 - 在基础模型中一个(或多个)层的输出之上创建一个新模型。
- 在您的新数据集上训练您的新模型。
请注意,另一种更轻量级的流程也可以是
- 实例化一个基础模型,并将预训练权重加载到其中。
- 将您的新数据集通过它,并记录基础模型中一个(或多个)层的输出。这被称为 **特征提取**。
- 将该输出用作新的小模型的输入数据。
第二种工作流程的一个主要优势是,您只需在数据上运行一次基础模型,而不是在每个训练时期运行一次。因此它更快且更便宜。
但是,第二种工作流程的一个问题是,它不允许您在训练期间动态修改新模型的输入数据,例如,在进行数据增强时需要这样做。迁移学习通常用于您的新数据集数据量太少而无法从头开始训练一个完整模型的任务,在这种情况下,数据增强非常重要。因此,在接下来的内容中,我们将重点关注第一个工作流程。
以下是 Keras 中第一个工作流程的样子
首先,使用预训练权重实例化一个基础模型。
base_model = keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
然后,冻结基础模型。
base_model.trainable = False
在顶部创建一个新模型。
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
在新的数据上训练模型。
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
微调
一旦您的模型在新数据上收敛,您可以尝试解冻基础模型的全部或部分,并以非常低的学习率重新训练整个模型。
这是一个可选的最后一步,它可能会为您带来增量改进。它也可能导致快速过拟合 - 请记住这一点。
至关重要的是,只有在具有冻结层的模型训练到收敛后才执行此步骤。如果您将随机初始化的可训练层与保存预训练特征的可训练层混合,随机初始化的层将在训练期间导致非常大的梯度更新,这将破坏您的预训练特征。
在这个阶段使用非常低的学习率也很重要,因为您正在训练一个比第一轮训练大得多的模型,并且在通常非常小的数据集上进行训练。因此,如果您应用较大的权重更新,您有快速过拟合的风险。在这里,您只想以增量方式重新调整预训练权重。
这是如何实现整个基础模型的微调
# Unfreeze the base model
base_model.trainable = True
# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
关于 compile()
和 trainable
的重要说明
在模型上调用 compile()
意味着“冻结”该模型的行为。这意味着,在编译模型时 trainable
属性值应在该模型的整个生命周期内保留,直到再次调用 compile
。因此,如果您更改任何 trainable
值,请确保再次在您的模型上调用 compile()
,以使您的更改生效。
关于 BatchNormalization
层的重要说明
许多图像模型包含 BatchNormalization
层。该层在所有可想象的方面都是一个特例。以下是一些需要注意的事项。
BatchNormalization
包含 2 个在训练过程中更新的不可训练权重。这些是跟踪输入的均值和方差的变量。- 当您设置
bn_layer.trainable = False
时,BatchNormalization
层将以推理模式运行,并且不会更新其均值和方差统计信息。这与其他层通常的情况不同,因为 权重可训练性和推理/训练模式是两个正交的概念。但在BatchNormalization
层的情况下,两者是相关的。 - 当您解冻包含
BatchNormalization
层的模型以进行微调时,您应该通过在调用基础模型时传递training=False
来将BatchNormalization
层保持在推理模式。否则,应用于不可训练权重的更新将突然破坏模型迄今为止学到的内容。
您将在本指南末尾的端到端示例中看到这种模式的实际应用。
使用自定义训练循环进行迁移学习和微调
如果您使用的是自己的低级训练循环,而不是 fit()
,工作流程基本上保持不变。您应该注意,在应用梯度更新时,只考虑列表 model.trainable_weights
。
# Create base model
base_model = keras.applications.Xception(
weights='imagenet',
input_shape=(150, 150, 3),
include_top=False)
# Freeze base model
base_model.trainable = False
# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()
# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
# Open a GradientTape.
with tf.GradientTape() as tape:
# Forward pass.
predictions = model(inputs)
# Compute the loss value for this batch.
loss_value = loss_fn(targets, predictions)
# Get gradients of loss wrt the *trainable* weights.
gradients = tape.gradient(loss_value, model.trainable_weights)
# Update the weights of the model.
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
微调也是如此。
端到端示例:在猫狗数据集上微调图像分类模型
为了巩固这些概念,让我们带您完成一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并在 Kaggle 的“猫狗”分类数据集上使用它。
获取数据
首先,让我们使用 TFDS 获取猫狗数据集。如果您有自己的数据集,您可能希望使用实用程序 keras.utils.image_dataset_from_directory
从磁盘上的一组图像(按类别特定文件夹排列)生成类似的标记数据集对象。
当使用非常小的数据集时,迁移学习最有用。为了使我们的数据集保持较小,我们将使用原始训练数据(25,000 张图像)的 40% 用于训练,10% 用于验证,10% 用于测试。
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
"Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305 Number of validation samples: 2326 Number of test samples: 2326
这些是训练数据集中前 9 张图像 - 正如您所见,它们的大小各不相同。
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
我们还可以看到,标签 1 是“狗”,标签 0 是“猫”。
标准化数据
我们的原始图像具有各种尺寸。此外,每个像素都包含 3 个介于 0 到 255 之间的整数值(RGB 色阶值)。这不太适合馈送神经网络。我们需要做两件事
- 标准化为固定图像大小。我们选择 150x150。
- 将像素值归一化为 -1 到 1 之间。我们将使用模型本身中的
Normalization
层来完成此操作。
一般来说,开发以原始数据作为输入的模型是一个好习惯,而不是开发以预处理数据作为输入的模型。原因是,如果您的模型期望预处理数据,那么无论何时您导出模型以在其他地方使用它(在 Web 浏览器中,在移动应用程序中),您都需要重新实现完全相同的数据预处理管道。这很快就会变得非常棘手。因此,我们应该在到达模型之前进行尽可能少的数据预处理。
在这里,我们将在数据管道中进行图像调整大小(因为深度神经网络只能处理连续的数据批次),并且将在创建模型时作为模型的一部分进行输入值缩放。
让我们将图像调整大小为 150x150
size = (150, 150)
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
此外,让我们对数据进行批处理,并使用缓存和预取来优化加载速度。
batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)
使用随机数据增强
当您没有大型图像数据集时,最好通过对训练图像应用随机但现实的变换(例如随机水平翻转或小的随机旋转)来人为地引入样本多样性。这有助于模型接触到训练数据的不同方面,同时减缓过拟合。
from tensorflow import keras
from tensorflow.keras import layers
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
)
让我们可视化第一批中的第一张图像在各种随机变换后的样子
import numpy as np
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(
tf.expand_dims(first_image, 0), training=True
)
plt.imshow(augmented_image[0].numpy().astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
2023-10-03 11:11:16.151536: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
构建模型
现在让我们构建一个遵循我们之前解释的蓝图的模型。
请注意
- 我们添加了一个
Rescaling
层来将输入值(最初在[0, 255]
范围内)缩放到[-1, 1]
范围内。 - 我们在分类层之前添加了一个
Dropout
层,用于正则化。 - 我们确保在调用基础模型时传递
training=False
,以便它以推理模式运行,这样即使我们在解冻基础模型以进行微调后,批归一化统计信息也不会更新。
base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs) # Apply random data augmentation
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5 83683744/83683744 [==============================] - 0s 0us/step Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 150, 150, 3)] 0 sequential_3 (Sequential) (None, 150, 150, 3) 0 rescaling (Rescaling) (None, 150, 150, 3) 0 xception (Functional) (None, 5, 5, 2048) 20861480 global_average_pooling2d ( (None, 2048) 0 GlobalAveragePooling2D) dropout (Dropout) (None, 2048) 0 dense_7 (Dense) (None, 1) 2049 ================================================================= Total params: 20863529 (79.59 MB) Trainable params: 2049 (8.00 KB) Non-trainable params: 20861480 (79.58 MB) _________________________________________________________________
训练顶层
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20 142/291 [=============>................] - ETA: 35s - loss: 0.2211 - binary_accuracy: 0.8963 Corrupt JPEG data: 65 extraneous bytes before marker 0xd9 258/291 [=========================>....] - ETA: 7s - loss: 0.1831 - binary_accuracy: 0.9172 Corrupt JPEG data: 239 extraneous bytes before marker 0xd9 272/291 [===========================>..] - ETA: 4s - loss: 0.1797 - binary_accuracy: 0.9190 Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9 275/291 [===========================>..] - ETA: 3s - loss: 0.1789 - binary_accuracy: 0.9194 Corrupt JPEG data: 228 extraneous bytes before marker 0xd9 291/291 [==============================] - ETA: 0s - loss: 0.1771 - binary_accuracy: 0.9205 Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9 291/291 [==============================] - 89s 295ms/step - loss: 0.1771 - binary_accuracy: 0.9205 - val_loss: 0.0835 - val_binary_accuracy: 0.9652 Epoch 2/20 291/291 [==============================] - 85s 293ms/step - loss: 0.1197 - binary_accuracy: 0.9493 - val_loss: 0.0846 - val_binary_accuracy: 0.9699 Epoch 3/20 291/291 [==============================] - 83s 286ms/step - loss: 0.1135 - binary_accuracy: 0.9531 - val_loss: 0.0751 - val_binary_accuracy: 0.9708 Epoch 4/20 291/291 [==============================] - 83s 285ms/step - loss: 0.1037 - binary_accuracy: 0.9558 - val_loss: 0.0704 - val_binary_accuracy: 0.9712 Epoch 5/20 291/291 [==============================] - 83s 285ms/step - loss: 0.1024 - binary_accuracy: 0.9582 - val_loss: 0.0718 - val_binary_accuracy: 0.9733 Epoch 6/20 291/291 [==============================] - 83s 284ms/step - loss: 0.1006 - binary_accuracy: 0.9595 - val_loss: 0.0749 - val_binary_accuracy: 0.9699 Epoch 7/20 291/291 [==============================] - 83s 287ms/step - loss: 0.0961 - binary_accuracy: 0.9580 - val_loss: 0.0720 - val_binary_accuracy: 0.9699 Epoch 8/20 291/291 [==============================] - 83s 285ms/step - loss: 0.0952 - binary_accuracy: 0.9598 - val_loss: 0.0737 - val_binary_accuracy: 0.9712 Epoch 9/20 291/291 [==============================] - 83s 286ms/step - loss: 0.0984 - binary_accuracy: 0.9614 - val_loss: 0.0729 - val_binary_accuracy: 0.9708 Epoch 10/20 291/291 [==============================] - 83s 285ms/step - loss: 0.1007 - binary_accuracy: 0.9581 - val_loss: 0.0811 - val_binary_accuracy: 0.9686 Epoch 11/20 291/291 [==============================] - 83s 285ms/step - loss: 0.0960 - binary_accuracy: 0.9611 - val_loss: 0.0813 - val_binary_accuracy: 0.9703 Epoch 12/20 291/291 [==============================] - 83s 287ms/step - loss: 0.0950 - binary_accuracy: 0.9606 - val_loss: 0.0745 - val_binary_accuracy: 0.9703 Epoch 13/20 291/291 [==============================] - 84s 289ms/step - loss: 0.0970 - binary_accuracy: 0.9602 - val_loss: 0.0756 - val_binary_accuracy: 0.9703 Epoch 14/20 291/291 [==============================] - 83s 285ms/step - loss: 0.0915 - binary_accuracy: 0.9632 - val_loss: 0.0754 - val_binary_accuracy: 0.9690 Epoch 15/20 291/291 [==============================] - 84s 290ms/step - loss: 0.0938 - binary_accuracy: 0.9628 - val_loss: 0.0786 - val_binary_accuracy: 0.9682 Epoch 16/20 291/291 [==============================] - 82s 283ms/step - loss: 0.0958 - binary_accuracy: 0.9609 - val_loss: 0.0784 - val_binary_accuracy: 0.9682 Epoch 17/20 291/291 [==============================] - 83s 284ms/step - loss: 0.0907 - binary_accuracy: 0.9616 - val_loss: 0.0720 - val_binary_accuracy: 0.9721 Epoch 18/20 291/291 [==============================] - 83s 287ms/step - loss: 0.0946 - binary_accuracy: 0.9621 - val_loss: 0.0751 - val_binary_accuracy: 0.9716 Epoch 19/20 291/291 [==============================] - 83s 286ms/step - loss: 0.1004 - binary_accuracy: 0.9597 - val_loss: 0.0726 - val_binary_accuracy: 0.9703 Epoch 20/20 291/291 [==============================] - 82s 283ms/step - loss: 0.0891 - binary_accuracy: 0.9635 - val_loss: 0.0736 - val_binary_accuracy: 0.9712 <keras.src.callbacks.History at 0x7f701c5093a0>
对整个模型进行一轮微调
最后,让我们解冻基础模型,并以低学习率对整个模型进行端到端训练。
重要的是,尽管基础模型变得可训练,但它仍然以推理模式运行,因为我们在构建模型时调用它时传递了 training=False
。这意味着,内部的批归一化层不会更新其批统计信息。如果它们更新了,它们将对模型迄今为止学到的表示造成破坏。
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 150, 150, 3)] 0 sequential_3 (Sequential) (None, 150, 150, 3) 0 rescaling (Rescaling) (None, 150, 150, 3) 0 xception (Functional) (None, 5, 5, 2048) 20861480 global_average_pooling2d ( (None, 2048) 0 GlobalAveragePooling2D) dropout (Dropout) (None, 2048) 0 dense_7 (Dense) (None, 1) 2049 ================================================================= Total params: 20863529 (79.59 MB) Trainable params: 20809001 (79.38 MB) Non-trainable params: 54528 (213.00 KB) _________________________________________________________________ Epoch 1/10 291/291 [==============================] - 321s 1s/step - loss: 0.0751 - binary_accuracy: 0.9698 - val_loss: 0.0537 - val_binary_accuracy: 0.9781 Epoch 2/10 291/291 [==============================] - 302s 1s/step - loss: 0.0566 - binary_accuracy: 0.9787 - val_loss: 0.0490 - val_binary_accuracy: 0.9802 Epoch 3/10 291/291 [==============================] - 296s 1s/step - loss: 0.0455 - binary_accuracy: 0.9810 - val_loss: 0.0477 - val_binary_accuracy: 0.9794 Epoch 4/10 291/291 [==============================] - 289s 994ms/step - loss: 0.0351 - binary_accuracy: 0.9859 - val_loss: 0.0457 - val_binary_accuracy: 0.9789 Epoch 5/10 291/291 [==============================] - 289s 993ms/step - loss: 0.0268 - binary_accuracy: 0.9907 - val_loss: 0.0522 - val_binary_accuracy: 0.9794 Epoch 6/10 291/291 [==============================] - 288s 990ms/step - loss: 0.0258 - binary_accuracy: 0.9900 - val_loss: 0.0529 - val_binary_accuracy: 0.9789 Epoch 7/10 291/291 [==============================] - 286s 982ms/step - loss: 0.0209 - binary_accuracy: 0.9918 - val_loss: 0.0518 - val_binary_accuracy: 0.9776 Epoch 8/10 291/291 [==============================] - 289s 994ms/step - loss: 0.0185 - binary_accuracy: 0.9936 - val_loss: 0.0467 - val_binary_accuracy: 0.9832 Epoch 9/10 291/291 [==============================] - 289s 994ms/step - loss: 0.0150 - binary_accuracy: 0.9953 - val_loss: 0.0509 - val_binary_accuracy: 0.9802 Epoch 10/10 291/291 [==============================] - 292s 1s/step - loss: 0.0148 - binary_accuracy: 0.9952 - val_loss: 0.0501 - val_binary_accuracy: 0.9832 <keras.src.callbacks.History at 0x7f701c7e4f10>
经过 10 个时期后,微调在这里为我们带来了不错的改进。