FeatureConnector

tfds.features.FeatureConnector API

  • 定义最终 tf.data.Dataset 的结构、形状、数据类型
  • 抽象化到/从磁盘的序列化。
  • 公开其他元数据(例如标签名称、音频采样率,...)

概览

tfds.features.FeatureConnector 定义数据集功能结构(在 tfds.core.DatasetInfo 中)

tfds.core.DatasetInfo(
    features=tfds.features.FeaturesDict({
        'image': tfds.features.Image(shape=(28, 28, 1), doc='Grayscale image'),
        'label': tfds.features.ClassLabel(
            names=['no', 'yes'],
            doc=tfds.features.Documentation(
                desc='Whether this is a picture of a cat',
                value_range='yes or no'
            ),
        ),
        'metadata': {
            'id': tf.int64,
            'timestamp': tfds.features.Scalar(
                tf.int64,
                doc='Timestamp when this picture was taken as seconds since epoch'),
            'language': tf.string,
        },
    }),
)

可以通过仅使用文本描述(doc='description')或直接使用 tfds.features.Documentation 来记录功能,以提供更详细的功能描述。

功能可以是

在生成期间,示例将由 FeatureConnector.encode_example 自动序列化为适合磁盘的格式(当前为 tf.train.Example 协议缓冲区)

yield {
    'image': '/path/to/img0.png',  # `np.array`, file bytes,... also accepted
    'label': 'yes',  # int (0-num_classes) also accepted
    'metadata': {
        'id': 43,
        'language': 'en',
    },
}

在读取数据集时(例如使用 tfds.load),数据将使用 FeatureConnector.decode_example 自动解码。返回的 tf.data.Dataset 将与 tfds.core.DatasetInfo 中定义的 dict 结构匹配

ds = tfds.load(...)
ds.element_spec == {
    'image': tf.TensorSpec(shape=(28, 28, 1), tf.uint8),
    'label': tf.TensorSpec(shape=(), tf.int64),
    'metadata': {
        'id': tf.TensorSpec(shape=(), tf.int64),
        'language': tf.TensorSpec(shape=(), tf.string),
    },
}

序列化/反序列化为 proto

TFDS 公开了一个低级 API,用于将示例序列化/反序列化为 tf.train.Example proto。

要将 dict[np.ndarray | Path | str | ...] 序列化为 proto bytes,请使用 features.serialize_example

with tf.io.TFRecordWriter('path/to/file.tfrecord') as writer:
  for ex in all_exs:
    ex_bytes = features.serialize_example(data)
    f.write(ex_bytes)

要将 proto bytes 反序列化为 tf.Tensor,请使用 features.deserialize_example

ds = tf.data.TFRecordDataset('path/to/file.tfrecord')
ds = ds.map(features.deserialize_example)

访问元数据

参阅简介文档以访问特征元数据(标签名称、形状、数据类型等)。示例

ds, info = tfds.load(..., with_info=True)

info.features['label'].names  # ['cat', 'dog', ...]
info.features['label'].str2int('cat')  # 0

创建自己的tfds.features.FeatureConnector

如果您认为可用特征中缺少某个特征,请打开新问题

要创建自己的特征连接器,您需要从tfds.features.FeatureConnector继承并实现抽象方法。

tfds.features.FeatureConnector对象抽象了特征在磁盘上的编码方式与向用户呈现的方式。以下是显示数据集抽象层以及从原始数据集文件到tf.data.Dataset对象的转换的图表。

DatasetBuilder abstraction layers

要创建自己的特征连接器,请对tfds.features.FeatureConnector进行子类化并实现抽象方法

  • encode_example(data):定义如何将生成器_generate_examples()中给出的数据编码到与tf.train.Example兼容的数据中。可以返回单个值或值的dict
  • decode_example(data):定义如何将从 tf.train.Example 中读取的张量中的数据解码为 tf.data.Dataset 返回的用户张量。
  • get_tensor_info():指示 tf.data.Dataset 返回的张量形状/数据类型。如果从另一个 tfds.features 继承,则可以是可选的。
  • (可选)get_serialized_info():如果 get_tensor_info() 返回的信息与数据实际写入磁盘的方式不同,则需要覆盖 get_serialized_info() 以匹配 tf.train.Example 的规格
  • to_json_content/from_json_content:需要此项才能在没有原始源代码的情况下加载数据集。请参阅 音频功能 以获取示例。

有关更多信息,请查看 tfds.features.FeatureConnector 文档。最好还查看 实际示例