在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
概述
本教程演示了数据增强:一种通过应用随机(但现实)变换(如图像旋转)来增加训练集多样性的技术。
您将学习如何通过两种方式应用数据增强
- 使用 Keras 预处理层,例如
tf.keras.layers.Resizing
、tf.keras.layers.Rescaling
、tf.keras.layers.RandomFlip
和tf.keras.layers.RandomRotation
。 - 使用
tf.image
方法,例如tf.image.flip_left_right
、tf.image.rgb_to_grayscale
、tf.image.adjust_brightness
、tf.image.central_crop
和tf.image.stateless_random*
。
设置
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
2024-02-15 02:21:03.189917: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-02-15 02:21:03.189965: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-02-15 02:21:03.191420: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
下载数据集
本教程使用 tf_flowers 数据集。为了方便起见,请使用 TensorFlow Datasets 下载数据集。如果您想了解其他导入数据的方法,请查看 加载图像 教程。
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
花卉数据集有五个类别。
num_classes = metadata.features['label'].num_classes
print(num_classes)
5
让我们从数据集中检索一张图像,并使用它来演示数据增强。
get_label_name = metadata.features['label'].int2str
image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2024-02-15 02:21:09.759464: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
使用 Keras 预处理层
调整大小和重新缩放
您可以使用 Keras 预处理层将图像调整为一致的形状(使用 tf.keras.layers.Resizing
),并将像素值重新缩放(使用 tf.keras.layers.Rescaling
)。
IMG_SIZE = 180
resize_and_rescale = tf.keras.Sequential([
layers.Resizing(IMG_SIZE, IMG_SIZE),
layers.Rescaling(1./255)
])
您可以可视化将这些层应用于图像的结果。
result = resize_and_rescale(image)
_ = plt.imshow(result)
验证像素是否在 [0, 1]
范围内
print("Min and max pixel values:", result.numpy().min(), result.numpy().max())
Min and max pixel values: 0.0 1.0
数据增强
您也可以使用 Keras 预处理层进行数据增强,例如 tf.keras.layers.RandomFlip
和 tf.keras.layers.RandomRotation
。
让我们创建一些预处理层,并将它们重复应用于同一图像。
data_augmentation = tf.keras.Sequential([
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.2),
])
# Add the image to a batch.
image = tf.cast(tf.expand_dims(image, 0), tf.float32)
plt.figure(figsize=(10, 10))
for i in range(9):
augmented_image = data_augmentation(image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_image[0])
plt.axis("off")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
您可以使用各种预处理层进行数据增强,包括 tf.keras.layers.RandomContrast
、tf.keras.layers.RandomCrop
、tf.keras.layers.RandomZoom
等。
使用 Keras 预处理层的两种选择
您可以通过两种方式使用这些预处理层,它们有重要的权衡。
选项 1:将预处理层作为模型的一部分
model = tf.keras.Sequential([
# Add the preprocessing layers you created earlier.
resize_and_rescale,
data_augmentation,
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# Rest of your model.
])
在这种情况下,需要注意两点
数据增强将在设备上运行,与您的其他层同步,并受益于 GPU 加速。
当您使用
model.save
导出模型时,预处理层将与您的其他模型一起保存。如果您稍后部署此模型,它将自动标准化图像(根据您的层的配置)。这可以节省您在服务器端重新实现该逻辑的努力。
选项 2:将预处理层应用于您的数据集
aug_ds = train_ds.map(
lambda x, y: (resize_and_rescale(x, training=True), y))
使用这种方法,您可以使用 Dataset.map
创建一个数据集,该数据集会生成增强图像的批次。在这种情况下
- 数据增强将在 CPU 上异步发生,并且是非阻塞的。您可以使用
Dataset.prefetch
(如下所示)将 GPU 上的模型训练与数据预处理重叠。 - 在这种情况下,当您调用
Model.save
时,预处理层不会与模型一起导出。您需要在保存之前将它们附加到模型,或者在服务器端重新实现它们。训练后,您可以在导出之前附加预处理层。
您可以在 图像分类 教程中找到第一个选项的示例。让我们在这里演示第二个选项。
将预处理层应用于数据集
使用您之前创建的 Keras 预处理层配置训练、验证和测试数据集。您还将配置数据集以提高性能,使用并行读取和缓冲预取,以从磁盘生成批次,而不会让 I/O 成为阻塞因素。(在 使用 tf.data API 提高性能 指南中了解有关数据集性能的更多信息。)
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE
def prepare(ds, shuffle=False, augment=False):
# Resize and rescale all datasets.
ds = ds.map(lambda x, y: (resize_and_rescale(x), y),
num_parallel_calls=AUTOTUNE)
if shuffle:
ds = ds.shuffle(1000)
# Batch all datasets.
ds = ds.batch(batch_size)
# Use data augmentation only on the training set.
if augment:
ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y),
num_parallel_calls=AUTOTUNE)
# Use buffered prefetching on all datasets.
return ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)
训练模型
为了完整起见,您现在将使用刚刚准备好的数据集训练模型。
Sequential 模型由三个卷积块(tf.keras.layers.Conv2D
)组成,每个卷积块中都有一个最大池化层(tf.keras.layers.MaxPooling2D
)。在它之上有一个具有 128 个单元的全连接层(tf.keras.layers.Dense
),它由 ReLU 激活函数('relu'
)激活。此模型尚未针对准确性进行调整(目标是向您展示机制)。
model = tf.keras.Sequential([
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
选择 tf.keras.optimizers.Adam
优化器和 tf.keras.losses.SparseCategoricalCrossentropy
损失函数。要查看每个训练纪元的训练和验证准确率,请将 metrics
参数传递给 Model.compile
。
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
训练几个纪元
epochs=5
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
Epoch 1/5 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1707963675.836253 10200 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 92/92 [==============================] - 10s 46ms/step - loss: 1.3130 - accuracy: 0.4380 - val_loss: 1.0730 - val_accuracy: 0.6213 Epoch 2/5 92/92 [==============================] - 2s 21ms/step - loss: 1.0332 - accuracy: 0.5988 - val_loss: 1.0585 - val_accuracy: 0.5613 Epoch 3/5 92/92 [==============================] - 2s 21ms/step - loss: 0.9444 - accuracy: 0.6264 - val_loss: 0.9728 - val_accuracy: 0.6104 Epoch 4/5 92/92 [==============================] - 2s 21ms/step - loss: 0.8922 - accuracy: 0.6451 - val_loss: 0.9163 - val_accuracy: 0.6431 Epoch 5/5 92/92 [==============================] - 2s 21ms/step - loss: 0.8407 - accuracy: 0.6638 - val_loss: 0.8458 - val_accuracy: 0.6785
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
12/12 [==============================] - 1s 7ms/step - loss: 0.8492 - accuracy: 0.6322 Accuracy 0.6321526169776917
自定义数据增强
您也可以创建自定义数据增强层。
本教程的这一部分展示了两种方法
- 首先,您将创建一个
tf.keras.layers.Lambda
层。这是编写简洁代码的好方法。 - 接下来,您将通过 子类化 编写一个新层,这将为您提供更多控制权。
这两个层都将根据一定的概率随机反转图像中的颜色。
def random_invert_img(x, p=0.5):
if tf.random.uniform([]) < p:
x = (255-x)
else:
x
return x
def random_invert(factor=0.5):
return layers.Lambda(lambda x: random_invert_img(x, factor))
random_invert = random_invert()
plt.figure(figsize=(10, 10))
for i in range(9):
augmented_image = random_invert(image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_image[0].numpy().astype("uint8"))
plt.axis("off")
接下来,通过 子类化 实现一个自定义层
class RandomInvert(layers.Layer):
def __init__(self, factor=0.5, **kwargs):
super().__init__(**kwargs)
self.factor = factor
def call(self, x):
return random_invert_img(x)
_ = plt.imshow(RandomInvert()(image)[0])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
这两个层都可以像上面选项 1 和 2 中描述的那样使用。
使用 tf.image
上面的 Keras 预处理实用程序很方便。但是,为了更精细的控制,您可以使用 tf.data
和 tf.image
编写自己的数据增强管道或层。(您可能还想查看 TensorFlow Addons Image:操作 和 TensorFlow I/O:颜色空间转换。)
由于花卉数据集之前已配置了数据增强,因此让我们重新导入它以从头开始
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
检索要使用的图像
image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2024-02-15 02:21:35.218784: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
让我们使用以下函数来可视化并比较原始图像和增强图像并排
def visualize(original, augmented):
fig = plt.figure()
plt.subplot(1,2,1)
plt.title('Original image')
plt.imshow(original)
plt.subplot(1,2,2)
plt.title('Augmented image')
plt.imshow(augmented)
数据增强
翻转图像
使用 tf.image.flip_left_right
垂直或水平翻转图像
flipped = tf.image.flip_left_right(image)
visualize(image, flipped)
将图像转换为灰度
您可以使用 tf.image.rgb_to_grayscale
将图像转换为灰度
grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()
饱和图像
通过提供饱和度因子,使用 tf.image.adjust_saturation
饱和图像
saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)
更改图像亮度
通过提供亮度因子,使用 tf.image.adjust_brightness
更改图像亮度
bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)
中心裁剪图像
使用 tf.image.central_crop
从中心裁剪图像,直到您想要的图像部分
cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image, cropped)
旋转图像
使用 tf.image.rot90
将图像旋转 90 度
rotated = tf.image.rot90(image)
visualize(image, rotated)
随机变换
将随机变换应用于图像可以进一步帮助泛化和扩展数据集。当前的 tf.image
API 提供了八种这样的随机图像操作(ops)
tf.image.stateless_random_brightness
tf.image.stateless_random_contrast
tf.image.stateless_random_crop
tf.image.stateless_random_flip_left_right
tf.image.stateless_random_flip_up_down
tf.image.stateless_random_hue
tf.image.stateless_random_jpeg_quality
tf.image.stateless_random_saturation
这些随机图像 ops 纯粹是函数式的:输出只取决于输入。这使得它们易于在高性能、确定性输入管道中使用。它们要求在每一步输入一个 seed
值。给定相同的 seed
,它们返回相同的结果,与调用次数无关。
在以下部分中,您将
- 回顾使用随机图像操作来变换图像的示例。
- 演示如何将随机变换应用于训练数据集。
随机更改图像亮度
使用 tf.image.stateless_random_brightness
通过提供亮度因子和 seed
来随机更改 image
的亮度。亮度因子是在 [-max_delta, max_delta)
范围内随机选择的,并且与给定的 seed
相关联。
for i in range(3):
seed = (i, 0) # tuple of size (2,)
stateless_random_brightness = tf.image.stateless_random_brightness(
image, max_delta=0.95, seed=seed)
visualize(image, stateless_random_brightness)
随机更改图像对比度
使用 tf.image.stateless_random_contrast
随机改变 image
的对比度,方法是提供一个对比度范围和 seed
。对比度范围在 [lower, upper]
区间内随机选择,并与给定的 seed
相关联。
for i in range(3):
seed = (i, 0) # tuple of size (2,)
stateless_random_contrast = tf.image.stateless_random_contrast(
image, lower=0.1, upper=0.9, seed=seed)
visualize(image, stateless_random_contrast)
随机裁剪图像
使用 tf.image.stateless_random_crop
随机裁剪 image
,方法是提供目标 size
和 seed
。从 image
中裁剪出来的部分位于随机选择的偏移量处,并与给定的 seed
相关联。
for i in range(3):
seed = (i, 0) # tuple of size (2,)
stateless_random_crop = tf.image.stateless_random_crop(
image, size=[210, 300, 3], seed=seed)
visualize(image, stateless_random_crop)
将增强应用于数据集
让我们首先再次下载图像数据集,以防它们在前面的部分中被修改。
(train_datasets, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
接下来,定义一个用于调整图像大小和重新缩放的实用函数。此函数将用于统一数据集中图像的大小和比例。
def resize_and_rescale(image, label):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
image = (image / 255.0)
return image, label
让我们也定义 augment
函数,该函数可以将随机变换应用于图像。此函数将在下一步中用于数据集。
def augment(image_label, seed):
image, label = image_label
image, label = resize_and_rescale(image, label)
image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
# Make a new seed.
new_seed = tf.random.split(seed, num=1)[0, :]
# Random crop back to the original size.
image = tf.image.stateless_random_crop(
image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
# Random brightness.
image = tf.image.stateless_random_brightness(
image, max_delta=0.5, seed=new_seed)
image = tf.clip_by_value(image, 0, 1)
return image, label
选项 1:使用 tf.data.experimental.Counter
创建一个 tf.data.experimental.Counter
对象(我们称之为 counter
),并使用 Dataset.zip
将数据集与 (counter, counter)
结合。这将确保数据集中的每个图像都与一个唯一的值(形状为 (2,)
)相关联,该值基于 counter
,该值稍后可以作为 seed
值传递到 augment
函数中,用于随机变换。
# Create a `Counter` object and `Dataset.zip` it together with the training set.
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10025/587852618.py:2: CounterV2 (from tensorflow.python.data.experimental.ops.counter) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.counter(...)` instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10025/587852618.py:2: CounterV2 (from tensorflow.python.data.experimental.ops.counter) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.counter(...)` instead.
将 augment
函数映射到训练数据集
train_ds = (
train_ds
.shuffle(1000)
.map(augment, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
val_ds = (
val_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
test_ds = (
test_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
选项 2:使用 tf.random.Generator
- 使用初始
seed
值创建一个tf.random.Generator
对象。在同一个生成器对象上调用make_seeds
函数始终返回一个新的、唯一的seed
值。 - 定义一个包装器函数,该函数:1) 调用
make_seeds
函数;2) 将新生成的seed
值传递到augment
函数中,用于随机变换。
# Create a generator.
rng = tf.random.Generator.from_seed(123, alg='philox')
# Create a wrapper function for updating seeds.
def f(x, y):
seed = rng.make_seeds(1)[:, 0]
image, label = augment((x, y), seed)
return image, label
将包装器函数 f
映射到训练数据集,并将 resize_and_rescale
函数映射到验证集和测试集
train_ds = (
train_datasets
.shuffle(1000)
.map(f, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
val_ds = (
val_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
test_ds = (
test_ds
.map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(AUTOTUNE)
)
现在可以使用这些数据集来训练模型,如前所述。
下一步
本教程演示了使用 Keras 预处理层和 tf.image
进行数据增强。