TFDS 和确定性

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

本文档解释了

  • TFDS 对确定性的保证
  • TFDS 以何种顺序读取示例
  • 各种注意事项和问题

设置

数据集

需要一些上下文才能了解 TFDS 如何读取数据。

在生成过程中,TFDS 将原始数据写入标准化的 .tfrecord 文件中。对于大型数据集,会创建多个 .tfrecord 文件,每个文件包含多个示例。我们称每个 .tfrecord 文件为一个 **分片**。

本指南使用 imagenet,它有 1024 个分片

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

查找数据集示例 ID

如果您只想了解确定性,可以跳过以下部分。

每个数据集示例都由一个唯一的 id 标识(例如 'imagenet2012-train.tfrecord-01023-of-01024__32')。您可以通过传递 read_config.add_tfds_id = True 来恢复此 id,这将在 tf.data.Dataset 的字典中添加一个 'tfds_id' 键。

在本教程中,我们定义了一个小型实用程序,它将打印数据集的示例 ID(转换为整数以更易于理解)

读取时的确定性

本节解释了 tfds.load 的确定性保证。

使用 shuffle_files=False(默认值)

默认情况下,TFDS 以确定性方式生成示例 (shuffle_files=False)

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

为了提高性能,TFDS 使用 tf.data.Dataset.interleave 同时读取多个分片。在这个示例中,我们看到 TFDS 在读取 16 个示例后切换到分片 2 (..., 14, 15, 1251, 1252, ...)。有关 interleave 的更多信息,请参见下文。

类似地,子拆分 API 也是确定性的

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

如果您要训练多个 epoch,则不建议使用上述设置,因为所有 epoch 都将以相同的顺序读取分片(因此随机性仅限于 ds = ds.shuffle(buffer) 缓冲区大小)。

使用 shuffle_files=True

使用 shuffle_files=True,每个 epoch 的分片都会被洗牌,因此读取不再是确定性的。

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

请参见下面的食谱以获得确定性的文件洗牌。

确定性注意事项:interleave 参数

更改 read_config.interleave_cycle_lengthread_config.interleave_block_length 将更改示例顺序。

TFDS 依赖于 tf.data.Dataset.interleave 仅加载少量分片,从而提高性能并减少内存使用量。

示例顺序仅保证在 interleave 参数固定时保持一致。请参见 interleave 文档 了解 cycle_lengthblock_length 对应的内容。

  • cycle_length=16block_length=16(默认值,与上面相同)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3block_length=2
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

在第二个示例中,我们看到数据集读取了 2 个 (block_length=2) 示例,然后切换到下一个分片。每 2 * 3 (cycle_length=3) 个示例,它会回到第一个分片 (shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,...)。

子集和示例顺序

每个示例都有一个 id 0, 1, ..., num_examples-1子集 API 选择示例的切片(例如 train[:x] 选择 0, 1, ..., x-1)。

但是,在子集内,示例不是按递增的 id 顺序读取的(由于分片和交织)。

更具体地说,ds.take(x)split='train[:x]' 等效!

这在上面的交织示例中很容易看到,其中示例来自不同的分片。

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

在 16 个(block_length)示例之后,.take(25) 切换到下一个分片,而 train[:25] 继续从第一个分片中读取示例。

食谱

获得确定性文件混洗

有两种方法可以进行确定性混洗

  1. 设置 shuffle_seed。注意:这需要在每个 epoch 更改种子,否则分片将在 epoch 之间按相同的顺序读取。
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. 使用 experimental_interleave_sort_fn:这可以完全控制读取哪些分片以及读取顺序,而不是依赖于 ds.shuffle 顺序。
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

获得确定性可抢占管道

这个更复杂。没有简单、令人满意的解决方案。

  1. 没有 ds.shuffle 并且使用确定性混洗,理论上应该可以计算已读取的示例,并推断每个分片中已读取的示例(作为 cycle_lengthblock_length 和分片顺序的函数)。然后,可以通过 experimental_interleave_sort_fn 注入每个分片的 skiptake

  2. 使用 ds.shuffle,如果没有重放整个训练管道,这可能是不可能的。它需要保存 ds.shuffle 缓冲区状态以推断已读取的示例。示例可能是非连续的(例如,shard5_ex2shard5_ex4 已读取,但 shard5_ex3 未读取)。

  3. 使用 ds.shuffle,一种方法是保存所有已读取的 shard_ids/example_ids(从 tfds_id 推断),然后从中推断文件指令。

对于 1.,最简单的案例是让 .skip(x).take(y) 匹配 train[x:x+y] 匹配。它需要

  • 设置 cycle_length=1(以便按顺序读取分片)
  • 设置 shuffle_files=False
  • 不要使用 ds.shuffle

它应该只用于训练仅为 1 个 epoch 的大型数据集。示例将按默认的混洗顺序读取。

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

查找给定子集读取了哪些分片/示例

使用 tfds.core.DatasetInfo,您可以直接访问读取指令。

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]