编写自定义数据集

按照本指南创建新的数据集(在 TFDS 或您自己的存储库中)。

查看我们的 数据集列表,看看您想要的数据集是否已经存在。

TL;DR

编写新数据集最简单的方法是使用 TFDS CLI

cd path/to/my/project/datasets/
tfds new my_dataset  # Create `my_dataset/my_dataset.py` template files
# [...] Manually modify `my_dataset/my_dataset_dataset_builder.py` to implement your dataset.
cd my_dataset/
tfds build  # Download and prepare the dataset to `~/tensorflow_datasets/`

要使用 tfds.load('my_dataset') 使用新数据集

  • tfds.load 将自动检测并加载在 ~/tensorflow_datasets/my_dataset/ 中生成的数据集(例如,通过 tfds build)。
  • 或者,您可以显式 import my.project.datasets.my_dataset 来注册您的数据集
import my.project.datasets.my_dataset  # Register `my_dataset`

ds = tfds.load('my_dataset')  # `my_dataset` registered

概述

数据集以各种格式分布在各种地方,并且并不总是以可以馈送到机器学习管道中的格式存储。TFDS 应运而生。

TFDS 将这些数据集处理成标准格式(外部数据 -> 序列化文件),然后可以将其加载为机器学习管道(序列化文件 -> tf.data.Dataset)。序列化只执行一次。后续访问将直接从这些预处理文件读取。

大多数预处理都是自动完成的。每个数据集都实现 tfds.core.DatasetBuilder 的子类,它指定

  • 数据来自哪里(即其 URL);
  • 数据集的外观(即其特征);
  • 如何拆分数据(例如,TRAINTEST);
  • 以及数据集中的各个示例。

编写您的数据集

默认模板:tfds new

使用 TFDS CLI 生成所需的模板 Python 文件。

cd path/to/project/datasets/  # Or use `--dir=path/to/project/datasets/` below
tfds new my_dataset

此命令将生成一个新的 my_dataset/ 文件夹,其结构如下

my_dataset/
    __init__.py
    README.md # Markdown description of the dataset.
    CITATIONS.bib # Bibtex citation for the dataset.
    TAGS.txt # List of tags describing the dataset.
    my_dataset_dataset_builder.py # Dataset definition
    my_dataset_dataset_builder_test.py # Test
    dummy_data/ # (optional) Fake data (used for testing)
    checksum.tsv # (optional) URL checksums (see `checksums` section).

在此处搜索 TODO(my_dataset) 并相应地修改。

数据集示例

所有数据集都实现了 tfds.core.DatasetBuilder 的子类,它负责处理大多数样板代码。它支持

  • 可以在单台机器上生成的小型/中型数据集(本教程)。
  • 需要分布式生成(使用 Apache Beam,请参阅我们的 大型数据集指南)的超大型数据集

这是一个基于 tfds.core.GeneratorBasedBuilder 的数据集构建器的最小示例

class Builder(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for my_dataset dataset."""

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {
      '1.0.0': 'Initial release.',
  }

  def _info(self) -> tfds.core.DatasetInfo:
    """Dataset metadata (homepage, citation,...)."""
    return self.dataset_info_from_configs(
        features=tfds.features.FeaturesDict({
            'image': tfds.features.Image(shape=(256, 256, 3)),
            'label': tfds.features.ClassLabel(
                names=['no', 'yes'],
                doc='Whether this is a picture of a cat'),
        }),
    )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
    """Download the data and define splits."""
    extracted_path = dl_manager.download_and_extract('http://data.org/data.zip')
    # dl_manager returns pathlib-like objects with `path.read_text()`,
    # `path.iterdir()`,...
    return {
        'train': self._generate_examples(path=extracted_path / 'train_images'),
        'test': self._generate_examples(path=extracted_path / 'test_images'),
    }

  def _generate_examples(self, path) -> Iterator[Tuple[Key, Example]]:
    """Generator of examples for each split."""
    for img_path in path.glob('*.jpeg'):
      # Yields (key, example)
      yield img_path.name, {
          'image': img_path,
          'label': 'yes' if img_path.name.startswith('yes_') else 'no',
      }

请注意,对于某些特定数据格式,我们提供现成的 数据集构建器 来处理大多数数据处理。

让我们详细了解要覆盖的 3 个抽象方法。

_info:数据集元数据

_info 返回包含 数据集元数据tfds.core.DatasetInfo

def _info(self):
  # The `dataset_info_from_configs` base method will construct the
  # `tfds.core.DatasetInfo` object using the passed-in parameters and
  # adding: builder (self), description/citations/tags from the config
  # files located in the same package.
  return self.dataset_info_from_configs(
      homepage='https://dataset-homepage.org',
      features=tfds.features.FeaturesDict({
          'image_description': tfds.features.Text(),
          'image': tfds.features.Image(),
          # Here, 'label' can be 0-4.
          'label': tfds.features.ClassLabel(num_classes=5),
      }),
      # If there's a common `(input, target)` tuple from the features,
      # specify them here. They'll be used if as_supervised=True in
      # builder.as_dataset.
      supervised_keys=('image', 'label'),
      # Specify whether to disable shuffling on the examples. Set to False by default.
      disable_shuffling=False,
  )

大多数字段不言自明。一些精度

编写 BibText CITATIONS.bib 文件

  • 在数据集网站上搜索引用说明(以 BibTex 格式使用)。
  • 对于 arXiv 论文:找到论文并点击右侧的 BibText 链接。
  • Google Scholar 上找到论文,点击标题下方的双引号,在弹出窗口中点击 BibTeX
  • 如果没有关联的论文(例如,只有一个网站),您可以使用 BibTeX 在线编辑器 创建自定义 BibTeX 条目(下拉菜单中有一个 Online 条目类型)。

更新 TAGS.txt 文件

  • 所有允许的标签都已预先填充到生成的文件中。
  • 删除不适用于数据集的所有标签。
  • 有效标签列在 tensorflow_datasets/core/valid_tags.txt 中。
  • 要将标签添加到该列表中,请发送 PR。

维护数据集顺序

默认情况下,数据集的记录在存储时会进行混洗,以使类在整个数据集中的分布更加均匀,因为通常属于同一类的记录是连续的。为了指定数据集应该按 _generate_examples 生成的键进行排序,字段 disable_shuffling 应该设置为 True。默认情况下,它设置为 False

def _info(self):
  return self.dataset_info_from_configs(
    # [...]
    disable_shuffling=True,
    # [...]
  )

请记住,禁用混洗会影响性能,因为分片不再能够并行读取。

_split_generators:下载和分割数据

下载和提取源数据

大多数数据集需要从网络下载数据。这是使用 tfds.download.DownloadManager_split_generators 输入参数完成的。 dl_manager 具有以下方法

  • download:支持 http(s)://ftp(s)://
  • extract:目前支持 .zip.gz.tar 文件。
  • download_and_extract:与 dl_manager.extract(dl_manager.download(urls)) 相同

所有这些方法都返回 tfds.core.Pathepath.Path 的别名),它们是 pathlib.Path-like 对象。

这些方法支持任意嵌套结构(listdict),例如

extracted_paths = dl_manager.download_and_extract({
    'foo': 'https://example.com/foo.zip',
    'bar': 'https://example.com/bar.zip',
})
# This returns:
assert extracted_paths == {
    'foo': Path('/path/to/extracted_foo/'),
    'bar': Path('/path/extracted_bar/'),
}

手动下载和提取

某些数据无法自动下载(例如,需要登录),在这种情况下,用户将手动下载源数据并将其放置在 manual_dir/(默认为 ~/tensorflow_datasets/downloads/manual/)中。

然后可以通过 dl_manager.manual_dir 访问文件

class MyDataset(tfds.core.GeneratorBasedBuilder):

  MANUAL_DOWNLOAD_INSTRUCTIONS = """
  Register into https://example.org/login to get the data. Place the `data.zip`
  file in the `manual_dir/`.
  """

  def _split_generators(self, dl_manager):
    # data_path is a pathlib-like `Path('<manual_dir>/data.zip')`
    archive_path = dl_manager.manual_dir / 'data.zip'
    # Extract the manually downloaded `data.zip`
    extracted_path = dl_manager.extract(archive_path)
    ...

manual_dir 位置可以使用 tfds build --manual_dir= 或使用 tfds.download.DownloadConfig 自定义。

直接读取存档

dl_manager.iter_archive 顺序读取存档,而无需提取它们。这可以节省存储空间并提高某些文件系统的性能。

for filename, fobj in dl_manager.iter_archive('path/to/archive.zip'):
  ...

fobj 具有与 with open('rb') as fobj: 相同的方法(例如 fobj.read()

指定数据集分割

如果数据集带有预定义的分割(例如,MNIST 具有 traintest 分割),请保留这些分割。否则,只指定一个 all 分割。用户可以使用 子分割 API 动态创建自己的子分割(例如,split='train[80%:]')。请注意,除了上述 all 之外,任何字母字符串都可以用作分割名称。

def _split_generators(self, dl_manager):
  # Download source data
  extracted_path = dl_manager.download_and_extract(...)

  # Specify the splits
  return {
      'train': self._generate_examples(
          images_path=extracted_path / 'train_imgs',
          label_path=extracted_path / 'train_labels.csv',
      ),
      'test': self._generate_examples(
          images_path=extracted_path / 'test_imgs',
          label_path=extracted_path / 'test_labels.csv',
      ),
  }

_generate_examples:示例生成器

_generate_examples 从源数据为每个分割生成示例。

此方法通常会读取源数据集工件(例如,CSV 文件)并生成 (key, feature_dict) 元组

  • key:示例标识符。用于使用 hash(key) 确定性地对示例进行混洗,或者在禁用混洗时按键排序(请参阅部分 维护数据集顺序)。应该是
    • 唯一:如果两个示例使用相同的键,则会引发异常。
    • 确定性:不应依赖于 download_diros.path.listdir 顺序等。两次生成数据应该生成相同的键。
    • 可比较:如果禁用混洗,则键将用于对数据集进行排序。
  • feature_dict:包含示例值的 dict
    • 结构应该与 tfds.core.DatasetInfo 中定义的 features= 结构匹配。
    • 复杂数据类型(图像、视频、音频等)将自动编码。
    • 每个特征通常接受多种输入类型(例如,视频接受 /path/to/vid.mp4np.array(shape=(l, h, w, c))List[paths]List[np.array(shape=(h, w, c)]List[img_bytes] 等)
    • 有关更多信息,请参阅 特征连接器指南
def _generate_examples(self, images_path, label_path):
  # Read the input data out of the source files
  with label_path.open() as f:
    for row in csv.DictReader(f):
      image_id = row['image_id']
      # And yield (key, feature_dict)
      yield image_id, {
          'image_description': row['description'],
          'image': images_path / f'{image_id}.jpeg',
          'label': row['label'],
      }

文件访问和 tf.io.gfile

为了支持云存储系统,请避免使用 Python 内置 I/O 操作。

相反,dl_manager 直接返回与 Google Cloud 存储兼容的 pathlib-like 对象

path = dl_manager.download_and_extract('http://some-website/my_data.zip')

json_path = path / 'data/file.json'

json.loads(json_path.read_text())

或者,使用 tf.io.gfile API 而不是内置 API 进行文件操作

Pathlib 应该优先于 tf.io.gfile(请参阅 理由)。

额外依赖项

某些数据集在生成期间需要额外的 Python 依赖项。例如,SVHN 数据集使用 scipy 加载一些数据。

如果您将数据集添加到 TFDS 存储库中,请使用 tfds.core.lazy_imports 来保持 tensorflow-datasets 包的大小。用户将仅在需要时安装额外的依赖项。

要使用 lazy_imports

  • setup.py 中的 DATASET_EXTRAS 中为您的数据集添加一个条目。这样,用户就可以执行以下操作,例如,pip install 'tensorflow-datasets[svhn]' 来安装额外的依赖项。
  • LazyImporterLazyImportsTest 中添加一个导入条目。
  • 使用 tfds.core.lazy_imports 在您的 DatasetBuilder 中访问依赖项(例如,tfds.core.lazy_imports.scipy)。

损坏的数据

某些数据集并不完全干净,包含一些损坏的数据(例如,图像位于 JPEG 文件中,但有些是无效的 JPEG)。这些示例应该跳过,但在数据集描述中留下一个说明,说明有多少示例被删除以及原因。

数据集配置/变体(tfds.core.BuilderConfig)

某些数据集可能有多个变体,或者关于如何预处理数据并将其写入磁盘的选项。例如,cycle_gan 每个对象对有一个配置(cycle_gan/horse2zebracycle_gan/monet2photo 等)。

这是通过 tfds.core.BuilderConfig 完成的

  1. 将您的配置对象定义为 tfds.core.BuilderConfig 的子类。例如,MyDatasetConfig

    @dataclasses.dataclass
    class MyDatasetConfig(tfds.core.BuilderConfig):
      img_size: Tuple[int, int] = (0, 0)
    
  2. MyDataset 中定义 BUILDER_CONFIGS = [] 类成员,该成员列出数据集公开的 MyDatasetConfig

    class MyDataset(tfds.core.GeneratorBasedBuilder):
      VERSION = tfds.core.Version('1.0.0')
      # pytype: disable=wrong-keyword-args
      BUILDER_CONFIGS = [
          # `name` (and optionally `description`) are required for each config
          MyDatasetConfig(name='small', description='Small ...', img_size=(8, 8)),
          MyDatasetConfig(name='big', description='Big ...', img_size=(32, 32)),
      ]
      # pytype: enable=wrong-keyword-args
    
  3. MyDataset 中使用 self.builder_config 来配置数据生成(例如,shape=self.builder_config.img_size)。这可能包括在 _info() 中设置不同的值或更改下载数据访问。

说明

  • 每个配置都有一个唯一的名称。配置的完全限定名称是 dataset_name/config_name(例如,coco/2017)。
  • 如果没有指定,将使用 BUILDER_CONFIGS 中的第一个配置(例如,tfds.load('c4') 默认使用 c4/en

请参阅 anli,了解使用 BuilderConfig 的数据集示例。

版本

版本可以指两种不同的含义

  • "外部"原始数据版本:例如,COCO v2019、v2017 等。
  • "内部" TFDS 代码版本:例如,在 tfds.features.FeaturesDict 中重命名一个特征,修复 _generate_examples 中的错误

要更新数据集

  • 对于“外部”数据更新:多个用户可能希望同时访问特定年份/版本。这是通过为每个版本使用一个 tfds.core.BuilderConfig(例如,coco/2017coco/2019)或为每个版本使用一个类(例如,Voc2007Voc2012)来完成的。
  • 对于“内部”代码更新:用户只下载最新版本。任何代码更新都应该增加 VERSION 类属性(例如,从 1.0.0VERSION = tfds.core.Version('2.0.0')),遵循 语义版本控制

添加用于注册的导入

不要忘记将数据集模块导入到您的项目 __init__ 中,以便在 tfds.loadtfds.builder 中自动注册。

import my_project.datasets.my_dataset  # Register MyDataset

ds = tfds.load('my_dataset')  # MyDataset available

例如,如果您正在为 tensorflow/datasets 做贡献,请将模块导入添加到其子目录的 __init__.py 中(例如,image/__init__.py)。

检查常见的实现问题

请检查 常见的实现问题

测试您的数据集

下载和准备:tfds build

要生成数据集,请从 my_dataset/ 目录运行 tfds build 命令。

cd path/to/datasets/my_dataset/
tfds build --register_checksums

一些开发中常用的标志

  • --pdb: 如果出现异常,则进入调试模式。
  • --overwrite: 如果数据集已生成,则删除现有文件。
  • --max_examples_per_split: 只生成前 X 个示例(默认值为 1),而不是生成完整数据集。
  • --register_checksums: 记录下载 URL 的校验和。仅在开发期间使用。

有关标志的完整列表,请参阅 CLI 文档

校验和

建议记录数据集的校验和,以确保确定性,帮助文档等。这可以通过使用 --register_checksums 标志生成数据集来完成(请参阅上一节)。

如果您通过 PyPI 发布数据集,请不要忘记导出 checksums.tsv 文件(例如,在 setup.pypackage_data 中)。

对数据集进行单元测试

tfds.testing.DatasetBuilderTestCase 是一个用于全面测试数据集的基类 TestCase。它使用“虚拟数据”作为测试数据,这些数据模拟源数据集的结构。

  • 测试数据应放在 my_dataset/dummy_data/ 目录中,并且应模拟下载和解压缩后的源数据集工件。它可以通过脚本手动或自动创建(示例脚本)。
  • 确保在测试数据拆分中使用不同的数据,因为如果数据集拆分重叠,测试将失败。
  • 测试数据不应包含任何版权材料。如有疑问,请勿使用原始数据集中的材料创建数据。
import tensorflow_datasets as tfds
from . import my_dataset_dataset_builder


class MyDatasetTest(tfds.testing.DatasetBuilderTestCase):
  """Tests for my_dataset dataset."""
  DATASET_CLASS = my_dataset_dataset_builder.Builder
  SPLITS = {
      'train': 3,  # Number of fake train example
      'test': 1,  # Number of fake test example
  }

  # If you are calling `download/download_and_extract` with a dict, like:
  #   dl_manager.download({'some_key': 'http://a.org/out.txt', ...})
  # then the tests needs to provide the fake output paths relative to the
  # fake data directory
  DL_EXTRACT_RESULT = {
      'name1': 'path/to/file1',  # Relative to my_dataset/dummy_data dir.
      'name2': 'file2',
  }


if __name__ == '__main__':
  tfds.testing.test_main()

运行以下命令测试数据集。

python my_dataset_test.py

向我们反馈

我们一直在努力改进数据集创建工作流程,但只有在了解问题的情况下才能做到。您在创建数据集时遇到了哪些问题或错误?是否有令人困惑的部分,或者第一次没有正常工作?

请在 GitHub 上分享您的反馈。