迁移学习与微调

作者: fchollet

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 在 keras.io 上查看

设置

import numpy as np
import tensorflow as tf
from tensorflow import keras
2023-10-03 11:11:08.160283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-03 11:11:08.160349: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-03 11:11:08.160404: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

简介

迁移学习是指将在一个问题上学习到的特征应用于另一个类似的新问题。例如,从学习识别浣熊的模型中获得的特征可能有助于启动一个旨在识别狸猫的模型。

迁移学习通常用于数据集数据量不足以从头开始训练完整模型的任务。

深度学习中迁移学习最常见的形式是以下工作流程

  1. 从先前训练的模型中获取层。
  2. 冻结它们,以避免在未来的训练回合中破坏它们包含的任何信息。
  3. 在冻结层之上添加一些新的可训练层。它们将学习将旧特征转换为对新数据集的预测。
  4. 在您的数据集上训练新层。

最后,可选步骤是微调,它包括解冻您上面获得的整个模型(或部分模型),并在新数据上以非常低的学习率重新训练它。这可以通过逐步调整预训练特征以适应新数据来实现有意义的改进。

首先,我们将详细介绍 Keras 的 trainable API,它构成了大多数迁移学习和微调工作流程的基础。

然后,我们将通过获取在 ImageNet 数据集上预训练的模型,并在 Kaggle 的“猫与狗”分类数据集上对其进行重新训练来演示典型的工作流程。

这是从 使用 Python 进行深度学习 和 2016 年的博客文章 “使用很少的数据构建强大的图像分类模型” 中改编而来。

冻结层:了解 trainable 属性

层和模型具有三个权重属性

  • weights 是层的权重变量列表。
  • trainable_weights 是指在训练过程中通过梯度下降法更新以最小化损失的权重列表。
  • non_trainable_weights 是指在训练过程中不会被更新的权重列表。通常它们会在模型的前向传播过程中被更新。

例如:Dense 层有两个可训练权重(kernel 和 bias)。

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0
2023-10-03 11:11:10.677246: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

一般来说,所有权重都是可训练权重。唯一具有不可训练权重的内置层是 BatchNormalization 层。它使用不可训练权重来跟踪训练过程中输入的均值和方差。要了解如何在自定义层中使用不可训练权重,请参阅 从头开始编写新层的指南

例如:BatchNormalization 层有两个可训练权重和两个不可训练权重。

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

层和模型还具有一个布尔属性 trainable。它的值可以更改。将 layer.trainable 设置为 False 会将所有层的权重从可训练权重移动到不可训练权重。这被称为“冻结”层:冻结层的状态在训练过程中不会更新(无论是使用 fit() 训练还是使用任何依赖于 trainable_weights 来应用梯度更新的自定义循环)。

例如:将 trainable 设置为 False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

当可训练权重变为不可训练权重时,它的值在训练过程中不再更新。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 324ms/step - loss: 0.0178

不要将 layer.trainable 属性与 layer.__call__() 中的参数 training 混淆(它控制层应该以推理模式还是训练模式运行其前向传播)。有关更多信息,请参阅 Keras 常见问题解答

递归设置 trainable 属性

如果您在模型或任何具有子层的层上设置 trainable = False,所有子层也将变为不可训练。

示例

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        inner_model,
        keras.layers.Dense(3, activation="sigmoid"),
    ]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的迁移学习工作流程

这将我们引向如何在 Keras 中实现典型的迁移学习工作流程。

  1. 实例化一个基础模型,并将预训练权重加载到其中。
  2. 通过设置 trainable = False 来冻结基础模型中的所有层。
  3. 在基础模型中一个(或多个)层的输出之上创建一个新模型。
  4. 在您的新数据集上训练您的新模型。

请注意,另一种更轻量级的流程也可以是

  1. 实例化一个基础模型,并将预训练权重加载到其中。
  2. 将您的新数据集通过它,并记录基础模型中一个(或多个)层的输出。这被称为 **特征提取**。
  3. 将该输出用作新的小模型的输入数据。

第二种工作流程的一个主要优势是,您只需在数据上运行一次基础模型,而不是在每个训练时期运行一次。因此它更快且更便宜。

但是,第二种工作流程的一个问题是,它不允许您在训练期间动态修改新模型的输入数据,例如,在进行数据增强时需要这样做。迁移学习通常用于您的新数据集数据量太少而无法从头开始训练一个完整模型的任务,在这种情况下,数据增强非常重要。因此,在接下来的内容中,我们将重点关注第一个工作流程。

以下是 Keras 中第一个工作流程的样子

首先,使用预训练权重实例化一个基础模型。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

然后,冻结基础模型。

base_model.trainable = False

在顶部创建一个新模型。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新的数据上训练模型。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微调

一旦您的模型在新数据上收敛,您可以尝试解冻基础模型的全部或部分,并以非常低的学习率重新训练整个模型。

这是一个可选的最后一步,它可能会为您带来增量改进。它也可能导致快速过拟合 - 请记住这一点。

至关重要的是,只有在具有冻结层的模型训练到收敛后才执行此步骤。如果您将随机初始化的可训练层与保存预训练特征的可训练层混合,随机初始化的层将在训练期间导致非常大的梯度更新,这将破坏您的预训练特征。

在这个阶段使用非常低的学习率也很重要,因为您正在训练一个比第一轮训练大得多的模型,并且在通常非常小的数据集上进行训练。因此,如果您应用较大的权重更新,您有快速过拟合的风险。在这里,您只想以增量方式重新调整预训练权重。

这是如何实现整个基础模型的微调

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于 compile()trainable 的重要说明

在模型上调用 compile() 意味着“冻结”该模型的行为。这意味着,在编译模型时 trainable 属性值应在该模型的整个生命周期内保留,直到再次调用 compile。因此,如果您更改任何 trainable 值,请确保再次在您的模型上调用 compile(),以使您的更改生效。

关于 BatchNormalization 层的重要说明

许多图像模型包含 BatchNormalization 层。该层在所有可想象的方面都是一个特例。以下是一些需要注意的事项。

  • BatchNormalization 包含 2 个在训练过程中更新的不可训练权重。这些是跟踪输入的均值和方差的变量。
  • 当您设置 bn_layer.trainable = False 时,BatchNormalization 层将以推理模式运行,并且不会更新其均值和方差统计信息。这与其他层通常的情况不同,因为 权重可训练性和推理/训练模式是两个正交的概念。但在 BatchNormalization 层的情况下,两者是相关的。
  • 当您解冻包含 BatchNormalization 层的模型以进行微调时,您应该通过在调用基础模型时传递 training=False 来将 BatchNormalization 层保持在推理模式。否则,应用于不可训练权重的更新将突然破坏模型迄今为止学到的内容。

您将在本指南末尾的端到端示例中看到这种模式的实际应用。

使用自定义训练循环进行迁移学习和微调

如果您使用的是自己的低级训练循环,而不是 fit(),工作流程基本上保持不变。您应该注意,在应用梯度更新时,只考虑列表 model.trainable_weights

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

微调也是如此。

端到端示例:在猫狗数据集上微调图像分类模型

为了巩固这些概念,让我们带您完成一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并在 Kaggle 的“猫狗”分类数据集上使用它。

获取数据

首先,让我们使用 TFDS 获取猫狗数据集。如果您有自己的数据集,您可能希望使用实用程序 keras.utils.image_dataset_from_directory 从磁盘上的一组图像(按类别特定文件夹排列)生成类似的标记数据集对象。

当使用非常小的数据集时,迁移学习最有用。为了使我们的数据集保持较小,我们将使用原始训练数据(25,000 张图像)的 40% 用于训练,10% 用于验证,10% 用于测试。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

这些是训练数据集中前 9 张图像 - 正如您所见,它们的大小各不相同。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

我们还可以看到,标签 1 是“狗”,标签 0 是“猫”。

标准化数据

我们的原始图像具有各种尺寸。此外,每个像素都包含 3 个介于 0 到 255 之间的整数值(RGB 色阶值)。这不太适合馈送神经网络。我们需要做两件事

  • 标准化为固定图像大小。我们选择 150x150。
  • 将像素值归一化为 -1 到 1 之间。我们将使用模型本身中的 Normalization 层来完成此操作。

一般来说,开发以原始数据作为输入的模型是一个好习惯,而不是开发以预处理数据作为输入的模型。原因是,如果您的模型期望预处理数据,那么无论何时您导出模型以在其他地方使用它(在 Web 浏览器中,在移动应用程序中),您都需要重新实现完全相同的数据预处理管道。这很快就会变得非常棘手。因此,我们应该在到达模型之前进行尽可能少的数据预处理。

在这里,我们将在数据管道中进行图像调整大小(因为深度神经网络只能处理连续的数据批次),并且将在创建模型时作为模型的一部分进行输入值缩放。

让我们将图像调整大小为 150x150

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

此外,让我们对数据进行批处理,并使用缓存和预取来优化加载速度。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

使用随机数据增强

当您没有大型图像数据集时,最好通过对训练图像应用随机但现实的变换(例如随机水平翻转或小的随机旋转)来人为地引入样本多样性。这有助于模型接触到训练数据的不同方面,同时减缓过拟合。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ]
)

让我们可视化第一批中的第一张图像在各种随机变换后的样子

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2023-10-03 11:11:16.151536: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

构建模型

现在让我们构建一个遵循我们之前解释的蓝图的模型。

请注意

  • 我们添加了一个 Rescaling 层来将输入值(最初在 [0, 255] 范围内)缩放到 [-1, 1] 范围内。
  • 我们在分类层之前添加了一个 Dropout 层,用于正则化。
  • 我们确保在调用基础模型时传递 training=False,以便它以推理模式运行,这样即使我们在解冻基础模型以进行微调后,批归一化统计信息也不会更新。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20863529 (79.59 MB)
Trainable params: 2049 (8.00 KB)
Non-trainable params: 20861480 (79.58 MB)
_________________________________________________________________

训练顶层

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
142/291 [=============>................] - ETA: 35s - loss: 0.2211 - binary_accuracy: 0.8963
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
258/291 [=========================>....] - ETA: 7s - loss: 0.1831 - binary_accuracy: 0.9172
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
272/291 [===========================>..] - ETA: 4s - loss: 0.1797 - binary_accuracy: 0.9190
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
275/291 [===========================>..] - ETA: 3s - loss: 0.1789 - binary_accuracy: 0.9194
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1771 - binary_accuracy: 0.9205
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 89s 295ms/step - loss: 0.1771 - binary_accuracy: 0.9205 - val_loss: 0.0835 - val_binary_accuracy: 0.9652
Epoch 2/20
291/291 [==============================] - 85s 293ms/step - loss: 0.1197 - binary_accuracy: 0.9493 - val_loss: 0.0846 - val_binary_accuracy: 0.9699
Epoch 3/20
291/291 [==============================] - 83s 286ms/step - loss: 0.1135 - binary_accuracy: 0.9531 - val_loss: 0.0751 - val_binary_accuracy: 0.9708
Epoch 4/20
291/291 [==============================] - 83s 285ms/step - loss: 0.1037 - binary_accuracy: 0.9558 - val_loss: 0.0704 - val_binary_accuracy: 0.9712
Epoch 5/20
291/291 [==============================] - 83s 285ms/step - loss: 0.1024 - binary_accuracy: 0.9582 - val_loss: 0.0718 - val_binary_accuracy: 0.9733
Epoch 6/20
291/291 [==============================] - 83s 284ms/step - loss: 0.1006 - binary_accuracy: 0.9595 - val_loss: 0.0749 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 83s 287ms/step - loss: 0.0961 - binary_accuracy: 0.9580 - val_loss: 0.0720 - val_binary_accuracy: 0.9699
Epoch 8/20
291/291 [==============================] - 83s 285ms/step - loss: 0.0952 - binary_accuracy: 0.9598 - val_loss: 0.0737 - val_binary_accuracy: 0.9712
Epoch 9/20
291/291 [==============================] - 83s 286ms/step - loss: 0.0984 - binary_accuracy: 0.9614 - val_loss: 0.0729 - val_binary_accuracy: 0.9708
Epoch 10/20
291/291 [==============================] - 83s 285ms/step - loss: 0.1007 - binary_accuracy: 0.9581 - val_loss: 0.0811 - val_binary_accuracy: 0.9686
Epoch 11/20
291/291 [==============================] - 83s 285ms/step - loss: 0.0960 - binary_accuracy: 0.9611 - val_loss: 0.0813 - val_binary_accuracy: 0.9703
Epoch 12/20
291/291 [==============================] - 83s 287ms/step - loss: 0.0950 - binary_accuracy: 0.9606 - val_loss: 0.0745 - val_binary_accuracy: 0.9703
Epoch 13/20
291/291 [==============================] - 84s 289ms/step - loss: 0.0970 - binary_accuracy: 0.9602 - val_loss: 0.0756 - val_binary_accuracy: 0.9703
Epoch 14/20
291/291 [==============================] - 83s 285ms/step - loss: 0.0915 - binary_accuracy: 0.9632 - val_loss: 0.0754 - val_binary_accuracy: 0.9690
Epoch 15/20
291/291 [==============================] - 84s 290ms/step - loss: 0.0938 - binary_accuracy: 0.9628 - val_loss: 0.0786 - val_binary_accuracy: 0.9682
Epoch 16/20
291/291 [==============================] - 82s 283ms/step - loss: 0.0958 - binary_accuracy: 0.9609 - val_loss: 0.0784 - val_binary_accuracy: 0.9682
Epoch 17/20
291/291 [==============================] - 83s 284ms/step - loss: 0.0907 - binary_accuracy: 0.9616 - val_loss: 0.0720 - val_binary_accuracy: 0.9721
Epoch 18/20
291/291 [==============================] - 83s 287ms/step - loss: 0.0946 - binary_accuracy: 0.9621 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 83s 286ms/step - loss: 0.1004 - binary_accuracy: 0.9597 - val_loss: 0.0726 - val_binary_accuracy: 0.9703
Epoch 20/20
291/291 [==============================] - 82s 283ms/step - loss: 0.0891 - binary_accuracy: 0.9635 - val_loss: 0.0736 - val_binary_accuracy: 0.9712
<keras.src.callbacks.History at 0x7f701c5093a0>

对整个模型进行一轮微调

最后,让我们解冻基础模型,并以低学习率对整个模型进行端到端训练。

重要的是,尽管基础模型变得可训练,但它仍然以推理模式运行,因为我们在构建模型时调用它时传递了 training=False。这意味着,内部的批归一化层不会更新其批统计信息。如果它们更新了,它们将对模型迄今为止学到的表示造成破坏。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (  (None, 2048)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20863529 (79.59 MB)
Trainable params: 20809001 (79.38 MB)
Non-trainable params: 54528 (213.00 KB)
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 321s 1s/step - loss: 0.0751 - binary_accuracy: 0.9698 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 2/10
291/291 [==============================] - 302s 1s/step - loss: 0.0566 - binary_accuracy: 0.9787 - val_loss: 0.0490 - val_binary_accuracy: 0.9802
Epoch 3/10
291/291 [==============================] - 296s 1s/step - loss: 0.0455 - binary_accuracy: 0.9810 - val_loss: 0.0477 - val_binary_accuracy: 0.9794
Epoch 4/10
291/291 [==============================] - 289s 994ms/step - loss: 0.0351 - binary_accuracy: 0.9859 - val_loss: 0.0457 - val_binary_accuracy: 0.9789
Epoch 5/10
291/291 [==============================] - 289s 993ms/step - loss: 0.0268 - binary_accuracy: 0.9907 - val_loss: 0.0522 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 288s 990ms/step - loss: 0.0258 - binary_accuracy: 0.9900 - val_loss: 0.0529 - val_binary_accuracy: 0.9789
Epoch 7/10
291/291 [==============================] - 286s 982ms/step - loss: 0.0209 - binary_accuracy: 0.9918 - val_loss: 0.0518 - val_binary_accuracy: 0.9776
Epoch 8/10
291/291 [==============================] - 289s 994ms/step - loss: 0.0185 - binary_accuracy: 0.9936 - val_loss: 0.0467 - val_binary_accuracy: 0.9832
Epoch 9/10
291/291 [==============================] - 289s 994ms/step - loss: 0.0150 - binary_accuracy: 0.9953 - val_loss: 0.0509 - val_binary_accuracy: 0.9802
Epoch 10/10
291/291 [==============================] - 292s 1s/step - loss: 0.0148 - binary_accuracy: 0.9952 - val_loss: 0.0501 - val_binary_accuracy: 0.9832
<keras.src.callbacks.History at 0x7f701c7e4f10>

经过 10 个时期后,微调在这里为我们带来了不错的改进。