本指南介绍了 tf.Transform
的基本概念以及如何使用它们。它将
- 定义一个预处理函数,它是将原始数据转换为用于训练机器学习模型的数据的管道的逻辑描述。
- 展示用于通过将预处理函数转换为Beam 管道来转换数据的 Apache Beam 实现。
- 展示其他使用示例。
设置
pip install -U tensorflow_transform
pip install pyarrow
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/pkg_resources/__init__.py'>
import os
import tempfile
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
from tfx_bsl.public import tfxio
2023-04-13 09:15:54.685940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-04-13 09:15:54.686060: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-04-13 09:15:54.686073: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
定义预处理函数
预处理函数是 tf.Transform
中最重要的概念。预处理函数是对数据集转换的逻辑描述。预处理函数接受并返回一个张量字典,其中张量表示 Tensor
或 SparseTensor
。有两种类型的函数用于定义预处理函数
- 任何接受并返回张量的函数。这些函数将 TensorFlow 操作添加到图中,将原始数据转换为转换后的数据。
- 由
tf.Transform
提供的任何分析器。分析器也接受并返回张量,但与 TensorFlow 函数不同,它们不会将操作添加到图中。相反,分析器会导致tf.Transform
在 TensorFlow 之外计算一个全通操作。它们使用整个数据集上的输入张量值来生成一个作为输出返回的常量张量。例如,tft.min
计算整个数据集上张量的最小值。tf.Transform
提供了一组固定的分析器,但这将在未来的版本中扩展。
预处理函数示例
通过组合分析器和常规 TensorFlow 函数,用户可以创建用于转换数据的灵活管道。以下预处理函数以不同的方式转换三个特征中的每一个,并将两个特征组合在一起
def preprocessing_fn(inputs):
x = inputs['x']
y = inputs['y']
s = inputs['s']
x_centered = x - tft.mean(x)
y_normalized = tft.scale_to_0_1(y)
s_integerized = tft.compute_and_apply_vocabulary(s)
x_centered_times_y_normalized = x_centered * y_normalized
return {
'x_centered': x_centered,
'y_normalized': y_normalized,
'x_centered_times_y_normalized': x_centered_times_y_normalized,
's_integerized': s_integerized
}
这里,x
、y
和 s
是表示输入特征的 Tensor
。创建的第一个新张量 x_centered
是通过将 tft.mean
应用于 x
并从 x
中减去它来构建的。 tft.mean(x)
返回一个表示张量 x
平均值的张量。 x_centered
是减去平均值的张量 x
。
第二个新张量 y_normalized
是以类似的方式创建的,但使用便利方法 tft.scale_to_0_1
。此方法执行类似于计算 x_centered
的操作,即计算最大值和最小值,并使用它们来缩放 y
。
张量 s_integerized
显示了字符串操作的示例。在这种情况下,我们获取一个字符串并将其映射到一个整数。这使用了便利函数 tft.compute_and_apply_vocabulary
。此函数使用分析器来计算输入字符串所取的唯一值,然后使用 TensorFlow 操作将输入字符串转换为唯一值表中的索引。
最后一列显示了可以使用 TensorFlow 操作通过组合张量来创建新特征。
预处理函数定义了对数据集的操作管道。为了应用管道,我们依赖于 tf.Transform
API 的具体实现。Apache Beam 实现提供了 PTransform
,它将用户的预处理函数应用于数据。 tf.Transform
用户的典型工作流程将构建一个预处理函数,然后将其合并到一个更大的 Beam 管道中,创建用于训练的数据。
批处理
批处理是 TensorFlow 的重要组成部分。由于 tf.Transform
的目标之一是提供一个用于预处理的 TensorFlow 图,该图可以合并到服务图(以及可选的训练图)中,因此批处理也是 tf.Transform
中的一个重要概念。
虽然在上面的示例中并不明显,但用户定义的预处理函数传递的是表示批次的张量,而不是像在 TensorFlow 的训练和服务过程中那样传递单个实例。另一方面,分析器对整个数据集执行计算,返回单个值,而不是一批值。x
是一个形状为 (batch_size,)
的 Tensor
,而 tft.mean(x)
是一个形状为 ()
的 Tensor
。减法 x - tft.mean(x)
进行广播,其中 tft.mean(x)
的值从 x
所表示的批次的每个元素中减去。
Apache Beam 实现
虽然预处理函数旨在作为在多个数据处理框架上实现的预处理管道的逻辑描述,但 tf.Transform
提供了在 Apache Beam 上使用的规范实现。此实现演示了实现所需的函数。此功能没有正式的 API,因此每个实现可以使用与其特定数据处理框架相适应的 API。
Apache Beam 实现提供了两个 PTransform
,用于处理预处理函数的数据。以下显示了复合 PTransform
的用法 - tft_beam.AnalyzeAndTransformDataset
raw_data = [
{'x': 1, 'y': 1, 's': 'hello'},
{'x': 2, 'y': 2, 's': 'world'},
{'x': 3, 'y': 3, 's': 'hello'}
]
raw_data_metadata = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec({
'y': tf.io.FixedLenFeature([], tf.float32),
'x': tf.io.FixedLenFeature([], tf.float32),
's': tf.io.FixedLenFeature([], tf.string),
}))
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_dataset, transform_fn = (
(raw_data, raw_data_metadata) |
tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features. WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_transform/tf_utils.py:324: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Use ref() instead. 2023-04-13 09:15:56.867283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_transform/tf_utils.py:324: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Use ref() instead. WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/88750e1500194862a87b2f23e04367bc/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/88750e1500194862a87b2f23e04367bc/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/8fad0af5a26242cc9733a752a7652277/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/8fad0af5a26242cc9733a752a7652277/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available.
transformed_data, transformed_metadata = transformed_dataset
下面显示了 transformed_data
的内容,其中包含与原始数据格式相同的转换后的列。特别是,s_integerized
的值为 [0, 1, 0]
- 这些值取决于将单词 hello
和 world
映射到整数的方式,这是确定性的。对于列 x_centered
,我们减去了平均值,因此列 x
的值,即 [1.0, 2.0, 3.0]
,变为 [-1.0, 0.0, 1.0]
。类似地,其余列与其预期值匹配。
transformed_data
[{'s_integerized': 0, 'x_centered': -1.0, 'x_centered_times_y_normalized': -0.0, 'y_normalized': 0.0}, {'s_integerized': 1, 'x_centered': 0.0, 'x_centered_times_y_normalized': 0.0, 'y_normalized': 0.5}, {'s_integerized': 0, 'x_centered': 1.0, 'x_centered_times_y_normalized': 1.0, 'y_normalized': 1.0}]
raw_data
和 transformed_data
都是数据集。接下来的两节将展示 Beam 实现如何表示数据集以及如何将数据读写到磁盘。另一个返回值 transform_fn
表示应用于数据的转换,将在下面详细介绍。
tft_beam.AnalyzeAndTransformDataset
类是实现提供的两个基本转换的组合 tft_beam.AnalyzeDataset
和 tft_beam.TransformDataset
。因此,以下两个代码片段是等效的
my_data = (raw_data, raw_data_metadata)
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_data, transform_fn = (
my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/8dc250e431e848a386d53f050ae886df/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/8dc250e431e848a386d53f050ae886df/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/46d2e23e8b9745219e9812f9b7f5aee1/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/46d2e23e8b9745219e9812f9b7f5aee1/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available.
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)
transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/2f6feb69b15d4a429fa4f56dd7fb02a3/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/2f6feb69b15d4a429fa4f56dd7fb02a3/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/26cbcc6000e947c798b5af9ad57c0b42/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/26cbcc6000e947c798b5af9ad57c0b42/assets WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available.
transform_fn
是一个纯函数,表示应用于数据集每一行的操作。特别是,分析器值已经计算出来并被视为常量。在示例中,transform_fn
包含列 x
的平均值、列 y
的最小值和最大值以及用于将字符串映射到整数的词汇表作为常量。
tf.Transform
的一个重要特性是 transform_fn
表示跨行的映射 - 它是一个分别应用于每一行的纯函数。所有用于聚合行的计算都在 AnalyzeDataset
中完成。此外,transform_fn
表示为一个 TensorFlow Graph
,可以嵌入到服务图中。
在这种情况下的优化中提供了 AnalyzeAndTransformDataset
。这与 scikit-learn 中使用的模式相同,提供了 fit
、transform
和 fit_transform
方法。
数据格式和模式
TFT Beam 实现接受两种不同的输入数据格式。“实例字典”格式(如上面的示例以及 simple.ipynb 和 simple_example.py 中所示)是一种直观的格式,适用于小型数据集,而 TFXIO (Apache Arrow) 格式提供更高的性能,适用于大型数据集。
与 PCollection
相关的“元数据”告诉 Beam 实现 PCollection
的格式。
(raw_data, raw_data_metadata) | tft.AnalyzeDataset(...)
- 如果
raw_data_metadata
是一个dataset_metadata.DatasetMetadata
(见下文,“‘实例字典’格式”部分),那么raw_data
预计将采用“实例字典”格式。 - 如果
raw_data_metadata
是一个tfxio.TensorAdapterConfig
(见下文,“TFXIO 格式”部分),那么raw_data
预计将采用 TFXIO 格式。
“实例字典”格式
之前的代码示例使用了这种格式。元数据包含定义数据布局以及如何从各种格式中读取和写入数据的模式。即使这种内存格式也不是自描述的,它需要模式才能被解释为张量。
再次,以下是示例数据的模式定义
import tensorflow_transform as tft
raw_data_metadata = tft.DatasetMetadata.from_feature_spec({
's': tf.io.FixedLenFeature([], tf.string),
'y': tf.io.FixedLenFeature([], tf.float32),
'x': tf.io.FixedLenFeature([], tf.float32),
})
Schema
协议包含从其磁盘或内存格式解析数据到张量所需的信息。它通常通过调用 schema_utils.schema_from_feature_spec
来构建,该函数使用一个字典,将特征键映射到 tf.io.FixedLenFeature
、tf.io.VarLenFeature
和 tf.io.SparseFeature
值。有关更多详细信息,请参阅 tf.parse_example
的文档。
在上面,我们使用 tf.io.FixedLenFeature
来指示每个特征包含固定数量的值,在本例中为单个标量值。因为 tf.Transform
对实例进行批处理,所以表示特征的实际 Tensor
将具有形状 (None,)
,其中未知维度是批次维度。
TFXIO 格式
使用这种格式,数据预计将包含在 pyarrow.RecordBatch
中。对于表格数据,我们的 Apache Beam 实现接受由以下类型列组成的 Arrow RecordBatch
pa.list_(<primitive>)
,其中<primitive>
是pa.int64()
、pa.float32()
pa.binary()
或pa.large_binary()
。pa.large_list(<primitive>)
上面我们使用的玩具输入数据集,当表示为 RecordBatch
时,看起来像这样
import pyarrow as pa
raw_data = [
pa.record_batch(
data=[
pa.array([[1], [2], [3]], pa.list_(pa.float32())),
pa.array([[1], [2], [3]], pa.list_(pa.float32())),
pa.array([['hello'], ['world'], ['hello']], pa.list_(pa.binary())),
],
names=['x', 'y', 's'])
]
类似于与“实例字典”格式相关的 dataset_metadata.DatasetMetadata
实例,tfxio.TensorAdapterConfig
必须与 RecordBatch
相关联。它包含 RecordBatch
的 Arrow 模式,以及 tfxio.TensorRepresentations
,以唯一地确定如何将 RecordBatch
中的列解释为 TensorFlow 张量(包括但不限于 tf.Tensor
、tf.SparseTensor
)。
tfxio.TensorRepresentations
是 Dict[str, tensorflow_metadata.proto.v0.schema_pb2.TensorRepresentation]
的类型别名,它建立了 preprocessing_fn
接受的张量与 RecordBatch
中的列之间的关系。例如
from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2
tensor_representation = {
'x': text_format.Parse(
"""dense_tensor { column_name: "col1" shape { dim { size: 2 } } }""",
schema_pb2.TensorRepresentation())
}
表示 preprocessing_fn
中的 inputs['x']
应该是一个密集的 tf.Tensor
,其值来自输入 RecordBatch
中名为 'col1'
的列,其(批处理)形状应为 [batch_size, 2]
。
schema_pb2.TensorRepresentation
是在 TensorFlow Metadata 中定义的 Protobuf。
与 TensorFlow 的兼容性
tf.Transform
支持将 transform_fn
导出为 SavedModel,请参阅 简单教程 以获取示例。在 0.30
版本发布之前,默认行为是导出 TF 1.x SavedModel。从 0.30
版本开始,默认行为是导出 TF 2.x SavedModel,除非显式禁用 TF 2.x 行为(通过调用 tf.compat.v1.disable_v2_behavior()
)。
如果使用 TF 1.x 概念,例如 tf.estimator
和 tf.Sessions
,您可以通过将 force_tf_compat_v1=True
传递给 tft_beam.Context
来保留之前的行为,如果使用 tf.Transform
作为独立库,或者将 force_tf_compat_v1=True
传递给 TFX 中的 Transform 组件。
将 transform_fn
导出为 TF 2.x SavedModel 时,预计 preprocessing_fn
可以使用 tf.function
进行跟踪。此外,如果在远程运行管道(例如使用 DataflowRunner
),请确保 preprocessing_fn
和任何依赖项都已按 此处 所述正确打包。
使用 tf.Transform
导出 TF 2.x SavedModel 的已知问题记录在 此处。
Apache Beam 的输入和输出
到目前为止,我们已经看到了 python 列表(RecordBatch
或实例字典)中的输入和输出数据。这是一种简化,它依赖于 Apache Beam 的能力,即能够使用列表以及其主要数据表示形式 PCollection
。
PCollection
是构成 Beam 管道一部分的数据表示形式。Beam 管道通过应用各种 PTransform
(包括 AnalyzeDataset
和 TransformDataset
)并运行管道来形成。PCollection
不是在主二进制文件的内存中创建的,而是分布在工作节点之间(尽管本节使用的是内存执行模式)。
预先准备好的 PCollection
源 (TFXIO
)
我们实现中接受的 RecordBatch
格式是一种通用格式,其他 TFX 库也接受这种格式。因此,TFX 提供了便捷的“源”(也称为 TFXIO
),它们可以读取磁盘上各种格式的文件并生成 RecordBatch
,还可以提供 tfxio.TensorAdapterConfig
,包括推断的 tfxio.TensorRepresentations
。
这些 TFXIO
可以在 tfx_bsl
包中找到 (tfx_bsl.public.tfxio
)。
示例:“人口普查收入”数据集
以下示例需要在磁盘上读取和写入数据,并将数据表示为 PCollection
(而不是列表),请参见:census_example.py
。下面我们将展示如何下载数据并运行此示例。“人口普查收入”数据集由 UCI 机器学习库 提供。此数据集包含分类数据和数值数据。
以下是一些用于下载和预览此数据的代码
wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data
--2023-04-13 09:16:10-- https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.203.128, 74.125.141.128, 142.250.98.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.203.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 3974305 (3.8M) [application/octet-stream] Saving to: ‘adult.data’ adult.data 100%[===================>] 3.79M --.-KB/s in 0.02s 2023-04-13 09:16:10 (153 MB/s) - ‘adult.data’ saved [3974305/3974305]
import pandas as pd
train_data_file = "adult.data"
在下面的单元格中隐藏了一些配置代码。
pd.read_csv(train_data_file, names = ORDERED_CSV_COLUMNS).head()
数据集的列要么是分类的,要么是数值的。此数据集描述了一个分类问题:预测最后一列,即个人年收入是否超过 50K 美元。但是,从 tf.Transform
的角度来看,此标签只是另一个分类列。
我们使用预先准备好的 tfxio.BeamRecordCsvTFXIO
将 CSV 行转换为 RecordBatches
。 TFXIO
需要两条重要的信息
- 一个 TensorFlow 元数据模式,
tfmd.proto.v0.shema_pb2
,它包含有关每个 CSV 列的类型和形状信息。schema_pb2.TensorRepresentation
是模式的可选部分;如果未提供(在本示例中就是这种情况),它们将从类型和形状信息中推断出来。可以通过使用我们提供的辅助函数从 TF 解析规范(在本示例中显示)进行转换,或者通过运行 TensorFlow 数据验证 来获取模式。 - 一个列名列表,按它们在 CSV 文件中出现的顺序排列。请注意,这些名称必须与模式中的特征名称匹配。
pip install -U -q tfx_bsl
from tfx_bsl.public import tfxio
from tfx_bsl.coders.example_coder import RecordBatchToExamples
import apache_beam as beam
pipeline = beam.Pipeline()
csv_tfxio = tfxio.BeamRecordCsvTFXIO(
physical_format='text', column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA)
raw_data = (
pipeline
| 'ReadTrainData' >> beam.io.ReadFromText(
train_data_file, coder=beam.coders.BytesCoder())
| 'FixCommasTrainData' >> beam.Map(
lambda line: line.replace(b', ', b','))
| 'DecodeTrainData' >> csv_tfxio.BeamSource())
raw_data
<PCollection[[21]: DecodeTrainData/RawRecordToRecordBatch/CollectRecordBatchTelemetry/ProfileRecordBatches.None] at 0x7feeaa6fd5b0>
请注意,在读取 CSV 行后,我们必须进行一些额外的修复。否则,我们可以依靠 tfxio.CsvTFXIO
来处理读取文件和转换为 RecordBatch
两种操作。
csv_tfxio = tfxio.CsvTFXIO(train_data_file,
telemetry_descriptors=[], #???
column_names=ORDERED_CSV_COLUMNS,
schema=SCHEMA)
p2 = beam.Pipeline()
raw_data_2 = p2 | 'TFXIORead' >> csv_tfxio.BeamSource()
此数据集的预处理与上一个示例类似,只是预处理函数是通过编程生成的,而不是手动指定每个列。在下面的预处理函数中,NUMERICAL_COLUMNS
和 CATEGORICAL_COLUMNS
是包含数值列和分类列名称的列表。
NUM_OOV_BUCKETS = 1
def preprocessing_fn(inputs):
"""Preprocess input columns into transformed columns."""
# Since we are modifying some features and leaving others unchanged, we
# start by setting `outputs` to a copy of `inputs.
outputs = inputs.copy()
# Scale numeric columns to have range [0, 1].
for key in NUMERIC_FEATURE_KEYS:
outputs[key] = tft.scale_to_0_1(outputs[key])
# For all categorical columns except the label column, we generate a
# vocabulary but do not modify the feature. This vocabulary is instead
# used in the trainer, by means of a feature column, to convert the feature
# from a string to an integer id.
for key in CATEGORICAL_FEATURE_KEYS:
outputs[key] = tft.compute_and_apply_vocabulary(
tf.strings.strip(inputs[key]),
num_oov_buckets=NUM_OOV_BUCKETS,
vocab_filename=key)
# For the label column we provide the mapping from string to index.
with tf.init_scope():
# `init_scope` - Only initialize the table once.
initializer = tf.lookup.KeyValueTensorInitializer(
keys=['>50K', '<=50K'],
values=tf.cast(tf.range(2), tf.int64),
key_dtype=tf.string,
value_dtype=tf.int64)
table = tf.lookup.StaticHashTable(initializer, default_value=-1)
outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])
return outputs
与上一个示例的不同之处在于,标签列手动指定了从字符串到索引的映射。因此,'>50'
映射到 0
,而 '<=50K'
映射到 1
,因为了解训练模型中的哪个索引对应于哪个标签非常有用。
record_batches
变量表示 pyarrow.RecordBatch
的 PCollection
。 tensor_adapter_config
由 csv_tfxio
提供,该配置是从 SCHEMA
(最终在本示例中是从 TF 解析规范)推断出来的。
最后一步是将转换后的数据写入磁盘,其形式与读取原始数据类似。用于执行此操作的模式是 tft_beam.AnalyzeAndTransformDataset
的输出的一部分,该输出会为输出数据推断一个模式。写入磁盘的代码如下所示。模式是元数据的一部分,但在 tf.Transform
API 中使用两者互换(即,将元数据传递给 tft.coders.ExampleProtoCoder
)。请注意,这会写入不同的格式。不要使用 textio.WriteToText
,而是使用 Beam 对 TFRecord
格式的内置支持,并使用编码器将数据编码为 Example
协议缓冲区。这是用于训练的更好格式,如下一节所示。 transformed_eval_data_base
提供了要写入的各个分片的基文件名。
raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig())
working_dir = tempfile.mkdtemp()
with tft_beam.Context(temp_dir=working_dir):
transformed_dataset, transform_fn = (
raw_dataset | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn, output_record_batches=True))
output_dir = tempfile.mkdtemp()
transformed_data, _ = transformed_dataset
_ = (
transformed_data
| 'EncodeTrainData' >>
beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))
| 'WriteTrainData' >> beam.io.WriteToTFRecord(
os.path.join(output_dir , 'transformed.tfrecord')))
除了训练数据外,transform_fn
也与元数据一起写入。
_ = (
transform_fn
| 'WriteTransformFn' >> tft_beam.WriteTransformFn(output_dir))
使用 pipeline.run().wait_until_finish()
运行整个 Beam 管道。到目前为止,Beam 管道表示一个延迟的分布式计算。它提供了有关将要执行的操作的说明,但这些说明尚未执行。此最终调用会执行指定的管道。
result = pipeline.run().wait_until_finish()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/c633cd0eb0c14a2bba2bc6f7ba556ce3/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/c633cd0eb0c14a2bba2bc6f7ba556ce3/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/9080e8c73e2443fea34d6505feed4129/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/9080e8c73e2443fea34d6505feed4129/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
运行管道后,输出目录将包含两个工件。
- 转换后的数据及其描述元数据。
- 包含生成的
preprocessing_fn
的tf.saved_model
ls {output_dir}
transform_fn transformed.tfrecord-00000-of-00001 transformed_metadata
要了解如何使用这些工件,请参阅 高级预处理教程。