常见实现问题

本页面介绍了在实现新数据集时常见的实现问题。

应避免使用旧版 SplitGenerator

旧版 tfds.core.SplitGenerator API 已弃用。

def _split_generator(...):
  return [
      tfds.core.SplitGenerator(name='train', gen_kwargs={'path': train_path}),
      tfds.core.SplitGenerator(name='test', gen_kwargs={'path': test_path}),
  ]

应替换为

def _split_generator(...):
  return {
      'train': self._generate_examples(path=train_path),
      'test': self._generate_examples(path=test_path),
  }

原因:新 API 更加简洁,更明确。旧 API 将在未来版本中移除。

新数据集应在一个文件夹中自包含

tensorflow_datasets/ 存储库中添加数据集时,请确保遵循数据集作为文件夹的结构(所有校验和、虚拟数据、实现代码都自包含在一个文件夹中)。

  • 旧数据集(错误):<category>/<ds_name>.py
  • 新数据集(正确):<category>/<ds_name>/<ds_name>.py

使用 TFDS CLI (tfds newgtfds new 用于 Google 员工) 生成模板。

原因:旧结构需要校验和、虚拟数据的绝对路径,并且在许多地方分发数据集文件。这使得在 TFDS 存储库之外实现数据集变得更加困难。为了保持一致性,现在应该在所有地方使用新结构。

描述列表应格式化为 Markdown

DatasetInfo.description str 格式化为 Markdown。Markdown 列表需要在第一个项目之前留空一行

_DESCRIPTION = """
Some text.
                      # << Empty line here !!!
1. Item 1
2. Item 1
3. Item 1
                      # << Empty line here !!!
Some other text.
"""

原因:格式错误的描述会在我们的目录文档中产生视觉效果。如果没有空行,上面的文本将呈现为

一些文本。1. 项目 1 2. 项目 1 3. 项目 1 其他一些文本

忘记 ClassLabel 名称

使用 tfds.features.ClassLabel 时,请尝试使用 names=names_file= 提供人类可读的标签 str(而不是 num_classes=10)。

features = {
    'label': tfds.features.ClassLabel(names=['dog', 'cat', ...]),
}

原因:人类可读的标签在许多地方使用

忘记图像形状

使用 tfds.features.Imagetfds.features.Video 时,如果图像具有静态形状,则应明确指定它们

features = {
    'image': tfds.features.Image(shape=(256, 256, 3)),
}

原因:它允许静态形状推断(例如 ds.element_spec['image'].shape),这是批处理所必需的(对未知形状的图像进行批处理需要先调整它们的大小)。

优先使用更具体的类型,而不是 tfds.features.Tensor

尽可能优先使用更具体的类型 tfds.features.ClassLabeltfds.features.BBoxFeatures,... 而不是通用的 tfds.features.Tensor

原因:除了更具语义正确性之外,特定功能还为用户提供额外的元数据,并且会被工具检测到。

全局空间中的延迟导入

延迟导入不应从全局空间调用。例如,以下操作是错误的

tfds.lazy_imports.apache_beam # << Error: Import beam in the global scope

def f() -> beam.Map:
  ...

原因:在全局范围内使用延迟导入将为所有 tfds 用户导入模块,从而失去了延迟导入的意义。

动态计算训练/测试拆分

如果数据集没有提供官方拆分,TFDS 也不应该提供。以下情况应避免

_TRAIN_TEST_RATIO = 0.7

def _split_generator():
  ids = list(range(num_examples))
  np.random.RandomState(seed).shuffle(ids)

  # Split train/test
  train_ids = ids[_TRAIN_TEST_RATIO * num_examples:]
  test_ids = ids[:_TRAIN_TEST_RATIO * num_examples]
  return {
      'train': self._generate_examples(train_ids),
      'test': self._generate_examples(test_ids),
  }

理由:TFDS 试图提供尽可能接近原始数据的 数据集。应该使用 子拆分 API 来让用户动态创建他们想要的子拆分

ds_train, ds_test = tfds.load(..., split=['train[:80%]', 'train[80%:]'])

Python 风格指南

优先使用 pathlib API

而不是使用 tf.io.gfile API,最好使用 pathlib API。所有 dl_manager 方法返回与 GCS、S3 等兼容的类似 pathlib 的对象。

path = dl_manager.download_and_extract('http://some-website/my_data.zip')

json_path = path / 'data/file.json'

json.loads(json_path.read_text())

理由:pathlib API 是一个现代的面向对象的 文件 API,它消除了样板代码。使用 .read_text() / .read_bytes() 还可以保证文件正确关闭。

如果方法没有使用 self,它应该是一个函数

如果类方法没有使用 self,它应该是一个简单的函数(在类之外定义)。

理由:这使读者明确地知道该函数没有副作用,也没有隐藏的输入/输出

x = f(y)  # Clear inputs/outputs

x = self.f(y)  # Does f depend on additional hidden variables ? Is it stateful ?

Python 中的延迟导入

我们延迟导入像 TensorFlow 这样的大型模块。延迟导入将模块的实际导入推迟到模块第一次使用时。因此,不需要这个大型模块的用户永远不会导入它。我们使用 etils.epy.lazy_imports

from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
# After this statement, TensorFlow is not imported yet

...

features = tfds.features.Image(dtype=tf.uint8)
# After using it (`tf.uint8`), TensorFlow is now imported

在幕后,LazyModule 充当工厂,只有在访问属性(__getattr__)时才会实际导入模块。

你也可以方便地使用上下文管理器

from etils import epy

with epy.lazy_imports(error_callback=..., success_callback=...):
  import some_big_module