加载和预处理图像

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本教程展示了如何以三种方式加载和预处理图像数据集

设置

import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds
2024-07-13 05:34:47.523839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-13 05:34:47.550016: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-13 05:34:47.550056: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(tf.__version__)
2.16.2

下载花卉数据集

本教程使用包含数千张花卉照片的数据集。该花卉数据集包含五个子目录,每个子目录对应一个类别。

flowers_photos/
  daisy/
  dandelion/
  roses/
  sunflowers/
  tulips/
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
archive = tf.keras.utils.get_file(origin=dataset_url, extract=True)
data_dir = pathlib.Path(archive).with_suffix('')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228813984/228813984 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step

下载(218MB)后,您应该现在拥有花卉照片的副本。共有 3,670 张图像。

image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
3670

每个目录包含该类型花的图像。这里有一些玫瑰。

roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[0]))

png

roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[1]))

png

使用 Keras 实用程序加载数据

让我们使用有用的 tf.keras.utils.image_dataset_from_directory 实用程序从磁盘加载这些图像。

创建数据集

为加载器定义一些参数

batch_size = 32
img_height = 180
img_width = 180

在开发模型时使用验证拆分是一种良好的做法。您将使用 80% 的图像进行训练,20% 的图像进行验证。

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
Found 3670 files belonging to 5 classes.
Using 734 files for validation.

您可以在这些数据集的 class_names 属性中找到类名。

class_names = train_ds.class_names
print(class_names)
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

可视化数据

以下是训练数据集中的前九张图像。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")
2024-07-13 05:34:57.652132: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

png

您可以通过将这些数据集传递给 model.fit(在本教程的后面部分显示)来训练模型。如果您愿意,您也可以手动遍历数据集并检索图像批次。

for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break
(32, 180, 180, 3)
(32,)

image_batch 是形状为 (32, 180, 180, 3) 的张量。这是一个包含 32 张形状为 180x180x3 的图像的批次(最后一个维度指的是颜色通道 RGB)。label_batch 是形状为 (32,) 的张量,这些是与 32 张图像相对应的标签。

您可以对这两个张量中的任何一个调用 .numpy() 以将其转换为 numpy.ndarray

标准化数据

RGB 通道值在 [0, 255] 范围内。这对于神经网络来说并不理想;通常您应该努力使您的输入值变小。

在这里,您将使用 tf.keras.layers.Rescaling 将值标准化为 [0, 1] 范围。

normalization_layer = tf.keras.layers.Rescaling(1./255)

使用此层有两种方法。您可以通过调用 Dataset.map 将其应用于数据集。

normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixel values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))
0.0 0.96902645

或者,您可以在模型定义中包含该层以简化部署。您将在这里使用第二种方法。

配置数据集以提高性能

让我们确保使用缓冲预取,这样您就可以从磁盘生成数据,而不会让 I/O 成为阻塞。在加载数据时,您应该使用这两种重要方法。

  • Dataset.cache 在第一个时期从磁盘加载图像后将其保留在内存中。这将确保数据集在训练模型时不会成为瓶颈。如果您的数据集太大而无法放入内存,您也可以使用此方法来创建高效的磁盘缓存。
  • Dataset.prefetch 在训练期间重叠数据预处理和模型执行。

感兴趣的读者可以在 使用 tf.data API 提高性能 指南的“预取”部分中了解有关这两种方法以及如何将数据缓存到磁盘的更多信息。

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

训练模型

为了完整起见,我们将展示如何使用您刚刚准备好的数据集来训练一个简单的模型。

顺序 模型由三个卷积块(tf.keras.layers.Conv2D)组成,每个卷积块中都有一个最大池化层(tf.keras.layers.MaxPooling2D)。在它上面有一个具有 128 个单元的全连接层(tf.keras.layers.Dense),该层由 ReLU 激活函数('relu')激活。此模型未以任何方式进行调整——目标是向您展示使用您刚刚创建的数据集的机制。要了解有关图像分类的更多信息,请访问 图像分类 教程。

num_classes = 5

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(num_classes)
])

选择 tf.keras.optimizers.Adam 优化器和 tf.keras.losses.SparseCategoricalCrossentropy 损失函数。要查看每个训练时期的训练和验证准确率,请将 metrics 参数传递给 Model.compile

model.compile(
  optimizer='adam',
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['accuracy'])
model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=3
)
Epoch 1/3
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720848900.011329  457136 service.cc:145] XLA service 0x7f2c68006430 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1720848900.011431  457136 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1720848900.011437  457136 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1720848900.011441  457136 service.cc:153]   StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1720848900.011446  457136 service.cc:153]   StreamExecutor device (3): Tesla T4, Compute Capability 7.5
10/92 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - accuracy: 0.1843 - loss: 1.7662
I0000 00:00:1720848903.023839  457136 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
92/92 ━━━━━━━━━━━━━━━━━━━━ 11s 69ms/step - accuracy: 0.3327 - loss: 1.4697 - val_accuracy: 0.5668 - val_loss: 1.0640
Epoch 2/3
92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - accuracy: 0.6000 - loss: 1.0092 - val_accuracy: 0.5790 - val_loss: 1.0287
Epoch 3/3
92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - accuracy: 0.6545 - loss: 0.8629 - val_accuracy: 0.6144 - val_loss: 0.9710
<keras.src.callbacks.history.History at 0x7f2dec3897f0>

您可能会注意到验证准确率与训练准确率相比很低,这表明您的模型过度拟合。您可以在此 教程 中了解有关过度拟合以及如何减少过度拟合的更多信息。

使用 tf.data 进行更精细的控制

上述 Keras 预处理实用程序——tf.keras.utils.image_dataset_from_directory——是使用图像目录创建 tf.data.Dataset 的便捷方法。

为了更精细的控制,您可以使用 tf.data 编写自己的输入管道。本节将展示如何做到这一点,从您之前下载的 TGZ 文件中的文件路径开始。

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/roses/102501987_3cdb8e5394_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6050020905_881295ac72_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6061177447_d8ce96aee0.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/4325834819_ab56661dcc_m.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/19413898445_69344f9956_n.jpg'
2024-07-13 05:35:12.764905: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

文件的树结构可用于编译 class_names 列表。

class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]))
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']

将数据集拆分为训练集和验证集

val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)

您可以按如下方式打印每个数据集的长度。

print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())
2936
734

编写一个将文件路径转换为 (img, label) 对的简短函数。

def get_label(file_path):
  # Convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  one_hot = parts[-2] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)
def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width])
def process_path(file_path):
  label = get_label(file_path)
  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

使用 Dataset.map 创建一个包含 image, label 对的数据集。

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in train_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())
Image shape:  (180, 180, 3)
Label:  1
2024-07-13 05:35:13.038820: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

配置数据集以提高性能

要使用此数据集训练模型,您将需要数据

  • 充分洗牌。
  • 进行批处理。
  • 批次尽快可用。

可以使用 tf.data API 添加这些功能。有关更多详细信息,请访问 输入管道性能 指南。

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds

train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)

可视化数据

您可以像以前一样可视化此数据集。

image_batch, label_batch = next(iter(train_ds))

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].numpy().astype("uint8"))
  label = label_batch[i]
  plt.title(class_names[label])
  plt.axis("off")
2024-07-13 05:35:13.274884: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

继续训练模型

您现在已经手动构建了一个类似于上面 tf.keras.utils.image_dataset_from_directory 创建的 tf.data.Dataset。您可以继续使用它训练模型。与之前一样,您将只训练几个时期以缩短运行时间。

model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=3
)
Epoch 1/3
92/92 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step - accuracy: 0.6975 - loss: 0.7811 - val_accuracy: 0.7112 - val_loss: 0.7230
Epoch 2/3
92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 22ms/step - accuracy: 0.7734 - loss: 0.5998 - val_accuracy: 0.7262 - val_loss: 0.6824
Epoch 3/3
92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 22ms/step - accuracy: 0.8561 - loss: 0.4123 - val_accuracy: 0.7425 - val_loss: 0.6843
<keras.src.callbacks.history.History at 0x7f2dcc1aa820>

使用 TensorFlow 数据集

到目前为止,本教程一直专注于从磁盘加载数据。您还可以通过探索 大型目录 中易于下载的数据集来找到要使用的数据集,该目录位于 TensorFlow 数据集 中。

由于您之前已经从磁盘加载了 Flowers 数据集,因此现在让我们使用 TensorFlow 数据集导入它。

使用 TensorFlow 数据集下载 Flowers 数据集

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

Flowers 数据集有五个类别。

num_classes = metadata.features['label'].num_classes
print(num_classes)
5

从数据集中检索图像。

get_label_name = metadata.features['label'].int2str

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2024-07-13 05:35:25.302143: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

与之前一样,请记住对训练集、验证集和测试集进行批处理、洗牌和配置以提高性能。

train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
test_ds = configure_for_performance(test_ds)

您可以在 数据增强 教程中找到使用 Flowers 数据集和 TensorFlow 数据集的完整示例。

后续步骤

本教程展示了两种从磁盘加载图像的方法。首先,您学习了如何使用 Keras 预处理层和实用程序加载和预处理图像数据集。接下来,您学习了如何使用 tf.data 从头开始编写输入管道。最后,您学习了如何从 TensorFlow 数据集下载数据集。

下一步