所有 TFDS 数据集都公开各种数据拆分(例如 'train'、'test'),可在 目录 中进行浏览。除了 all(这是一个保留术语,对应于所有拆分的并集,见下文)之外,任何字母数字字符串都可以用作拆分名称。
除了“官方”数据集拆分之外,TFDS 还允许选择拆分和各种组合的切片。
切片 API
切片说明在 tfds.load 或 tfds.DatasetBuilder.as_dataset 中通过 split= kwarg 指定。
ds = tfds.load('my_dataset', split='train[:75%]')
builder = tfds.builder('my_dataset')
ds = builder.as_dataset(split='test+train[:75%]')
拆分可以是
- 普通拆分名称(例如
'train'、'test'等字符串):选择拆分中的所有示例。 - 切片:切片的语义与 python 切片符号 相同。切片可以是
- 绝对 (
'train[123:450]',train[:4000]):(有关读取顺序的警告,请参见下方的注释) - 百分比 (
'train[:75%]','train[25%:75%]'):将完整数据分成相等的切片。如果数据不能平均分配,则某些百分比可能包含额外的示例。支持小数百分比。 - 分片 (
train[:4shard],train[4shard]):选择请求的分片中的所有示例。(请参阅info.splits['train'].num_shards以获取拆分的碎片数)
- 绝对 (
- 拆分的并集 (
'train+test','train[:25%]+test'):拆分将交错在一起。 - 完整数据集 (
'all'):'all'是一个特殊拆分名称,对应于所有拆分的并集(等同于'train+test+...')。 - 拆分列表 (
['train', 'test']):多个tf.data.Dataset将分别返回
# Returns both train and test split separately
train_ds, test_ds = tfds.load('mnist', split=['train', 'test[:50%]'])
tfds.even_splits 和多主机训练
tfds.even_splits 会生成一个大小相同的非重叠子拆分列表。
# Divide the dataset into 3 even parts, each containing 1/3 of the data
split0, split1, split2 = tfds.even_splits('train', n=3)
ds = tfds.load('my_dataset', split=split2)
这在分布式设置中进行训练时特别有用,其中每个主机都应接收原始数据的一部分。
使用 Jax 时,可以使用 tfds.split_for_jax_process 进一步简化
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process 是以下内容的简单别名
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
tfds.even_splits、tfds.split_for_jax_process 接受任何拆分值作为输入(例如 'train[75%:]+test')
切片和元数据
可以使用 数据集信息 获取拆分/子拆分(num_examples、file_instructions 等)的其他信息
builder = tfds.builder('my_dataset')
builder.info.splits['train'].num_examples # 10_000
builder.info.splits['train[:75%]'].num_examples # 7_500 (also works with slices)
builder.info.splits.keys() # ['train', 'test']
交叉验证
使用字符串 API 进行 10 折交叉验证的示例
vals_ds = tfds.load('mnist', split=[
f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
trains_ds = tfds.load('mnist', split=[
f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])
验证数据集每个为 10%:[0%:10%]、[10%:20%]、...、[90%:100%]。训练数据集每个为互补的 90%:[10%:100%](对应验证集为 [0%:10%])、`[0%:10%]
- [20%:100%]
(对应验证集为[10%:20%]`),...
tfds.core.ReadInstruction 和舍入
可以将拆分作为 tfds.core.ReadInstruction 传递,而不是 str
例如,split = 'train[50%:75%] + test' 等效于
split = (
tfds.core.ReadInstruction(
'train',
from_=50,
to=75,
unit='%',
)
+ tfds.core.ReadInstruction('test')
)
ds = tfds.load('my_dataset', split=split)
unit 可以是
abs:绝对切片%:百分比切片shard:分片切片
tfds.ReadInstruction 还有一个舍入参数。如果数据集中的示例数量不能平均分配
rounding='closest'(默认):剩余示例在百分比中分配,因此某些百分比可能包含其他示例。rounding='pct1_dropremainder':剩余示例被丢弃,但这保证所有百分比包含完全相同的示例数量(例如:len(5%) == 5 * len(1%))。
可重复性和确定性
在生成期间,对于给定的数据集版本,TFDS 保证示例在磁盘上以确定性方式随机排列。因此,在两台不同的计算机上生成数据集两次不会更改示例顺序。
类似地,子拆分 API 始终会选择相同的示例 set,无论平台、架构等如何。这意味着 set('train[:20%]') == set('train[:10%]') + set('train[10%:20%]')。
但是,读取示例的顺序可能不是确定性的。这取决于其他参数(例如 shuffle_files=True)。