在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
本教程展示了如何以三种方式加载和预处理图像数据集
- 首先,您将使用高级 Keras 预处理实用程序(例如
tf.keras.utils.image_dataset_from_directory
)和层(例如tf.keras.layers.Rescaling
)来读取磁盘上的图像目录。 - 接下来,您将从头开始编写自己的输入管道 使用 tf.data.
- 最后,您将从 大型目录 中下载数据集,该目录在 TensorFlow Datasets 中可用。
设置
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds
2024-07-13 05:34:47.523839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-07-13 05:34:47.550016: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-07-13 05:34:47.550056: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(tf.__version__)
2.16.2
下载花卉数据集
本教程使用包含数千张花卉照片的数据集。该花卉数据集包含五个子目录,每个子目录对应一个类别。
flowers_photos/
daisy/
dandelion/
roses/
sunflowers/
tulips/
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
archive = tf.keras.utils.get_file(origin=dataset_url, extract=True)
data_dir = pathlib.Path(archive).with_suffix('')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228813984/228813984 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
下载(218MB)后,您应该现在拥有花卉照片的副本。共有 3,670 张图像。
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
3670
每个目录包含该类型花的图像。这里有一些玫瑰。
roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[0]))
roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[1]))
使用 Keras 实用程序加载数据
让我们使用有用的 tf.keras.utils.image_dataset_from_directory
实用程序从磁盘加载这些图像。
创建数据集
为加载器定义一些参数
batch_size = 32
img_height = 180
img_width = 180
在开发模型时使用验证拆分是一种良好的做法。您将使用 80% 的图像进行训练,20% 的图像进行验证。
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
Found 3670 files belonging to 5 classes. Using 2936 files for training.
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
Found 3670 files belonging to 5 classes. Using 734 files for validation.
您可以在这些数据集的 class_names
属性中找到类名。
class_names = train_ds.class_names
print(class_names)
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
可视化数据
以下是训练数据集中的前九张图像。
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
2024-07-13 05:34:57.652132: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
您可以通过将这些数据集传递给 model.fit
(在本教程的后面部分显示)来训练模型。如果您愿意,您也可以手动遍历数据集并检索图像批次。
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
(32, 180, 180, 3) (32,)
image_batch
是形状为 (32, 180, 180, 3)
的张量。这是一个包含 32 张形状为 180x180x3
的图像的批次(最后一个维度指的是颜色通道 RGB)。label_batch
是形状为 (32,)
的张量,这些是与 32 张图像相对应的标签。
您可以对这两个张量中的任何一个调用 .numpy()
以将其转换为 numpy.ndarray
。
标准化数据
RGB 通道值在 [0, 255]
范围内。这对于神经网络来说并不理想;通常您应该努力使您的输入值变小。
在这里,您将使用 tf.keras.layers.Rescaling
将值标准化为 [0, 1]
范围。
normalization_layer = tf.keras.layers.Rescaling(1./255)
使用此层有两种方法。您可以通过调用 Dataset.map
将其应用于数据集。
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixel values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))
0.0 0.96902645
或者,您可以在模型定义中包含该层以简化部署。您将在这里使用第二种方法。
配置数据集以提高性能
让我们确保使用缓冲预取,这样您就可以从磁盘生成数据,而不会让 I/O 成为阻塞。在加载数据时,您应该使用这两种重要方法。
Dataset.cache
在第一个时期从磁盘加载图像后将其保留在内存中。这将确保数据集在训练模型时不会成为瓶颈。如果您的数据集太大而无法放入内存,您也可以使用此方法来创建高效的磁盘缓存。Dataset.prefetch
在训练期间重叠数据预处理和模型执行。
感兴趣的读者可以在 使用 tf.data API 提高性能 指南的“预取”部分中了解有关这两种方法以及如何将数据缓存到磁盘的更多信息。
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
训练模型
为了完整起见,我们将展示如何使用您刚刚准备好的数据集来训练一个简单的模型。
顺序 模型由三个卷积块(tf.keras.layers.Conv2D
)组成,每个卷积块中都有一个最大池化层(tf.keras.layers.MaxPooling2D
)。在它上面有一个具有 128 个单元的全连接层(tf.keras.layers.Dense
),该层由 ReLU 激活函数('relu'
)激活。此模型未以任何方式进行调整——目标是向您展示使用您刚刚创建的数据集的机制。要了解有关图像分类的更多信息,请访问 图像分类 教程。
num_classes = 5
model = tf.keras.Sequential([
tf.keras.layers.Rescaling(1./255),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(num_classes)
])
选择 tf.keras.optimizers.Adam
优化器和 tf.keras.losses.SparseCategoricalCrossentropy
损失函数。要查看每个训练时期的训练和验证准确率,请将 metrics
参数传递给 Model.compile
。
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_ds,
validation_data=val_ds,
epochs=3
)
Epoch 1/3 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720848900.011329 457136 service.cc:145] XLA service 0x7f2c68006430 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1720848900.011431 457136 service.cc:153] StreamExecutor device (0): Tesla T4, Compute Capability 7.5 I0000 00:00:1720848900.011437 457136 service.cc:153] StreamExecutor device (1): Tesla T4, Compute Capability 7.5 I0000 00:00:1720848900.011441 457136 service.cc:153] StreamExecutor device (2): Tesla T4, Compute Capability 7.5 I0000 00:00:1720848900.011446 457136 service.cc:153] StreamExecutor device (3): Tesla T4, Compute Capability 7.5 10/92 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - accuracy: 0.1843 - loss: 1.7662 I0000 00:00:1720848903.023839 457136 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 92/92 ━━━━━━━━━━━━━━━━━━━━ 11s 69ms/step - accuracy: 0.3327 - loss: 1.4697 - val_accuracy: 0.5668 - val_loss: 1.0640 Epoch 2/3 92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - accuracy: 0.6000 - loss: 1.0092 - val_accuracy: 0.5790 - val_loss: 1.0287 Epoch 3/3 92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 18ms/step - accuracy: 0.6545 - loss: 0.8629 - val_accuracy: 0.6144 - val_loss: 0.9710 <keras.src.callbacks.history.History at 0x7f2dec3897f0>
您可能会注意到验证准确率与训练准确率相比很低,这表明您的模型过度拟合。您可以在此 教程 中了解有关过度拟合以及如何减少过度拟合的更多信息。
使用 tf.data 进行更精细的控制
上述 Keras 预处理实用程序——tf.keras.utils.image_dataset_from_directory
——是使用图像目录创建 tf.data.Dataset
的便捷方法。
为了更精细的控制,您可以使用 tf.data
编写自己的输入管道。本节将展示如何做到这一点,从您之前下载的 TGZ 文件中的文件路径开始。
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
for f in list_ds.take(5):
print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/roses/102501987_3cdb8e5394_n.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6050020905_881295ac72_n.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6061177447_d8ce96aee0.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/roses/4325834819_ab56661dcc_m.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/tulips/19413898445_69344f9956_n.jpg' 2024-07-13 05:35:12.764905: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
文件的树结构可用于编译 class_names
列表。
class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]))
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
将数据集拆分为训练集和验证集
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)
您可以按如下方式打印每个数据集的长度。
print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())
2936 734
编写一个将文件路径转换为 (img, label)
对的简短函数。
def get_label(file_path):
# Convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
# The second to last is the class-directory
one_hot = parts[-2] == class_names
# Integer encode the label
return tf.argmax(one_hot)
def decode_img(img):
# Convert the compressed string to a 3D uint8 tensor
img = tf.io.decode_jpeg(img, channels=3)
# Resize the image to the desired size
return tf.image.resize(img, [img_height, img_width])
def process_path(file_path):
label = get_label(file_path)
# Load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
使用 Dataset.map
创建一个包含 image, label
对的数据集。
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in train_ds.take(1):
print("Image shape: ", image.numpy().shape)
print("Label: ", label.numpy())
Image shape: (180, 180, 3) Label: 1 2024-07-13 05:35:13.038820: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
配置数据集以提高性能
要使用此数据集训练模型,您将需要数据
- 充分洗牌。
- 进行批处理。
- 批次尽快可用。
可以使用 tf.data
API 添加这些功能。有关更多详细信息,请访问 输入管道性能 指南。
def configure_for_performance(ds):
ds = ds.cache()
ds = ds.shuffle(buffer_size=1000)
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
可视化数据
您可以像以前一样可视化此数据集。
image_batch, label_batch = next(iter(train_ds))
plt.figure(figsize=(10, 10))
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image_batch[i].numpy().astype("uint8"))
label = label_batch[i]
plt.title(class_names[label])
plt.axis("off")
2024-07-13 05:35:13.274884: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
继续训练模型
您现在已经手动构建了一个类似于上面 tf.keras.utils.image_dataset_from_directory
创建的 tf.data.Dataset
。您可以继续使用它训练模型。与之前一样,您将只训练几个时期以缩短运行时间。
model.fit(
train_ds,
validation_data=val_ds,
epochs=3
)
Epoch 1/3 92/92 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step - accuracy: 0.6975 - loss: 0.7811 - val_accuracy: 0.7112 - val_loss: 0.7230 Epoch 2/3 92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 22ms/step - accuracy: 0.7734 - loss: 0.5998 - val_accuracy: 0.7262 - val_loss: 0.6824 Epoch 3/3 92/92 ━━━━━━━━━━━━━━━━━━━━ 2s 22ms/step - accuracy: 0.8561 - loss: 0.4123 - val_accuracy: 0.7425 - val_loss: 0.6843 <keras.src.callbacks.history.History at 0x7f2dcc1aa820>
使用 TensorFlow 数据集
到目前为止,本教程一直专注于从磁盘加载数据。您还可以通过探索 大型目录 中易于下载的数据集来找到要使用的数据集,该目录位于 TensorFlow 数据集 中。
由于您之前已经从磁盘加载了 Flowers 数据集,因此现在让我们使用 TensorFlow 数据集导入它。
使用 TensorFlow 数据集下载 Flowers 数据集
(train_ds, val_ds, test_ds), metadata = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)
Flowers 数据集有五个类别。
num_classes = metadata.features['label'].num_classes
print(num_classes)
5
从数据集中检索图像。
get_label_name = metadata.features['label'].int2str
image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2024-07-13 05:35:25.302143: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
与之前一样,请记住对训练集、验证集和测试集进行批处理、洗牌和配置以提高性能。
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
test_ds = configure_for_performance(test_ds)
您可以在 数据增强 教程中找到使用 Flowers 数据集和 TensorFlow 数据集的完整示例。
后续步骤
本教程展示了两种从磁盘加载图像的方法。首先,您学习了如何使用 Keras 预处理层和实用程序加载和预处理图像数据集。接下来,您学习了如何使用 tf.data
从头开始编写输入管道。最后,您学习了如何从 TensorFlow 数据集下载数据集。
下一步
- 您可以学习 如何添加数据增强。
- 要了解有关
tf.data
的更多信息,您可以访问 tf.data:构建 TensorFlow 输入管道 指南。