数据增强

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

本教程演示了数据增强:一种通过应用随机(但现实)变换(如图像旋转)来增加训练集多样性的技术。

您将学习如何通过两种方式应用数据增强

设置

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras import layers
2024-02-15 02:21:03.189917: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-15 02:21:03.189965: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-15 02:21:03.191420: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

下载数据集

本教程使用 tf_flowers 数据集。为了方便起见,请使用 TensorFlow Datasets 下载数据集。如果您想了解其他导入数据的方法,请查看 加载图像 教程。

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

花卉数据集有五个类别。

num_classes = metadata.features['label'].num_classes
print(num_classes)
5

让我们从数据集中检索一张图像,并使用它来演示数据增强。

get_label_name = metadata.features['label'].int2str

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2024-02-15 02:21:09.759464: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

使用 Keras 预处理层

调整大小和重新缩放

您可以使用 Keras 预处理层将图像调整为一致的形状(使用 tf.keras.layers.Resizing),并将像素值重新缩放(使用 tf.keras.layers.Rescaling)。

IMG_SIZE = 180

resize_and_rescale = tf.keras.Sequential([
  layers.Resizing(IMG_SIZE, IMG_SIZE),
  layers.Rescaling(1./255)
])

您可以可视化将这些层应用于图像的结果。

result = resize_and_rescale(image)
_ = plt.imshow(result)

png

验证像素是否在 [0, 1] 范围内

print("Min and max pixel values:", result.numpy().min(), result.numpy().max())
Min and max pixel values: 0.0 1.0

数据增强

您也可以使用 Keras 预处理层进行数据增强,例如 tf.keras.layers.RandomFliptf.keras.layers.RandomRotation

让我们创建一些预处理层,并将它们重复应用于同一图像。

data_augmentation = tf.keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.2),
])
# Add the image to a batch.
image = tf.cast(tf.expand_dims(image, 0), tf.float32)
plt.figure(figsize=(10, 10))
for i in range(9):
  augmented_image = data_augmentation(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0])
  plt.axis("off")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

您可以使用各种预处理层进行数据增强,包括 tf.keras.layers.RandomContrasttf.keras.layers.RandomCroptf.keras.layers.RandomZoom 等。

使用 Keras 预处理层的两种选择

您可以通过两种方式使用这些预处理层,它们有重要的权衡。

选项 1:将预处理层作为模型的一部分

model = tf.keras.Sequential([
  # Add the preprocessing layers you created earlier.
  resize_and_rescale,
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  # Rest of your model.
])

在这种情况下,需要注意两点

  • 数据增强将在设备上运行,与您的其他层同步,并受益于 GPU 加速。

  • 当您使用 model.save 导出模型时,预处理层将与您的其他模型一起保存。如果您稍后部署此模型,它将自动标准化图像(根据您的层的配置)。这可以节省您在服务器端重新实现该逻辑的努力。

选项 2:将预处理层应用于您的数据集

aug_ds = train_ds.map(
  lambda x, y: (resize_and_rescale(x, training=True), y))

使用这种方法,您可以使用 Dataset.map 创建一个数据集,该数据集会生成增强图像的批次。在这种情况下

  • 数据增强将在 CPU 上异步发生,并且是非阻塞的。您可以使用 Dataset.prefetch(如下所示)将 GPU 上的模型训练与数据预处理重叠。
  • 在这种情况下,当您调用 Model.save 时,预处理层不会与模型一起导出。您需要在保存之前将它们附加到模型,或者在服务器端重新实现它们。训练后,您可以在导出之前附加预处理层。

您可以在 图像分类 教程中找到第一个选项的示例。让我们在这里演示第二个选项。

将预处理层应用于数据集

使用您之前创建的 Keras 预处理层配置训练、验证和测试数据集。您还将配置数据集以提高性能,使用并行读取和缓冲预取,以从磁盘生成批次,而不会让 I/O 成为阻塞因素。(在 使用 tf.data API 提高性能 指南中了解有关数据集性能的更多信息。)

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
  # Resize and rescale all datasets.
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y), 
              num_parallel_calls=AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

  # Batch all datasets.
  ds = ds.batch(batch_size)

  # Use data augmentation only on the training set.
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                num_parallel_calls=AUTOTUNE)

  # Use buffered prefetching on all datasets.
  return ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

训练模型

为了完整起见,您现在将使用刚刚准备好的数据集训练模型。

Sequential 模型由三个卷积块(tf.keras.layers.Conv2D)组成,每个卷积块中都有一个最大池化层(tf.keras.layers.MaxPooling2D)。在它之上有一个具有 128 个单元的全连接层(tf.keras.layers.Dense),它由 ReLU 激活函数('relu')激活。此模型尚未针对准确性进行调整(目标是向您展示机制)。

model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

选择 tf.keras.optimizers.Adam 优化器和 tf.keras.losses.SparseCategoricalCrossentropy 损失函数。要查看每个训练纪元的训练和验证准确率,请将 metrics 参数传递给 Model.compile

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

训练几个纪元

epochs=5
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707963675.836253   10200 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
92/92 [==============================] - 10s 46ms/step - loss: 1.3130 - accuracy: 0.4380 - val_loss: 1.0730 - val_accuracy: 0.6213
Epoch 2/5
92/92 [==============================] - 2s 21ms/step - loss: 1.0332 - accuracy: 0.5988 - val_loss: 1.0585 - val_accuracy: 0.5613
Epoch 3/5
92/92 [==============================] - 2s 21ms/step - loss: 0.9444 - accuracy: 0.6264 - val_loss: 0.9728 - val_accuracy: 0.6104
Epoch 4/5
92/92 [==============================] - 2s 21ms/step - loss: 0.8922 - accuracy: 0.6451 - val_loss: 0.9163 - val_accuracy: 0.6431
Epoch 5/5
92/92 [==============================] - 2s 21ms/step - loss: 0.8407 - accuracy: 0.6638 - val_loss: 0.8458 - val_accuracy: 0.6785
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
12/12 [==============================] - 1s 7ms/step - loss: 0.8492 - accuracy: 0.6322
Accuracy 0.6321526169776917

自定义数据增强

您也可以创建自定义数据增强层。

本教程的这一部分展示了两种方法

  • 首先,您将创建一个 tf.keras.layers.Lambda 层。这是编写简洁代码的好方法。
  • 接下来,您将通过 子类化 编写一个新层,这将为您提供更多控制权。

这两个层都将根据一定的概率随机反转图像中的颜色。

def random_invert_img(x, p=0.5):
  if  tf.random.uniform([]) < p:
    x = (255-x)
  else:
    x
  return x
def random_invert(factor=0.5):
  return layers.Lambda(lambda x: random_invert_img(x, factor))

random_invert = random_invert()
plt.figure(figsize=(10, 10))
for i in range(9):
  augmented_image = random_invert(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0].numpy().astype("uint8"))
  plt.axis("off")

png

接下来,通过 子类化 实现一个自定义层

class RandomInvert(layers.Layer):
  def __init__(self, factor=0.5, **kwargs):
    super().__init__(**kwargs)
    self.factor = factor

  def call(self, x):
    return random_invert_img(x)
_ = plt.imshow(RandomInvert()(image)[0])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

这两个层都可以像上面选项 1 和 2 中描述的那样使用。

使用 tf.image

上面的 Keras 预处理实用程序很方便。但是,为了更精细的控制,您可以使用 tf.datatf.image 编写自己的数据增强管道或层。(您可能还想查看 TensorFlow Addons Image:操作TensorFlow I/O:颜色空间转换。)

由于花卉数据集之前已配置了数据增强,因此让我们重新导入它以从头开始

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

检索要使用的图像

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2024-02-15 02:21:35.218784: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

让我们使用以下函数来可视化并比较原始图像和增强图像并排

def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1,2,1)
  plt.title('Original image')
  plt.imshow(original)

  plt.subplot(1,2,2)
  plt.title('Augmented image')
  plt.imshow(augmented)

数据增强

翻转图像

使用 tf.image.flip_left_right 垂直或水平翻转图像

flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

png

将图像转换为灰度

您可以使用 tf.image.rgb_to_grayscale 将图像转换为灰度

grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()

png

饱和图像

通过提供饱和度因子,使用 tf.image.adjust_saturation 饱和图像

saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

png

更改图像亮度

通过提供亮度因子,使用 tf.image.adjust_brightness 更改图像亮度

bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)

png

中心裁剪图像

使用 tf.image.central_crop 从中心裁剪图像,直到您想要的图像部分

cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image, cropped)

png

旋转图像

使用 tf.image.rot90 将图像旋转 90 度

rotated = tf.image.rot90(image)
visualize(image, rotated)

png

随机变换

将随机变换应用于图像可以进一步帮助泛化和扩展数据集。当前的 tf.image API 提供了八种这样的随机图像操作(ops)

这些随机图像 ops 纯粹是函数式的:输出只取决于输入。这使得它们易于在高性能、确定性输入管道中使用。它们要求在每一步输入一个 seed 值。给定相同的 seed,它们返回相同的结果,与调用次数无关。

在以下部分中,您将

  1. 回顾使用随机图像操作来变换图像的示例。
  2. 演示如何将随机变换应用于训练数据集。

随机更改图像亮度

使用 tf.image.stateless_random_brightness 通过提供亮度因子和 seed 来随机更改 image 的亮度。亮度因子是在 [-max_delta, max_delta) 范围内随机选择的,并且与给定的 seed 相关联。

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_brightness = tf.image.stateless_random_brightness(
      image, max_delta=0.95, seed=seed)
  visualize(image, stateless_random_brightness)

png

png

png

随机更改图像对比度

使用 tf.image.stateless_random_contrast 随机改变 image 的对比度,方法是提供一个对比度范围和 seed。对比度范围在 [lower, upper] 区间内随机选择,并与给定的 seed 相关联。

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_contrast = tf.image.stateless_random_contrast(
      image, lower=0.1, upper=0.9, seed=seed)
  visualize(image, stateless_random_contrast)

png

png

png

随机裁剪图像

使用 tf.image.stateless_random_crop 随机裁剪 image,方法是提供目标 sizeseed。从 image 中裁剪出来的部分位于随机选择的偏移量处,并与给定的 seed 相关联。

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_crop = tf.image.stateless_random_crop(
      image, size=[210, 300, 3], seed=seed)
  visualize(image, stateless_random_crop)

png

png

png

将增强应用于数据集

让我们首先再次下载图像数据集,以防它们在前面的部分中被修改。

(train_datasets, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

接下来,定义一个用于调整图像大小和重新缩放的实用函数。此函数将用于统一数据集中图像的大小和比例。

def resize_and_rescale(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
  image = (image / 255.0)
  return image, label

让我们也定义 augment 函数,该函数可以将随机变换应用于图像。此函数将在下一步中用于数据集。

def augment(image_label, seed):
  image, label = image_label
  image, label = resize_and_rescale(image, label)
  image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
  # Make a new seed.
  new_seed = tf.random.split(seed, num=1)[0, :]
  # Random crop back to the original size.
  image = tf.image.stateless_random_crop(
      image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
  # Random brightness.
  image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
  image = tf.clip_by_value(image, 0, 1)
  return image, label

选项 1:使用 tf.data.experimental.Counter

创建一个 tf.data.experimental.Counter 对象(我们称之为 counter),并使用 Dataset.zip 将数据集与 (counter, counter) 结合。这将确保数据集中的每个图像都与一个唯一的值(形状为 (2,))相关联,该值基于 counter,该值稍后可以作为 seed 值传递到 augment 函数中,用于随机变换。

# Create a `Counter` object and `Dataset.zip` it together with the training set.
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10025/587852618.py:2: CounterV2 (from tensorflow.python.data.experimental.ops.counter) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10025/587852618.py:2: CounterV2 (from tensorflow.python.data.experimental.ops.counter) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.

augment 函数映射到训练数据集

train_ds = (
    train_ds
    .shuffle(1000)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
val_ds = (
    val_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
test_ds = (
    test_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

选项 2:使用 tf.random.Generator

  • 使用初始 seed 值创建一个 tf.random.Generator 对象。在同一个生成器对象上调用 make_seeds 函数始终返回一个新的、唯一的 seed 值。
  • 定义一个包装器函数,该函数:1) 调用 make_seeds 函数;2) 将新生成的 seed 值传递到 augment 函数中,用于随机变换。
# Create a generator.
rng = tf.random.Generator.from_seed(123, alg='philox')
# Create a wrapper function for updating seeds.
def f(x, y):
  seed = rng.make_seeds(1)[:, 0]
  image, label = augment((x, y), seed)
  return image, label

将包装器函数 f 映射到训练数据集,并将 resize_and_rescale 函数映射到验证集和测试集

train_ds = (
    train_datasets
    .shuffle(1000)
    .map(f, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
val_ds = (
    val_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
test_ds = (
    test_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

现在可以使用这些数据集来训练模型,如前所述。

下一步

本教程演示了使用 Keras 预处理层和 tf.image 进行数据增强。

  • 要了解如何在模型中包含预处理层,请参阅 图像分类 教程。
  • 您可能还对了解预处理层如何帮助您对文本进行分类感兴趣,如 基本文本分类 教程中所示。
  • 您可以在此 指南 中了解有关 tf.data 的更多信息,并且您可以了解如何 此处 配置您的输入管道以提高性能。