tf.data:构建 TensorFlow 输入管道

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

tf.data API 使您能够从简单、可重用的部分构建复杂的输入管道。例如,图像模型的管道可能会从分布式文件系统中的文件聚合数据,对每个图像应用随机扰动,并将随机选择的图像合并到一个批次中以进行训练。文本模型的管道可能涉及从原始文本数据中提取符号,使用查找表将它们转换为嵌入标识符,并将不同长度的序列批处理在一起。该 tf.data API 使处理大量数据、从不同数据格式读取以及执行复杂转换成为可能。

tf.data API 引入了一个 tf.data.Dataset 抽象,它表示元素的序列,其中每个元素包含一个或多个组件。例如,在图像管道中,一个元素可能是一个训练示例,其中一对张量组件表示图像及其标签。

有两种不同的方法来创建数据集

  • 数据 从存储在内存或一个或多个文件中的数据构建 Dataset

  • 数据 转换 从一个或多个 tf.data.Dataset 对象构建数据集。

import tensorflow as tf
2024-01-17 02:24:38.925403: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-17 02:24:38.925446: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-17 02:24:38.927059: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

基本机制

要创建输入管道,您必须从数据开始。例如,要从内存中的数据构建一个Dataset,您可以使用tf.data.Dataset.from_tensors()tf.data.Dataset.from_tensor_slices()。或者,如果您的输入数据存储在推荐的 TFRecord 格式的文件中,您可以使用tf.data.TFRecordDataset()

拥有Dataset对象后,您可以通过在tf.data.Dataset对象上链接方法调用来将其转换为新的Dataset。例如,您可以应用诸如Dataset.map之类的逐元素转换,以及诸如Dataset.batch之类的多元素转换。有关转换的完整列表,请参阅tf.data.Dataset的文档。

Dataset对象是 Python 可迭代对象。这使得可以使用 for 循环来使用其元素

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<_TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>
for elem in dataset:
  print(elem.numpy())
8
3
0
8
2
1

或者通过使用iter显式创建 Python 迭代器,并使用next使用其元素

it = iter(dataset)

print(next(it).numpy())
8

或者,可以使用reduce转换来使用数据集元素,该转换将所有元素减少以产生单个结果。以下示例说明了如何使用reduce转换来计算整数数据集的总和。

print(dataset.reduce(0, lambda state, value: state + value).numpy())
22

数据集结构

数据集生成一系列元素,其中每个元素都是组件的相同(嵌套)结构。结构的各个组件可以是tf.TypeSpec可表示的任何类型,包括tf.Tensortf.sparse.SparseTensortf.RaggedTensortf.TensorArraytf.data.Dataset

可用于表达元素(嵌套)结构的 Python 结构包括tupledictNamedTupleOrderedDict。特别是,list不是用于表达数据集元素结构的有效结构。这是因为早期的tf.data用户强烈要求将list输入(例如,传递给tf.data.Dataset.from_tensors)自动打包为张量,并将list输出(例如,用户定义函数的返回值)强制转换为tuple。因此,如果您希望将list输入视为结构,则需要将其转换为tuple,如果您希望list输出成为单个组件,则需要使用tf.stack显式打包它。

Dataset.element_spec属性允许您检查每个元素组件的类型。该属性返回tf.TypeSpec对象的嵌套结构,与元素的结构匹配,该结构可能是一个单个组件、一个组件元组或一个嵌套的组件元组。例如

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset转换支持任何结构的数据集。当使用Dataset.mapDataset.filter转换(将函数应用于每个元素)时,元素结构决定函数的参数

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1
<_TensorSliceDataset element_spec=TensorSpec(shape=(10,), dtype=tf.int32, name=None)>
for z in dataset1:
  print(z.numpy())
[2 1 3 2 1 3 9 6 4 5]
[8 9 2 9 1 4 7 2 4 5]
[3 9 6 8 4 6 8 4 9 5]
[4 3 7 2 8 6 4 9 7 6]
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2
<_TensorSliceDataset element_spec=(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None))>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3
<_ZipDataset element_spec=(TensorSpec(shape=(10,), dtype=tf.int32, name=None), (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None)))>
for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

读取输入数据

使用 NumPy 数组

有关更多示例,请参阅加载 NumPy 数组教程。

如果您的所有输入数据都适合内存,则创建Dataset的最简单方法是将它们转换为tf.Tensor对象并使用Dataset.from_tensor_slices

train, test = tf.keras.datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
29515/29515 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4422102/4422102 [==============================] - 0s 0us/step
images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
<_TensorSliceDataset element_spec=(TensorSpec(shape=(28, 28), dtype=tf.float64, name=None), TensorSpec(shape=(), dtype=tf.uint8, name=None))>

使用 Python 生成器

另一个可以轻松地作为tf.data.Dataset摄取的常见数据源是 Python 生成器。

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)
0
1
2
3
4

Dataset.from_generator构造函数将 Python 生成器转换为功能齐全的tf.data.Dataset

该构造函数以可调用对象作为输入,而不是迭代器。这使它能够在到达末尾时重新启动生成器。它接受一个可选的args参数,该参数作为可调用对象的参数传递。

output_types参数是必需的,因为tf.data在内部构建了一个tf.Graph,并且图边需要一个tf.dtype

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes参数不是必需的,但强烈建议使用,因为许多 TensorFlow 操作不支持具有未知秩的张量。如果特定轴的长度未知或可变,请在output_shapes中将其设置为None

还需要注意的是,output_shapesoutput_types遵循与其他数据集方法相同的嵌套规则。

以下是一个示例生成器,它演示了这两个方面:它返回数组元组,其中第二个数组是长度未知的向量。

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
0 : [-1.1977 -1.3099]
1 : [-0.3763  1.3639  0.0703  0.7888  0.3659 -0.0422 -0.5699 -1.4458]
2 : [ 0.4438  0.2206 -0.8348  3.0743 -0.2304  0.6876]
3 : [ 0.1138  0.3484 -0.3989  0.1871 -0.9462  0.7905  0.0224  0.204  -1.2715]
4 : [ 1.0292 -1.7965 -1.1569  0.437   1.9364  0.4718 -0.5036 -0.1318]
5 : [0.6893]
6 : [-0.2385 -0.3129  0.4913  0.2546  1.4849 -1.3109 -0.3785]

第一个输出是int32,第二个是float32

第一项是标量,形状为(),第二项是长度未知的向量,形状为(None,)

ds_series = tf.data.Dataset.from_generator(
    gen_series,
    output_types=(tf.int32, tf.float32),
    output_shapes=((), (None,)))

ds_series
<_FlatMapDataset element_spec=(TensorSpec(shape=(), dtype=tf.int32, name=None), TensorSpec(shape=(None,), dtype=tf.float32, name=None))>

现在它可以像常规的tf.data.Dataset一样使用。请注意,在对具有可变形状的数据集进行批处理时,您需要使用Dataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[ 4 18 19  2  8 22 13  0 25 27]

[[-0.5268  0.8465  1.8949 -0.6337 -0.9212  0.2917  0.1995 -0.2283  1.5621]
 [-0.7196  0.3447 -0.5744 -1.6807  1.9387 -0.7832  1.1232  0.5444  0.3566]
 [-1.0073  0.      0.      0.      0.      0.      0.      0.      0.    ]
 [ 1.3614 -0.0866  0.4309 -1.1438  0.066   0.3847 -0.8009  0.      0.    ]
 [-0.7528  0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.006   0.9022  1.2462  0.0703  0.      0.      0.      0.      0.    ]
 [ 0.5811  0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.1996  1.6923 -0.274  -0.7509 -0.6734 -1.687  -0.8438 -1.0904  0.    ]
 [ 0.3178  0.0775  1.3367  1.0921  0.1651  0.9298  0.0764 -0.4039  0.    ]
 [ 1.5668 -1.3154  0.8587 -0.7022  0.      0.      0.      0.      0.    ]]

有关更真实的示例,请尝试将preprocessing.image.ImageDataGenerator包装为tf.data.Dataset

首先下载数据

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228813984/228813984 [==============================] - 1s 0us/step

创建image.ImageDataGenerator

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3)
float32 (32, 5)
ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers),
    output_types=(tf.float32, tf.float32),
    output_shapes=([32,256,256,3], [32,5])
)

ds.element_spec
(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
for images, labels in ds.take(1):
  print('images.shape: ', images.shape)
  print('labels.shape: ', labels.shape)
Found 3670 images belonging to 5 classes.
images.shape:  (32, 256, 256, 3)
labels.shape:  (32, 5)

使用 TFRecord 数据

有关端到端示例,请参阅加载 TFRecords教程。

tf.data API 支持各种文件格式,以便您可以处理不适合内存的大型数据集。例如,TFRecord 文件格式是一种简单的面向记录的二进制格式,许多 TensorFlow 应用程序将其用于训练数据。该tf.data.TFRecordDataset类使您能够将一个或多个 TFRecord 文件的内容作为输入管道的一部分进行流式传输。

以下是一个使用法国街名标志 (FSNS) 测试文件的示例。

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7904079/7904079 [==============================] - 0s 0us/step

filenames参数传递给TFRecordDataset初始化程序可以是字符串、字符串列表或tf.Tensor字符串。因此,如果您有两组用于训练和验证的文件,您可以创建一个工厂方法来生成数据集,并将文件名作为输入参数

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

许多 TensorFlow 项目在其 TFRecord 文件中使用序列化tf.train.Example记录。在检查它们之前,需要对其进行解码

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
bytes_list {
  value: "Rue Perreyon"
}

使用文本数据

有关端到端示例,请参阅加载文本教程。

许多数据集以一个或多个文本文件的形式分发。该tf.data.TextLineDataset提供了一种从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset将为这些文件的每一行生成一个字符串值元素。

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
815980/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
809730/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
807992/807992 [==============================] - 0s 0us/step
dataset = tf.data.TextLineDataset(file_paths)

以下是第一个文件的头几行

for line in dataset.take(5):
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

要交替使用文件之间的行,请使用Dataset.interleave。这使得将文件一起混洗变得更加容易。以下是每个翻译的第一、第二和第三行

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

默认情况下,TextLineDataset会生成每个文件的所有行,这可能并不理想,例如,如果文件以标题行开头或包含注释。可以使用Dataset.skip()Dataset.filter转换来删除这些行。在这里,您跳过第一行,然后过滤以仅找到幸存者。

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
30874/30874 [==============================] - 0s 0us/step
for line in titanic_lines.take(10):
  print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

使用 CSV 数据

有关更多示例,请参阅加载 CSV 文件加载 Pandas 数据帧教程。

CSV 文件格式是一种流行的格式,用于以纯文本形式存储表格数据。

例如

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file)
df.head()

如果您的数据适合内存,则相同的 Dataset.from_tensor_slices 方法适用于字典,允许轻松导入此数据。

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

更可扩展的方法是在需要时从磁盘加载。

tf.data 模块提供从一个或多个符合 RFC 4180 的 CSV 文件中提取记录的方法。

tf.data.experimental.make_csv_dataset 函数是用于读取 CSV 文件集的高级接口。它支持列类型推断和许多其他功能,例如批处理和洗牌,以简化使用。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [0 1 0 0]
features:
  'sex'               : [b'male' b'male' b'male' b'female']
  'age'               : [28. 25. 30. 28.]
  'n_siblings_spouses': [1 1 0 3]
  'parch'             : [0 0 0 1]
  'fare'              : [15.5     7.775  27.75   25.4667]
  'class'             : [b'Third' b'Third' b'First' b'Third']
  'deck'              : [b'unknown' b'unknown' b'C' b'unknown']
  'embark_town'       : [b'Queenstown' b'Southampton' b'Cherbourg' b'Southampton']
  'alone'             : [b'n' b'n' b'y' b'n']

如果您只需要一部分列,可以使用 select_columns 参数。

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 0 1 0]
  'fare'              : [11.1333 24.15    7.925  52.    ]
  'class'             : [b'Third' b'Third' b'Third' b'First']

还有一个更低级的 experimental.CsvDataset 类,它提供更细粒度的控制。它不支持列类型推断。相反,您必须指定每列的类型。

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

如果某些列为空,此低级接口允许您提供默认值而不是列类型。

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv
# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<_MapDataset element_spec=TensorSpec(shape=(4,), dtype=tf.int32, name=None)>
for line in dataset:
  print(line.numpy())
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

默认情况下,CsvDataset 会生成文件行的列,这可能不理想,例如,如果文件以应忽略的标题行开头,或者如果某些列在输入中不需要。这些行和字段可以使用 headerselect_cols 参数分别删除。

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<_MapDataset element_spec=TensorSpec(shape=(2,), dtype=tf.int32, name=None)>
for line in dataset:
  print(line.numpy())
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

使用文件集

许多数据集以文件集的形式分发,其中每个文件都是一个示例。

flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

根目录包含每个类的目录

for item in flowers_root.glob("*"):
  print(item.name)
roses
sunflowers
LICENSE.txt
dandelion
tulips
daisy

每个类目录中的文件都是示例

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/4520577328_a94c11e806_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/22679060358_561ec823ae_m.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/daisy/6978826370_7b9aa7c7d5.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/3447650747_8299786b80_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/tulips/112951022_4892b1348b_n.jpg'

使用 tf.io.read_file 函数读取数据并从路径中提取标签,返回 (image, label)

def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xe1\t\xcbXMP\x00://ns.adobe.com/xap/1.0/\x00<?xpacket begin='\xef\xbb\xbf' id='W5M0MpCehiHzreSzNTczk"

b'roses'

批处理数据集元素

简单批处理

最简单的批处理形式将数据集的 n 个连续元素堆叠成单个元素。该 Dataset.batch() 变换正是这样做的,与 tf.stack() 运算符具有相同的约束,应用于元素的每个组件:即,对于每个组件i,所有元素必须具有完全相同形状的张量。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

虽然 tf.data 尝试传播形状信息,但 Dataset.batch 的默认设置会导致未知的批次大小,因为最后一个批次可能不完整。请注意形状中的 None

batched_dataset
<_BatchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

使用 drop_remainder 参数忽略最后一个批次,并获得完整的形状传播

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
<_BatchDataset element_spec=(TensorSpec(shape=(7,), dtype=tf.int64, name=None), TensorSpec(shape=(7,), dtype=tf.int64, name=None))>

使用填充批处理张量

上述方法适用于所有大小相同的张量。但是,许多模型(包括序列模型)使用可以具有不同大小的输入数据(例如,不同长度的序列)。为了处理这种情况,该 Dataset.padded_batch 变换使您能够通过指定一个或多个维度来批处理不同形状的张量,这些维度可以在其中填充。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()
[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]

Dataset.padded_batch 变换允许您为每个组件的每个维度设置不同的填充,并且它可以是可变长度(在上面的示例中由 None 表示)或固定长度。还可以覆盖填充值,该值默认为 0。

训练工作流程

处理多个时期

tf.data API 提供了两种主要方法来处理相同数据的多个时期。

在多个时期内迭代数据集的最简单方法是使用 Dataset.repeat() 变换。首先,创建一个泰坦尼克号数据集

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

应用 Dataset.repeat() 变换而不带参数将无限期地重复输入。

Dataset.repeat 变换连接其参数,而不发出一个时期的结束和下一个时期的开始的信号。因此,在 Dataset.repeat 之后应用的 Dataset.batch 将产生跨越时期边界的批次

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

png

如果您需要明确的时期分离,请将 Dataset.batch 放在重复之前

titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)

png

如果您想在每个时期的末尾执行自定义计算(例如,收集统计信息),那么最简单的方法是在每个时期重新启动数据集迭代

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

随机洗牌输入数据

Dataset.shuffle() 变换维护一个固定大小的缓冲区,并从该缓冲区中随机均匀地选择下一个元素。

向数据集添加索引,以便您可以看到效果

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_14491/4092668703.py:2: CounterV2 (from tensorflow.python.data.experimental.ops.counter) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.
<_BatchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None))>

由于 buffer_size 为 100,批次大小为 20,因此第一个批次不包含索引超过 120 的元素。

n,line_batch = next(iter(dataset))
print(n.numpy())
[ 43  13  17  72  33  89  83 105  96  81   0  67  97  84  73  32  30  71
  64 103]

Dataset.batch 一样,相对于 Dataset.repeat 的顺序很重要。

Dataset.shuffle 不会发出一个时期的结束信号,直到洗牌缓冲区为空。因此,放在重复之前的洗牌将显示一个时期的每个元素,然后才移动到下一个时期

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[625 583 586 606 504 597 453 615 575 429]
[424 456 452 605 483 566 395 556 492 365]
[570 573 611 540 545 559 388 579]
[  0  18  92  79  81  86  62 103  29  82]
[47 69 17 95  9 11 77 84 31 53]
shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f32b0355dc0>

png

但是,重复之前的洗牌会将时期边界混合在一起

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[464 377 595 328 614 504   7  20 433 623]
[419 550 615 509 499 540 557 622 570 544]
[588 618 616   4 460  38  16 617  31 591]
[440 610 585 600  36  17  35  52 592  19]
[523  59 545 624 607  51  53  26  33 318]
[510  37   6 448 612 469  32  10  39 594]
[ 41  63  13 627  67  76 386 579 412  55]
[  1  54 626  71  64  22  47 553 525  65]
[ 69   3  15 102  14 455  23  98  74  78]
[596  12  50   5  18 112 114  97  61  42]
[103  84 583  90 350 575 606  85 107 108]
[115 127  60 602 118  43  34  58  46 587]
[119  56 620  75 564 625  88 140 539  45]
[589 100 149 452 110  11  66 132 142 111]
[101 334  94 497 520 158 120  86 135  95]
repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f32b033bdf0>

png

预处理数据

Dataset.map(f) 变换通过将给定函数 f 应用于输入数据集的每个元素来生成一个新的数据集。它基于 map() 函数,该函数通常应用于函数式编程语言中的列表(和其他结构)。函数 f 获取表示输入中单个元素的 tf.Tensor 对象,并返回将表示新数据集中的单个元素的 tf.Tensor 对象。它的实现使用标准 TensorFlow 操作将一个元素转换为另一个元素。

本节介绍了如何使用 Dataset.map() 的常见示例。

解码图像数据并调整其大小

在对真实世界图像数据进行神经网络训练时,通常需要将不同大小的图像转换为通用大小,以便它们可以批处理成固定大小。

重建花卉文件名数据集

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

编写一个操作数据集元素的函数。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(filename, os.sep)
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.io.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label

测试它是否有效。

file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)

png

将其映射到数据集上。

images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)

png

png

应用任意 Python 逻辑

出于性能原因,尽可能使用 TensorFlow 操作来预处理您的数据。但是,在解析输入数据时,有时使用外部 Python 库很有用。您可以在 Dataset.map 变换中使用 tf.py_function 操作。

例如,如果您想应用随机旋转,则 tf.image 模块只有 tf.image.rot90,这对于图像增强来说不太有用。

为了演示 tf.py_function,尝试使用 scipy.ndimage.rotate 函数代替

import scipy.ndimage as ndimage

@tf.py_function(Tout=tf.float32)
def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

要将此函数与 Dataset.map 一起使用,与 Dataset.from_generator 相同的注意事项适用,您需要在应用函数时描述返回值的形状和类型

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  image = random_rotate_image(image)
  image.set_shape(im_shape)
  return image, label
rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

png

解析 tf.Example 协议缓冲区消息

许多输入管道从 TFRecord 格式中提取 tf.train.Example 协议缓冲区消息。每个 tf.train.Example 记录包含一个或多个“特征”,输入管道通常将这些特征转换为张量。

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

您可以在 tf.data.Dataset 之外使用 tf.train.Example 原型来了解数据

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

png

raw_example = next(iter(dataset))
def tf_parse(eg):
  example = tf.io.parse_example(
      eg[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...
decoded = dataset.map(tf_parse)
decoded
<_MapDataset element_spec=(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.string, name=None))>
image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])

时间序列窗口

有关端到端时间序列示例,请参阅:时间序列预测

时间序列数据通常以时间轴完整的方式组织。

使用简单的 Dataset.range 来演示

range_ds = tf.data.Dataset.range(100000)

通常,基于此类数据的模型将需要一个连续的时间片。

最简单的方法是批处理数据

使用 batch

batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

或者,为了进行密集预测,您可以在未来一步预测,您可以将特征和标签相对于彼此移动一步

def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

为了预测整个窗口而不是固定偏移量,您可以将批次分成两部分

batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Inputs: All except the last 5 steps
          batch[-5:])   # Labels: The last 5 steps

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

为了允许一个批次的特征与另一个批次的标签之间存在一些重叠,请使用 Dataset.zip

feature_length = 10
label_length = 3

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:label_length])

predicted_steps = tf.data.Dataset.zip((features, labels))

for features, label in predicted_steps.take(5):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42]
[40 41 42 43 44 45 46 47 48 49]  =>  [50 51 52]

使用 window

虽然使用 Dataset.batch 有效,但在某些情况下您可能需要更精细的控制。该 Dataset.window 方法为您提供了完全的控制,但需要一些注意:它返回一个 DatasetDatasets。有关详细信息,请转到 数据集结构 部分。

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>

The Dataset.flat_map method can take a dataset of datasets and flatten it into a single dataset

for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9

In nearly all cases, you will want to Dataset.batch the dataset first

def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

Now, you can see that the shift argument controls how much each window moves over.

Putting this together you might write this function

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

Then it's easy to extract labels, as before

dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

Resampling

When working with a dataset that is very class-imbalanced, you may want to resample the dataset. tf.data provides two methods to do this. The credit card fraud dataset is a good example of this sort of problem.

zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69155632/69155632 [==============================] - 0s 0us/step
creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

Now, check the distribution of classes, it is highly skewed

def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
[0.9953 0.0047]

A common approach to training with an imbalanced dataset is to balance it. tf.data includes a few methods which enable this workflow

Datasets sampling

One approach to resampling a dataset is to use sample_from_datasets. This is more applicable when you have a separate tf.data.Dataset for each class.

Here, just use filter to generate them from the credit card fraud data

negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())
for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]

To use tf.data.Dataset.sample_from_datasets pass the datasets, and the weight for each

balanced_ds = tf.data.Dataset.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

Now the dataset produces examples of each class with a 50/50 probability

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[0 0 0 1 0 0 1 0 0 0]
[1 1 0 0 1 1 1 0 0 1]
[1 0 0 0 1 1 0 0 0 0]
[1 0 0 0 0 1 0 1 0 0]
[1 1 1 1 0 1 0 0 1 0]
[0 1 1 0 0 0 1 1 0 1]
[0 1 1 0 1 1 0 1 1 1]
[0 0 0 1 0 0 1 1 0 0]
[0 1 0 0 1 1 1 0 1 0]
[1 1 1 1 1 1 0 0 0 0]

Rejection resampling

One problem with the above Dataset.sample_from_datasets approach is that it needs a separate tf.data.Dataset per class. You could use Dataset.filter to create those two datasets, but that results in all the data being loaded twice.

The tf.data.Dataset.rejection_resample method can be applied to a dataset to rebalance it, while only loading it once. Elements will be dropped or repeated to achieve balance.

The rejection_resample method takes a class_func argument. This class_func is applied to each dataset element, and is used to determine which class an example belongs to for the purposes of balancing.

The goal here is to balance the label distribution, and the elements of creditcard_ds are already (features, label) pairs. So the class_func just needs to return those labels

def class_func(features, label):
  return label

The resampling method deals with individual examples, so in this case you must unbatch the dataset before applying that method.

The method needs a target distribution, and optionally an initial distribution estimate as inputs.

resample_ds = (
    creditcard_ds
    .unbatch()
    .rejection_resample(class_func, target_dist=[0.5,0.5],
                        initial_dist=fractions)
    .batch(10))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:4963: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:

The rejection_resample method returns (class, example) pairs where the class is the output of the class_func. In this case, the example was already a (feature, label) pair, so use map to drop the extra copy of the labels

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

Now the dataset produces examples of each class with a 50/50 probability

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
Proportion of examples rejected by sampler is high: [0.995312512][0.995312512 0.0046875][0 1]
[0 1 0 1 0 0 1 1 0 1]
[1 1 0 1 0 0 1 1 0 0]
[0 0 1 0 1 0 0 0 0 1]
[1 0 0 1 0 0 1 0 1 0]
[0 1 1 0 0 1 0 1 1 0]
[1 1 0 1 1 0 1 1 0 1]
[1 1 1 0 0 0 0 0 0 1]
[0 1 0 0 0 0 0 1 1 1]
[0 1 0 1 1 0 0 1 0 1]
[0 0 1 0 1 1 1 0 0 1]

Iterator Checkpointing

Tensorflow supports taking checkpoints so that when your training process restarts it can restore the latest checkpoint to recover most of its progress. In addition to checkpointing the model variables, you can also checkpoint the progress of the dataset iterator. This could be useful if you have a large dataset and don't want to start the dataset from the beginning on each restart. Note however that iterator checkpoints may be large, since transformations such as Dataset.shuffle and Dataset.prefetch require buffering elements within the iterator.

To include your iterator in a checkpoint, pass the iterator to the tf.train.Checkpoint constructor.

range_ds = tf.data.Dataset.range(20)

iterator = iter(range_ds)
ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)

print([next(iterator).numpy() for _ in range(5)])

save_path = manager.save()

print([next(iterator).numpy() for _ in range(5)])

ckpt.restore(manager.latest_checkpoint)

print([next(iterator).numpy() for _ in range(5)])
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[5, 6, 7, 8, 9]

Using tf.data with tf.keras

The tf.keras API simplifies many aspects of creating and executing machine learning models. Its Model.fit and Model.evaluate and Model.predict APIs support datasets as inputs. Here is a quick dataset and model setup

train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Passing a dataset of (feature, label) pairs is all that's needed for Model.fit and Model.evaluate

model.fit(fmnist_train_ds, epochs=2)
Epoch 1/2
  26/1875 [..............................] - ETA: 3s - loss: 1.7645 - accuracy: 0.3930
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705458306.961075   14743 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1875/1875 [==============================] - 4s 2ms/step - loss: 0.5948 - accuracy: 0.7994
Epoch 2/2
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4601 - accuracy: 0.8427
<keras.src.callbacks.History at 0x7f32b0722250>

If you pass an infinite dataset, for example by calling Dataset.repeat, you just need to also pass the steps_per_epoch argument

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4791 - accuracy: 0.8344
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4777 - accuracy: 0.8438
<keras.src.callbacks.History at 0x7f32b0457100>

For evaluation you can pass the number of evaluation steps

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4396 - accuracy: 0.8490
Loss : 0.4396059811115265
Accuracy : 0.8489833474159241

For long datasets, set the number of steps to evaluate

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
10/10 [==============================] - 0s 2ms/step - loss: 0.4079 - accuracy: 0.8844
Loss : 0.4079427719116211
Accuracy : 0.8843749761581421

The labels are not required when calling Model.predict.

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
10/10 [==============================] - 0s 1ms/step
(320, 10)

But the labels are ignored if you do pass a dataset containing them

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
10/10 [==============================] - 0s 1ms/step
(320, 10)