适用于 Jax 和 PyTorch 的 TFDS

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TFDS 一直以来都是框架无关的。例如,您可以轻松地将数据集加载到 NumPy 格式 中,以便在 Jax 和 PyTorch 中使用。

TensorFlow 及其数据加载解决方案 (tf.data) 在我们的 API 中是设计上的头等公民。

我们扩展了 TFDS 以支持无 TensorFlow 的仅 NumPy 数据加载。这对于在 Jax 和 PyTorch 等 ML 框架中使用非常方便。实际上,对于后者的用户,TensorFlow 可以

  • 保留 GPU/TPU 内存;
  • 增加 CI/CD 中的构建时间;
  • 在运行时需要时间导入。

TensorFlow 不再是读取数据集的依赖项。

ML 管道需要数据加载器来加载示例、解码它们并将它们呈现给模型。数据加载器使用“源/采样器/加载器”范式

 TFDS dataset       ┌────────────────┐
   on disk          │                │
        ┌──────────►│      Data      │
|..|... │     |     │     source     ├─┐
├──┼────┴─────┤     │                │ │
│12│image12   │     └────────────────┘ │    ┌────────────────┐
├──┼──────────┤                        │    │                │
│13│image13   │                        ├───►│      Data      ├───► ML pipeline
├──┼──────────┤                        │    │     loader     │
│14│image14   │     ┌────────────────┐ │    │                │
├──┼──────────┤     │                │ │    └────────────────┘
|..|...       |     │     Index      ├─┘
                    │    sampler     │
                    │                │
                    └────────────────┘
  • 数据源负责从 TFDS 数据集中动态访问和解码示例。
  • 索引采样器负责确定处理记录的顺序。这对于在读取任何记录之前实现全局转换(例如,全局混洗、分片、为多个时期重复)非常重要。
  • 数据加载器通过利用数据源和索引采样器来协调加载。它允许性能优化(例如,预取、多进程或多线程)。

TL;DR

tfds.data_source 是一个用于创建数据源的 API

  1. 用于在纯 Python 管道中快速原型设计;
  2. 以大规模管理数据密集型 ML 管道。

设置

让我们安装并导入所需的依赖项

!pip install array_record
!pip install grain-nightly
!pip install jax jaxlib
!pip install tfds-nightly

import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

数据源

数据源基本上是 Python 序列。因此,它们需要实现以下协议

from typing import SupportsIndex

class RandomAccessDataSource(Protocol):
  """Interface for datasources where storage supports efficient random access."""

  def __len__(self) -> int:
    """Number of records in the dataset."""

  def __getitem__(self, key: SupportsIndex) -> Any:
    """Retrieves the record for the given key."""

底层文件格式需要支持高效的随机访问。目前,TFDS 依赖于 array_record

array_record 是一种源自 Riegeli 的新型文件格式,它开创了 IO 效率的新纪元。特别是,ArrayRecord 支持按记录索引进行并行读取、写入和随机访问。ArrayRecord 建立在 Riegeli 之上,并支持相同的压缩算法。

fashion_mnist 是计算机视觉中常用的数据集。要使用 TFDS 获取基于 ArrayRecord 的数据源,只需使用

ds = tfds.data_source('fashion_mnist')
Downloading and preparing dataset 29.45 MiB (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1...
2024-04-26 11:20:57.419076: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Dataset fashion_mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.

tfds.data_source 是一个方便的包装器。它等效于

builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()

这将输出一个数据源字典

{
  'train': DataSource(name=fashion_mnist, split='train', decoders=None),
  'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}

一旦 download_and_prepare 运行完毕,并且您生成了记录文件,我们就不再需要 TensorFlow 了。所有操作都将在 Python/NumPy 中进行!

让我们通过卸载 TensorFlow 并重新加载另一个子进程中的数据源来检查这一点

pip uninstall -y tensorflow
/usr/lib/python3.9/pty.py:85: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

try:
  import tensorflow as tf
except ImportError:
  print('No TensorFlow found...')

ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')
Writing no_tensorflow.py
python no_tensorflow.py
No TensorFlow found...
...but the data source could still be loaded...
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
...and the records can be decoded.

在未来的版本中,我们还将使数据集准备工作摆脱 TensorFlow 的依赖。

数据源具有长度

len(ds['train'])
60000

访问数据集的第一个元素

%%timeit
ds['train'][0]
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
584 µs ± 2.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

...与访问任何其他元素一样便宜。这是 随机访问 的定义

%%timeit
ds['train'][1000]
581 µs ± 2.33 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

功能现在使用 NumPy DTypes(而不是 TensorFlow DTypes)。您可以使用以下方法检查功能

features = tfds.builder('fashion_mnist').info.features

您将在我们的 文档 中找到有关功能的更多信息。在这里,我们可以显着地检索图像的形状和类别的数量

shape = features['image'].shape
num_classes = features['label'].num_classes

在纯 Python 中使用

您可以通过迭代数据源来在 Python 中使用它们

for example in ds['train']:
  print(example)
  break
{'image': array([[[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 18],
        [ 77],
        [227],
        [227],
        [208],
        [210],
        [225],
        [216],
        [ 85],
        [ 32],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 61],
        [100],
        [ 97],
        [ 80],
        [ 57],
        [117],
        [227],
        [238],
        [115],
        [ 49],
        [ 78],
        [106],
        [108],
        [ 71],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 81],
        [105],
        [ 80],
        [ 69],
        [ 72],
        [ 64],
        [ 44],
        [ 21],
        [ 13],
        [ 44],
        [ 69],
        [ 75],
        [ 75],
        [ 80],
        [114],
        [ 80],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 26],
        [ 92],
        [ 69],
        [ 68],
        [ 75],
        [ 75],
        [ 71],
        [ 74],
        [ 83],
        [ 75],
        [ 77],
        [ 78],
        [ 74],
        [ 74],
        [ 83],
        [ 77],
        [108],
        [ 34],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 55],
        [ 92],
        [ 69],
        [ 74],
        [ 74],
        [ 71],
        [ 71],
        [ 77],
        [ 69],
        [ 66],
        [ 75],
        [ 74],
        [ 77],
        [ 80],
        [ 80],
        [ 78],
        [ 94],
        [ 63],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 63],
        [ 95],
        [ 66],
        [ 68],
        [ 72],
        [ 72],
        [ 69],
        [ 72],
        [ 74],
        [ 74],
        [ 74],
        [ 75],
        [ 75],
        [ 77],
        [ 80],
        [ 77],
        [106],
        [ 61],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 80],
        [108],
        [ 71],
        [ 69],
        [ 72],
        [ 71],
        [ 69],
        [ 72],
        [ 75],
        [ 75],
        [ 72],
        [ 72],
        [ 75],
        [ 78],
        [ 72],
        [ 85],
        [128],
        [ 64],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 88],
        [120],
        [ 75],
        [ 74],
        [ 77],
        [ 75],
        [ 72],
        [ 77],
        [ 74],
        [ 74],
        [ 77],
        [ 78],
        [ 83],
        [ 83],
        [ 66],
        [111],
        [123],
        [ 78],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 85],
        [134],
        [ 74],
        [ 85],
        [ 69],
        [ 75],
        [ 75],
        [ 74],
        [ 75],
        [ 74],
        [ 75],
        [ 75],
        [ 81],
        [ 75],
        [ 61],
        [151],
        [115],
        [ 91],
        [ 12],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 10],
        [ 85],
        [153],
        [ 83],
        [ 80],
        [ 68],
        [ 77],
        [ 75],
        [ 74],
        [ 75],
        [ 74],
        [ 75],
        [ 77],
        [ 80],
        [ 68],
        [ 61],
        [162],
        [122],
        [ 78],
        [  6],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 30],
        [ 75],
        [154],
        [ 85],
        [ 80],
        [ 71],
        [ 80],
        [ 72],
        [ 77],
        [ 75],
        [ 75],
        [ 77],
        [ 78],
        [ 77],
        [ 75],
        [ 49],
        [191],
        [132],
        [ 72],
        [ 15],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 58],
        [ 66],
        [174],
        [115],
        [ 66],
        [ 77],
        [ 80],
        [ 72],
        [ 78],
        [ 75],
        [ 77],
        [ 78],
        [ 78],
        [ 77],
        [ 66],
        [ 49],
        [222],
        [131],
        [ 77],
        [ 37],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 69],
        [ 55],
        [179],
        [139],
        [ 55],
        [ 92],
        [ 74],
        [ 74],
        [ 78],
        [ 74],
        [ 78],
        [ 77],
        [ 75],
        [ 80],
        [ 64],
        [ 55],
        [242],
        [111],
        [ 95],
        [ 44],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 74],
        [ 57],
        [159],
        [180],
        [ 55],
        [ 92],
        [ 64],
        [ 72],
        [ 74],
        [ 74],
        [ 77],
        [ 75],
        [ 77],
        [ 78],
        [ 55],
        [ 66],
        [255],
        [ 97],
        [108],
        [ 49],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 74],
        [ 66],
        [145],
        [153],
        [ 72],
        [ 83],
        [ 58],
        [ 78],
        [ 77],
        [ 75],
        [ 75],
        [ 75],
        [ 72],
        [ 80],
        [ 30],
        [132],
        [255],
        [ 37],
        [122],
        [ 60],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 80],
        [ 69],
        [142],
        [180],
        [142],
        [ 57],
        [ 64],
        [ 78],
        [ 74],
        [ 75],
        [ 75],
        [ 75],
        [ 72],
        [ 85],
        [ 21],
        [185],
        [227],
        [ 37],
        [143],
        [ 63],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [ 83],
        [ 71],
        [136],
        [194],
        [126],
        [ 46],
        [ 69],
        [ 75],
        [ 72],
        [ 75],
        [ 75],
        [ 75],
        [ 74],
        [ 78],
        [ 38],
        [139],
        [185],
        [ 60],
        [151],
        [ 58],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  4],
        [ 81],
        [ 74],
        [145],
        [177],
        [ 78],
        [ 49],
        [ 74],
        [ 77],
        [ 75],
        [ 75],
        [ 75],
        [ 75],
        [ 74],
        [ 72],
        [ 63],
        [ 80],
        [156],
        [117],
        [153],
        [ 55],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 10],
        [ 80],
        [ 72],
        [157],
        [163],
        [ 61],
        [ 55],
        [ 75],
        [ 77],
        [ 75],
        [ 77],
        [ 75],
        [ 75],
        [ 75],
        [ 77],
        [ 71],
        [ 60],
        [ 98],
        [156],
        [132],
        [ 58],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 13],
        [ 77],
        [ 74],
        [157],
        [143],
        [ 43],
        [ 61],
        [ 72],
        [ 75],
        [ 77],
        [ 75],
        [ 74],
        [ 77],
        [ 77],
        [ 75],
        [ 71],
        [ 58],
        [ 80],
        [157],
        [120],
        [ 66],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 18],
        [ 81],
        [ 74],
        [156],
        [114],
        [ 35],
        [ 72],
        [ 71],
        [ 75],
        [ 78],
        [ 72],
        [ 66],
        [ 80],
        [ 78],
        [ 77],
        [ 75],
        [ 64],
        [ 63],
        [165],
        [119],
        [ 68],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 23],
        [ 85],
        [ 81],
        [177],
        [ 57],
        [ 52],
        [ 77],
        [ 71],
        [ 78],
        [ 80],
        [ 72],
        [ 75],
        [ 74],
        [ 77],
        [ 77],
        [ 75],
        [ 64],
        [ 37],
        [173],
        [ 95],
        [ 72],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 26],
        [ 81],
        [ 86],
        [160],
        [ 20],
        [ 75],
        [ 77],
        [ 77],
        [ 80],
        [ 78],
        [ 80],
        [ 89],
        [ 78],
        [ 81],
        [ 83],
        [ 80],
        [ 74],
        [ 20],
        [177],
        [ 77],
        [ 74],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 49],
        [ 77],
        [ 91],
        [200],
        [  0],
        [ 83],
        [ 95],
        [ 86],
        [ 88],
        [ 88],
        [ 89],
        [ 88],
        [ 89],
        [ 88],
        [ 83],
        [ 89],
        [ 86],
        [  0],
        [191],
        [ 78],
        [ 80],
        [ 24],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 54],
        [ 71],
        [108],
        [165],
        [  0],
        [ 24],
        [ 57],
        [ 52],
        [ 57],
        [ 60],
        [ 60],
        [ 60],
        [ 63],
        [ 63],
        [ 77],
        [ 89],
        [ 52],
        [  0],
        [211],
        [ 97],
        [ 77],
        [ 61],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 68],
        [ 91],
        [117],
        [137],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 18],
        [216],
        [ 94],
        [ 97],
        [ 57],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 54],
        [115],
        [105],
        [185],
        [  0],
        [  0],
        [  1],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [153],
        [ 78],
        [106],
        [ 37],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [ 18],
        [ 61],
        [ 41],
        [103],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [106],
        [ 47],
        [ 69],
        [ 23],
        [  0],
        [  0],
        [  0]]], dtype=uint8), 'label': 2}

如果您检查元素,您还会注意到所有功能都已使用 NumPy 解码。在幕后,我们默认使用 OpenCV,因为它速度很快。如果您没有安装 OpenCV,我们将默认使用 Pillow 来提供轻量级且快速的图像解码。

{
  'image': array([[[0], [0], ..., [0]],
                  [[0], [0], ..., [0]]], dtype=uint8),
  'label': 2,
}

与 PyTorch 一起使用

PyTorch 使用源/采样器/加载器范式。在 Torch 中,“数据源”被称为“数据集”。torch.utils.data 包含构建 Torch 中高效输入管道所需的所有详细信息。

TFDS 数据源可以用作常规的 映射式数据集

首先,我们安装并导入 Torch

!pip install torch

from tqdm import tqdm
import torch

我们已经为训练和测试定义了数据源(分别为 ds['train']ds['test'])。现在我们可以定义采样器和加载器

batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
    ds['train'],
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    ds['test'],
    sampler=None,
    batch_size=batch_size,
)

使用 PyTorch,我们在前几个示例上训练和评估一个简单的逻辑回归

class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)


model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
Training...
100%|██████████| 40/40 [00:01<00:00, 31.22it/s]
Testing...
100%|██████████| 79/79 [00:02<00:00, 32.66it/s]
Accuracy: 65.63%

与 JAX 一起使用

Grain 是一个用于读取数据以训练和评估 JAX 模型的库。它是开源的、快速的和确定性的。Grain 使用源/采样器/加载器范式,因此我们可以重复使用 tfds.data_source

import grain.python as pygrain
import numpy as np

data_source = tfds.data_source("fashion_mnist", split="train")

# To shuffle the data, use a sampler:
sampler = pygrain.IndexSampler(
    num_records=5,
    num_epochs=1,
    shard_options=pygrain.NoSharding(),
    shuffle=True,
    seed=0,
)

转换被定义为类,可以是 BatchTransformFilterTransformMapTransform

class ImageToText(pygrain.MapTransform):
  """Maps an image to text."""

  LABEL_TO_TEXT = {
      0: "zero",
      1: "one",
      2: "two",
      3: "three",
      4: "four",
      5: "five",
      6: "six",
      7: "seven",
      8: "height",
      9: "nine",
  }

  def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    label = element["label"]
    text = self.LABEL_TO_TEXT[label]
    element["text"] = text
    return element

# You can chain transformations in a list:
operations = [ImageToText()]

最后,数据加载器负责协调加载。您可以使用多处理进行扩展,以享受 Python 的灵活性和数据加载器的性能

loader = pygrain.DataLoader(
    data_source=data_source,
    operations=operations,
    sampler=sampler,
    worker_count=0,  # Scale to multiple workers in multiprocessing
)

for element in loader:
  print(element["text"])
two
one
one
height
four

阅读更多

有关更多信息,请参阅 tfds.data_source API 文档。