如果您有一个 tf.train.Example
协议(在 .tfrecord
、.riegeli
等中),该协议是由第三方工具生成的,您希望直接使用 tfds API 加载,那么本页适合您。
为了加载您的 .tfrecord
文件,您只需要
- 遵循 TFDS 命名约定。
- 在您的 tfrecord 文件旁边添加元数据文件(
dataset_info.json
、features.json
)。
限制
tf.train.SequenceExample
不受支持,只支持tf.train.Example
。- 您需要能够用
tfds.features
来表达tf.train.Example
(参见下面的部分)。
文件命名约定
TFDS 支持为文件名定义模板,这提供了使用不同文件名方案的灵活性。模板由 tfds.core.ShardedFileTemplate
表示,并支持以下变量:{DATASET}
、{SPLIT}
、{FILEFORMAT}
、{SHARD_INDEX}
、{NUM_SHARDS}
和 {SHARD_X_OF_Y}
。例如,TFDS 的默认文件名方案是:{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}
。对于 MNIST,这意味着 文件名 如下所示
mnist-test.tfrecord-00000-of-00001
mnist-train.tfrecord-00000-of-00001
添加元数据
提供功能结构
为了让 TFDS 能够解码 tf.train.Example
协议,您需要提供与您的规范匹配的 tfds.features
结构。例如
features = tfds.features.FeaturesDict({
'image':
tfds.features.Image(
shape=(256, 256, 3),
doc='Picture taken by smartphone, downscaled.'),
'label':
tfds.features.ClassLabel(names=['dog', 'cat']),
'objects':
tfds.features.Sequence({
'camera/K': tfds.features.Tensor(shape=(3,), dtype=tf.float32),
}),
})
对应于以下 tf.train.Example
规范
{
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
'label': tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
'objects/camera/K': tf.io.FixedLenSequenceFeature(shape=(3,), dtype=tf.int64),
}
指定功能允许 TFDS 自动解码图像、视频等。与任何其他 TFDS 数据集一样,功能元数据(例如标签名称等)将公开给用户(例如 info.features['label'].names
)。
如果您控制生成管道
如果您在 TFDS 之外生成数据集,但仍然控制生成管道,则可以使用 tfds.features.FeatureConnector.serialize_example
将您的数据从 dict[np.ndarray]
编码为 tf.train.Example
协议 bytes
with tf.io.TFRecordWriter('path/to/file.tfrecord') as writer:
for ex in all_exs:
ex_bytes = features.serialize_example(data)
writer.write(ex_bytes)
这将确保与 TFDS 的功能兼容性。
类似地,一个 feature.deserialize_example
存在于解码协议(示例)
如果您不控制生成管道
如果您想了解 tfds.features
如何在 tf.train.Example
中表示,您可以在 colab 中查看。
- 要将
tfds.features
转换为tf.train.Example
的人类可读结构,您可以调用features.get_serialized_info()
。 - 要获取传递给
tf.io.parse_single_example
的确切FixedLenFeature
等规范,您可以使用spec = features.tf_example_spec
获取拆分的统计信息
TFDS 需要知道每个分片中示例的确切数量。这是诸如 len(ds)
或 子拆分 API 之类的功能所必需的:split='train[75%:]'
。
如果您有此信息,您可以显式地创建一个
tfds.core.SplitInfo
列表,并跳到下一节split_infos = [ tfds.core.SplitInfo( name='train', shard_lengths=[1024, ...], # Num of examples in shard0, shard1,... num_bytes=0, # Total size of your dataset (if unknown, set to 0) ), tfds.core.SplitInfo(name='test', ...), ]
如果您不知道此信息,可以使用
compute_split_info.py
脚本(或在您自己的脚本中使用tfds.folder_dataset.compute_split_info
)来计算它。它将启动一个 beam 管道,该管道将读取给定目录上的所有分片并计算信息。
添加元数据文件
要自动将适当的元数据文件添加到您的数据集旁边,请使用 tfds.folder_dataset.write_metadata
tfds.folder_dataset.write_metadata(
data_dir='/path/to/my/dataset/1.0.0/',
features=features,
# Pass the `out_dir` argument of compute_split_info (see section above)
# You can also explicitly pass a list of `tfds.core.SplitInfo`.
split_infos='/path/to/my/dataset/1.0.0/',
# Pass a custom file name template or use None for the default TFDS
# file name template.
filename_template='{SPLIT}-{SHARD_X_OF_Y}.{FILEFORMAT}',
# Optionally, additional DatasetInfo metadata can be provided
# See:
# https://tensorflowcn.cn/datasets/api_docs/python/tfds/core/DatasetInfo
description="""Multi-line description."""
homepage='http://my-project.org',
supervised_keys=('image', 'label'),
citation="""BibTex citation.""",
)
一旦您在数据集目录上调用了该函数一次,元数据文件(dataset_info.json
等)就会被添加,您的数据集就可以使用 TFDS 加载(见下一节)。
使用 TFDS 加载数据集
直接从文件夹加载
元数据生成后,可以使用 tfds.builder_from_directory
加载数据集,该函数会返回一个 tfds.core.DatasetBuilder
,它具有标准的 TFDS API(如 tfds.builder
)。
builder = tfds.builder_from_directory('~/path/to/my_dataset/3.0.0/')
# Metadata are available as usual
builder.info.splits['train'].num_examples
# Construct the tf.data.Dataset pipeline
ds = builder.as_dataset(split='train[75%:]')
for ex in ds:
...
直接从多个文件夹加载
也可以从多个文件夹加载数据。例如,在强化学习中,多个代理分别生成单独的数据集,您可能希望将它们全部加载在一起,就会出现这种情况。其他用例包括定期生成新数据集,例如每天生成一个新数据集,并且您希望加载某个日期范围内的所有数据。
要从多个文件夹加载数据,请使用 tfds.builder_from_directories
,该函数会返回一个 tfds.core.DatasetBuilder
,它具有标准的 TFDS API(如 tfds.builder
)。
builder = tfds.builder_from_directories(builder_dirs=[
'~/path/my_dataset/agent1/1.0.0/',
'~/path/my_dataset/agent2/1.0.0/',
'~/path/my_dataset/agent3/1.0.0/',
])
# Metadata are available as usual
builder.info.splits['train'].num_examples
# Construct the tf.data.Dataset pipeline
ds = builder.as_dataset(split='train[75%:]')
for ex in ds:
...
文件夹结构(可选)
为了更好地与 TFDS 兼容,您可以将数据组织为 <data_dir>/<dataset_name>[/<dataset_config>]/<dataset_version>
。例如
data_dir/
dataset0/
1.0.0/
1.0.1/
dataset1/
config0/
2.0.0/
config1/
2.0.0/
这样,您的数据集将与 tfds.load
/ tfds.builder
API 兼容,只需提供 data_dir/
即可。
ds0 = tfds.load('dataset0', data_dir='data_dir/')
ds1 = tfds.load('dataset1/config0', data_dir='data_dir/')