本页面介绍了在实现新数据集时常见的实现问题。
应避免使用旧版 SplitGenerator
旧版 tfds.core.SplitGenerator
API 已弃用。
def _split_generator(...):
return [
tfds.core.SplitGenerator(name='train', gen_kwargs={'path': train_path}),
tfds.core.SplitGenerator(name='test', gen_kwargs={'path': test_path}),
]
应替换为
def _split_generator(...):
return {
'train': self._generate_examples(path=train_path),
'test': self._generate_examples(path=test_path),
}
原因:新 API 更加简洁,更明确。旧 API 将在未来版本中移除。
新数据集应在一个文件夹中自包含
在 tensorflow_datasets/
存储库中添加数据集时,请确保遵循数据集作为文件夹的结构(所有校验和、虚拟数据、实现代码都自包含在一个文件夹中)。
- 旧数据集(错误):
<category>/<ds_name>.py
- 新数据集(正确):
<category>/<ds_name>/<ds_name>.py
使用 TFDS CLI (tfds new
或 gtfds new
用于 Google 员工) 生成模板。
原因:旧结构需要校验和、虚拟数据的绝对路径,并且在许多地方分发数据集文件。这使得在 TFDS 存储库之外实现数据集变得更加困难。为了保持一致性,现在应该在所有地方使用新结构。
描述列表应格式化为 Markdown
DatasetInfo.description
str
格式化为 Markdown。Markdown 列表需要在第一个项目之前留空一行
_DESCRIPTION = """
Some text.
# << Empty line here !!!
1. Item 1
2. Item 1
3. Item 1
# << Empty line here !!!
Some other text.
"""
原因:格式错误的描述会在我们的目录文档中产生视觉效果。如果没有空行,上面的文本将呈现为
一些文本。1. 项目 1 2. 项目 1 3. 项目 1 其他一些文本
忘记 ClassLabel 名称
使用 tfds.features.ClassLabel
时,请尝试使用 names=
或 names_file=
提供人类可读的标签 str
(而不是 num_classes=10
)。
features = {
'label': tfds.features.ClassLabel(names=['dog', 'cat', ...]),
}
原因:人类可读的标签在许多地方使用
- 允许在
_generate_examples
中直接生成str
:yield {'label': 'dog'}
- 在用户中公开为
info.features['label'].names
(转换方法.str2int('dog')
,... 也可用) - 在 可视化工具
tfds.show_examples
、tfds.as_dataframe
中使用
忘记图像形状
使用 tfds.features.Image
、tfds.features.Video
时,如果图像具有静态形状,则应明确指定它们
features = {
'image': tfds.features.Image(shape=(256, 256, 3)),
}
原因:它允许静态形状推断(例如 ds.element_spec['image'].shape
),这是批处理所必需的(对未知形状的图像进行批处理需要先调整它们的大小)。
优先使用更具体的类型,而不是 tfds.features.Tensor
尽可能优先使用更具体的类型 tfds.features.ClassLabel
、tfds.features.BBoxFeatures
,... 而不是通用的 tfds.features.Tensor
。
原因:除了更具语义正确性之外,特定功能还为用户提供额外的元数据,并且会被工具检测到。
全局空间中的延迟导入
延迟导入不应从全局空间调用。例如,以下操作是错误的
tfds.lazy_imports.apache_beam # << Error: Import beam in the global scope
def f() -> beam.Map:
...
原因:在全局范围内使用延迟导入将为所有 tfds 用户导入模块,从而失去了延迟导入的意义。
动态计算训练/测试拆分
如果数据集没有提供官方拆分,TFDS 也不应该提供。以下情况应避免
_TRAIN_TEST_RATIO = 0.7
def _split_generator():
ids = list(range(num_examples))
np.random.RandomState(seed).shuffle(ids)
# Split train/test
train_ids = ids[_TRAIN_TEST_RATIO * num_examples:]
test_ids = ids[:_TRAIN_TEST_RATIO * num_examples]
return {
'train': self._generate_examples(train_ids),
'test': self._generate_examples(test_ids),
}
理由:TFDS 试图提供尽可能接近原始数据的 数据集。应该使用 子拆分 API 来让用户动态创建他们想要的子拆分
ds_train, ds_test = tfds.load(..., split=['train[:80%]', 'train[80%:]'])
Python 风格指南
优先使用 pathlib API
而不是使用 tf.io.gfile
API,最好使用 pathlib API。所有 dl_manager
方法返回与 GCS、S3 等兼容的类似 pathlib 的对象。
path = dl_manager.download_and_extract('http://some-website/my_data.zip')
json_path = path / 'data/file.json'
json.loads(json_path.read_text())
理由:pathlib API 是一个现代的面向对象的 文件 API,它消除了样板代码。使用 .read_text()
/ .read_bytes()
还可以保证文件正确关闭。
如果方法没有使用 self
,它应该是一个函数
如果类方法没有使用 self
,它应该是一个简单的函数(在类之外定义)。
理由:这使读者明确地知道该函数没有副作用,也没有隐藏的输入/输出
x = f(y) # Clear inputs/outputs
x = self.f(y) # Does f depend on additional hidden variables ? Is it stateful ?
Python 中的延迟导入
我们延迟导入像 TensorFlow 这样的大型模块。延迟导入将模块的实际导入推迟到模块第一次使用时。因此,不需要这个大型模块的用户永远不会导入它。我们使用 etils.epy.lazy_imports
。
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
# After this statement, TensorFlow is not imported yet
...
features = tfds.features.Image(dtype=tf.uint8)
# After using it (`tf.uint8`), TensorFlow is now imported
在幕后,LazyModule
类 充当工厂,只有在访问属性(__getattr__
)时才会实际导入模块。
你也可以方便地使用上下文管理器
from etils import epy
with epy.lazy_imports(error_callback=..., success_callback=...):
import some_big_module