TFX Estimator 组件教程

TensorFlow Extended (TFX) 的组件介绍

本基于 Colab 的教程将以交互方式逐步介绍 TensorFlow Extended (TFX) 的每个内置组件。

它涵盖了端到端机器学习管道中的每个步骤,从数据摄取到将模型推送到服务。

完成后,本笔记本的内容可以自动导出为 TFX 管道源代码,您可以使用 Apache Airflow 和 Apache Beam 来协调这些代码。

背景

本笔记本演示了如何在 Jupyter/Colab 环境中使用 TFX。在这里,我们以交互式笔记本的形式逐步介绍芝加哥出租车示例。

在交互式笔记本中工作是熟悉 TFX 管道结构的一种有用方法。它在开发您自己的管道时作为轻量级开发环境也很有用,但您应该注意,交互式笔记本的协调方式以及它们访问元数据工件的方式存在差异。

协调

在 TFX 的生产部署中,您将使用 Apache Airflow、Kubeflow Pipelines 或 Apache Beam 等协调器来协调预定义的 TFX 组件管道图。在交互式笔记本中,笔记本本身就是协调器,在您执行笔记本单元时运行每个 TFX 组件。

元数据

在 TFX 的生产部署中,您将通过 ML 元数据 (MLMD) API 访问元数据。MLMD 将元数据属性存储在 MySQL 或 SQLite 等数据库中,并将元数据有效负载存储在持久性存储中,例如您的文件系统上。在交互式笔记本中,属性和有效负载都存储在 Jupyter 笔记本或 Colab 服务器上的 /tmp 目录中的一个短暂的 SQLite 数据库中。

设置

首先,我们安装并导入必要的包,设置路径并下载数据。

升级 Pip

为了避免在本地运行时升级系统中的 Pip,请检查我们是否在 Colab 中运行。本地系统当然可以单独升级。

try:
  import colab
  !pip install --upgrade pip
except:
  pass

安装 TFX

pip install tfx

您是否重新启动了运行时?

如果您使用的是 Google Colab,那么第一次运行上面的单元格时,您必须重新启动运行时(运行时 > 重新启动运行时...)。这是因为 Colab 加载包的方式。

导入包

我们导入必要的包,包括标准 TFX 组件类。

import os
import pprint
import tempfile
import urllib

import absl
import tensorflow as tf
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False
pp = pprint.PrettyPrinter()

from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip
2024-05-08 09:28:45.211472: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-08 09:28:45.211527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-08 09:28:45.213202: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

让我们检查一下库版本。

print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
TensorFlow version: 2.15.1
TFX version: 1.15.0

设置管道路径

# This is the root directory for your TFX pip package installation.
_tfx_root = tfx.__path__[0]

# This is the directory containing the TFX Chicago Taxi Pipeline example.
_taxi_root = os.path.join(_tfx_root, 'examples/chicago_taxi_pipeline')

# This is the path where your model will be pushed for serving.
_serving_model_dir = os.path.join(
    tempfile.mkdtemp(), 'serving_model/taxi_simple')

# Set up logging.
absl.logging.set_verbosity(absl.logging.INFO)

下载示例数据

我们下载示例数据集以供我们的 TFX 管道使用。

我们使用的数据集是芝加哥市发布的 出租车行程数据集。此数据集中的列为

pickup_community_areafaretrip_start_month
trip_start_hourtrip_start_daytrip_start_timestamp
pickup_latitudepickup_longitude下车纬度
下车经度行程里程上车人口普查区
下车人口普查区支付方式公司
行程时长(秒)下车社区区域小费

使用此数据集,我们将构建一个模型来预测行程的 小费

_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/chicago_taxi_pipeline/data/simple/data.csv'
_data_filepath = os.path.join(_data_root, "data.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)
('/tmpfs/tmp/tfx-datafzazc6h4/data.csv',
 <http.client.HTTPMessage at 0x7f393fe02ee0>)

快速浏览一下 CSV 文件。

head {_data_filepath}
pickup_community_area,fare,trip_start_month,trip_start_hour,trip_start_day,trip_start_timestamp,pickup_latitude,pickup_longitude,dropoff_latitude,dropoff_longitude,trip_miles,pickup_census_tract,dropoff_census_tract,payment_type,company,trip_seconds,dropoff_community_area,tips
,12.45,5,19,6,1400269500,,,,,0.0,,,Credit Card,Chicago Elite Cab Corp. (Chicago Carriag,0,,0.0
,0,3,19,5,1362683700,,,,,0,,,Unknown,Chicago Elite Cab Corp.,300,,0
60,27.05,10,2,3,1380593700,41.836150155,-87.648787952,,,12.6,,,Cash,Taxi Affiliation Services,1380,,0.0
10,5.85,10,1,2,1382319000,41.985015101,-87.804532006,,,0.0,,,Cash,Taxi Affiliation Services,180,,0.0
14,16.65,5,7,5,1369897200,41.968069,-87.721559063,,,0.0,,,Cash,Dispatch Taxi Affiliation,1080,,0.0
13,16.45,11,12,3,1446554700,41.983636307,-87.723583185,,,6.9,,,Cash,,780,,0.0
16,32.05,12,1,1,1417916700,41.953582125,-87.72345239,,,15.4,,,Cash,,1200,,0.0
30,38.45,10,10,5,1444301100,41.839086906,-87.714003807,,,14.6,,,Cash,,2580,,0.0
11,14.65,1,1,3,1358213400,41.978829526,-87.771166703,,,5.81,,,Cash,,1080,,0.0

免责声明:本网站提供使用从原始来源 www.cityofchicago.org(芝加哥市官方网站)修改后的数据的应用程序。芝加哥市不对本网站提供的任何数据的內容、准确性、及时性或完整性做出任何声明。本网站提供的数据随时可能发生变化。了解到本网站提供的数据的使用风险自负。

创建 InteractiveContext

最后,我们创建一个 InteractiveContext,它将允许我们在该笔记本中以交互方式运行 TFX 组件。

# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext. Calls to InteractiveContext are no-ops outside of the
# notebook.
context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/metadata.sqlite.

以交互方式运行 TFX 组件

在接下来的单元格中,我们将一个接一个地创建 TFX 组件,运行每个组件,并可视化它们的输出工件。

ExampleGen

The ExampleGen 组件通常位于 TFX 管道的开头。它将

  1. 将数据分成训练集和评估集(默认情况下,2/3 训练 + 1/3 评估)
  2. 将数据转换为 tf.Example 格式(了解更多信息 此处
  3. 将数据复制到 _tfx_root 目录中,以便其他组件可以访问

ExampleGen 将数据源路径作为输入。在我们的例子中,这是包含已下载 CSV 的 _data_root 路径。

example_gen = tfx.components.CsvExampleGen(input_base=_data_root)
context.run(example_gen)
INFO:absl:Running driver for CsvExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:Running executor for CsvExampleGen
INFO:absl:Generating examples.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
INFO:absl:Processing input csv data /tmpfs/tmp/tfx-datafzazc6h4/* to TFExample.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
INFO:absl:Examples generated.
INFO:absl:Running publisher for CsvExampleGen
INFO:absl:MetadataStore with DB connection initialized

让我们检查一下 ExampleGen 的输出工件。该组件生成两个工件,训练示例和评估示例

artifact = example_gen.outputs['examples'].get()[0]
print(artifact.split_names, artifact.uri)
["train", "eval"] /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/CsvExampleGen/examples/1

我们还可以查看前三个训练示例

# Get the URI of the output artifact representing the training examples, which is a directory
train_uri = os.path.join(example_gen.outputs['examples'].get()[0].uri, 'Split-train')

# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]

# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

# Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example()
  example.ParseFromString(serialized_example)
  pp.pprint(example)
features {
  feature {
    key: "company"
    value {
      bytes_list {
        value: "Chicago Elite Cab Corp. (Chicago Carriag"
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 12.449999809265137
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      bytes_list {
        value: "Credit Card"
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "tips"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 6
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 19
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 5
      }
    }
  }
  feature {
    key: "trip_start_timestamp"
    value {
      int64_list {
        value: 1400269500
      }
    }
  }
}

features {
  feature {
    key: "company"
    value {
      bytes_list {
        value: "Taxi Affiliation Services"
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 27.049999237060547
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      bytes_list {
        value: "Cash"
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 60
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      float_list {
        value: 41.836151123046875
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      float_list {
        value: -87.64878845214844
      }
    }
  }
  feature {
    key: "tips"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 12.600000381469727
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      int64_list {
        value: 1380
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 2
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 10
      }
    }
  }
  feature {
    key: "trip_start_timestamp"
    value {
      int64_list {
        value: 1380593700
      }
    }
  }
}

features {
  feature {
    key: "company"
    value {
      bytes_list {
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 16.450000762939453
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      bytes_list {
        value: "Cash"
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 13
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      float_list {
        value: 41.98363494873047
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      float_list {
        value: -87.72357940673828
      }
    }
  }
  feature {
    key: "tips"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 6.900000095367432
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      int64_list {
        value: 780
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 12
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 11
      }
    }
  }
  feature {
    key: "trip_start_timestamp"
    value {
      int64_list {
        value: 1446554700
      }
    }
  }
}

现在 ExampleGen 已完成数据摄取,下一步是数据分析。

StatisticsGen

The StatisticsGen 组件计算数据集上的统计信息,用于数据分析,以及在后续组件中使用。它使用 TensorFlow 数据验证 库。

StatisticsGen 将我们刚刚使用 ExampleGen 摄取的数据集作为输入。

statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for StatisticsGen
INFO:absl:Generating statistics for split train.
INFO:absl:Statistics for split train written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/StatisticsGen/statistics/2/Split-train.
INFO:absl:Generating statistics for split eval.
INFO:absl:Statistics for split eval written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/StatisticsGen/statistics/2/Split-eval.
INFO:absl:Running publisher for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized

StatisticsGen 完成运行后,我们可以可视化输出的统计信息。尝试使用不同的绘图!

context.show(statistics_gen.outputs['statistics'])

SchemaGen

The SchemaGen 组件根据您的数据统计信息生成一个模式。(模式定义了数据集中的特征的预期边界、类型和属性。)它还使用 TensorFlow 数据验证 库。

SchemaGen 将使用我们使用 StatisticsGen 生成的统计信息作为输入,默认情况下查看训练拆分。

schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(schema_gen)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for SchemaGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for SchemaGen
INFO:absl:Processing schema from statistics for split train.
INFO:absl:Processing schema from statistics for split eval.
INFO:absl:Schema written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/SchemaGen/schema/3/schema.pbtxt.
INFO:absl:Running publisher for SchemaGen
INFO:absl:MetadataStore with DB connection initialized

SchemaGen 完成运行后,我们可以将生成的模式可视化为表格。

context.show(schema_gen.outputs['schema'])

数据集中的每个特征都显示为模式表中的一行,以及其属性。模式还捕获分类特征所采用的所有值,表示为其域。

要了解有关模式的更多信息,请参见 SchemaGen 文档

ExampleValidator

The ExampleValidator 组件根据模式定义的期望检测数据中的异常。它还使用 TensorFlow 数据验证 库。

ExampleValidator 将使用来自 StatisticsGen 的统计信息以及来自 SchemaGen 的模式作为输入。

example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(example_validator)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for ExampleValidator
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for ExampleValidator
INFO:absl:Validating schema against the computed statistics for split train.
INFO:absl:Anomalies alerts created for split train.
INFO:absl:Validation complete for split train. Anomalies written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/ExampleValidator/anomalies/4/Split-train.
INFO:absl:Validating schema against the computed statistics for split eval.
INFO:absl:Anomalies alerts created for split eval.
INFO:absl:Validation complete for split eval. Anomalies written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/ExampleValidator/anomalies/4/Split-eval.
INFO:absl:Running publisher for ExampleValidator
INFO:absl:MetadataStore with DB connection initialized

ExampleValidator 完成运行后,我们可以将异常可视化为表格。

context.show(example_validator.outputs['anomalies'])

在异常表中,我们可以看到没有异常。这是我们所期望的,因为这是我们分析的第一个数据集,并且模式是根据它定制的。您应该审查此模式 - 任何意外情况都意味着数据中的异常。审查后,模式可用于保护未来的数据,并且此处产生的异常可用于调试模型性能,了解数据如何随时间推移而演变,以及识别数据错误。

转换

The Transform 组件对训练和服务执行特征工程。它使用 TensorFlow Transform 库。

Transform 将使用来自 ExampleGen 的数据、来自 SchemaGen 的模式以及包含用户定义的 Transform 代码的模块作为输入。

让我们在下面查看用户定义的 Transform 代码的示例(有关 TensorFlow Transform API 的介绍,请参阅教程)。首先,我们定义一些用于特征工程的常量

_taxi_constants_module_file = 'taxi_constants.py'
%%writefile {_taxi_constants_module_file}

# Categorical features are assumed to each have a maximum value in the dataset.
MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12]

CATEGORICAL_FEATURE_KEYS = [
    'trip_start_hour', 'trip_start_day', 'trip_start_month',
    'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area',
    'dropoff_community_area'
]

DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds']

# Number of buckets used by tf.transform for encoding each feature.
FEATURE_BUCKET_COUNT = 10

BUCKET_FEATURE_KEYS = [
    'pickup_latitude', 'pickup_longitude', 'dropoff_latitude',
    'dropoff_longitude'
]

# Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform
VOCAB_SIZE = 1000

# Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed.
OOV_SIZE = 10

VOCAB_FEATURE_KEYS = [
    'payment_type',
    'company',
]

# Keys
LABEL_KEY = 'tips'
FARE_KEY = 'fare'
Writing taxi_constants.py

接下来,我们编写一个 preprocessing_fn,它将原始数据作为输入,并返回模型可以训练的转换后的特征

_taxi_transform_module_file = 'taxi_transform.py'
%%writefile {_taxi_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft

import taxi_constants

_DENSE_FLOAT_FEATURE_KEYS = taxi_constants.DENSE_FLOAT_FEATURE_KEYS
_VOCAB_FEATURE_KEYS = taxi_constants.VOCAB_FEATURE_KEYS
_VOCAB_SIZE = taxi_constants.VOCAB_SIZE
_OOV_SIZE = taxi_constants.OOV_SIZE
_FEATURE_BUCKET_COUNT = taxi_constants.FEATURE_BUCKET_COUNT
_BUCKET_FEATURE_KEYS = taxi_constants.BUCKET_FEATURE_KEYS
_CATEGORICAL_FEATURE_KEYS = taxi_constants.CATEGORICAL_FEATURE_KEYS
_FARE_KEY = taxi_constants.FARE_KEY
_LABEL_KEY = taxi_constants.LABEL_KEY


def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.
  Args:
    inputs: map from feature keys to raw not-yet-transformed features.
  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}
  for key in _DENSE_FLOAT_FEATURE_KEYS:
    # If sparse make it dense, setting nan's to 0 or '', and apply zscore.
    outputs[key] = tft.scale_to_z_score(
        _fill_in_missing(inputs[key]))

  for key in _VOCAB_FEATURE_KEYS:
    # Build a vocabulary for this feature.
    outputs[key] = tft.compute_and_apply_vocabulary(
        _fill_in_missing(inputs[key]),
        top_k=_VOCAB_SIZE,
        num_oov_buckets=_OOV_SIZE)

  for key in _BUCKET_FEATURE_KEYS:
    outputs[key] = tft.bucketize(
        _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT)

  for key in _CATEGORICAL_FEATURE_KEYS:
    outputs[key] = _fill_in_missing(inputs[key])

  # Was this passenger a big tipper?
  taxi_fare = _fill_in_missing(inputs[_FARE_KEY])
  tips = _fill_in_missing(inputs[_LABEL_KEY])
  outputs[_LABEL_KEY] = tf.where(
      tf.math.is_nan(taxi_fare),
      tf.cast(tf.zeros_like(taxi_fare), tf.int64),
      # Test if the tip was > 20% of the fare.
      tf.cast(
          tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64))

  return outputs


def _fill_in_missing(x):
  """Replace missing values in a SparseTensor.
  Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.
  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  if not isinstance(x, tf.sparse.SparseTensor):
    return x

  default_value = '' if x.dtype == tf.string else 0
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)
Writing taxi_transform.py

现在,我们将此特征工程代码传递给 Transform 组件并运行它以转换您的数据。

transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_taxi_transform_module_file))
context.run(transform)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/taxi_transform.py' (including modules: ['taxi_transform', 'taxi_constants']).
INFO:absl:User module package has hash fingerprint version f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmpagxnedeh/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpmro2o5jq', '--dist-dir', '/tmpfs/tmp/tmpqxoax40b']
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl'; target user module is 'taxi_transform'.
INFO:absl:Full user module path is 'taxi_transform@/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl'
INFO:absl:Running driver for Transform
INFO:absl:MetadataStore with DB connection initialized
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying taxi_transform.py -> build/lib
copying taxi_constants.py -> build/lib
installing to /tmpfs/tmp/tmpmro2o5jq
running install
running install_lib
copying build/lib/taxi_transform.py -> /tmpfs/tmp/tmpmro2o5jq
copying build/lib/taxi_constants.py -> /tmpfs/tmp/tmpmro2o5jq
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpmro2o5jq/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpmro2o5jq/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424.dist-info/WHEEL
creating '/tmpfs/tmp/tmpqxoax40b/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl' and adding '/tmpfs/tmp/tmpmro2o5jq' to it
adding 'taxi_constants.py'
adding 'taxi_transform.py'
adding 'tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424.dist-info/RECORD'
removing /tmpfs/tmp/tmpmro2o5jq
INFO:absl:Running executor for Transform
INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'taxi_transform@/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp5nkxki37', '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl']
Processing /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl'.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'taxi_transform@/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp5jgr4gdg', '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424
Processing /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl'.
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmppif0r6r3', '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424
Processing /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424-py3-none-any.whl'.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+f78e5f6b4988b5d5289aab277eceaff03bd38343154c2f602e06d95c6acd5424
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary_1/apply_vocab/text_file_init/InitializeTableFromTextFileV2
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary_1/apply_vocab/text_file_init/InitializeTableFromTextFileV2
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Transform/transform_graph/5/.temp_path/tftransform_tmp/27474b1a835b4a51961d9ff1455e2a16/assets
INFO:absl:Writing fingerprint to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Transform/transform_graph/5/.temp_path/tftransform_tmp/27474b1a835b4a51961d9ff1455e2a16/fingerprint.pb
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Transform/transform_graph/5/.temp_path/tftransform_tmp/ff049abe3d52450a806a6fe8069cd573/assets
INFO:absl:Writing fingerprint to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Transform/transform_graph/5/.temp_path/tftransform_tmp/ff049abe3d52450a806a6fe8069cd573/fingerprint.pb
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Running publisher for Transform
INFO:absl:MetadataStore with DB connection initialized

让我们检查一下 Transform 的输出工件。该组件生成两种类型的输出

  • transform_graph 是可以执行预处理操作的图(此图将包含在服务和评估模型中)。
  • transformed_examples 表示预处理的训练和评估数据。
transform.outputs
{'transform_graph': OutputChannel(artifact_type=TransformGraph, producer_component_id=Transform, output_key=transform_graph, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'transformed_examples': OutputChannel(artifact_type=Examples, producer_component_id=Transform, output_key=transformed_examples, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'updated_analyzer_cache': OutputChannel(artifact_type=TransformCache, producer_component_id=Transform, output_key=updated_analyzer_cache, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'pre_transform_schema': OutputChannel(artifact_type=Schema, producer_component_id=Transform, output_key=pre_transform_schema, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'pre_transform_stats': OutputChannel(artifact_type=ExampleStatistics, producer_component_id=Transform, output_key=pre_transform_stats, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'post_transform_schema': OutputChannel(artifact_type=Schema, producer_component_id=Transform, output_key=post_transform_schema, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'post_transform_stats': OutputChannel(artifact_type=ExampleStatistics, producer_component_id=Transform, output_key=post_transform_stats, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'post_transform_anomalies': OutputChannel(artifact_type=ExampleAnomalies, producer_component_id=Transform, output_key=post_transform_anomalies, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False)}

看一下 transform_graph 工件。它指向包含三个子目录的目录。

train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)
['transform_fn', 'metadata', 'transformed_metadata']

The transformed_metadata 子目录包含预处理数据的模式。The transform_fn 子目录包含实际的预处理图。The metadata 子目录包含原始数据的模式。

我们还可以查看前三个转换后的示例

# Get the URI of the output artifact representing the transformed examples, which is a directory
train_uri = os.path.join(transform.outputs['transformed_examples'].get()[0].uri, 'Split-train')

# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]

# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

# Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example()
  example.ParseFromString(serialized_example)
  pp.pprint(example)
features {
  feature {
    key: "company"
    value {
      int64_list {
        value: 8
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      int64_list {
        value: 9
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 0.061060599982738495
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      int64_list {
        value: 9
      }
    }
  }
  feature {
    key: "tips"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: -0.15886740386486053
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      float_list {
        value: -0.7118487358093262
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 6
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 19
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 5
      }
    }
  }
}

features {
  feature {
    key: "company"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      int64_list {
        value: 9
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 1.2521240711212158
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 60
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "tips"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 0.532160758972168
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      float_list {
        value: 0.5509493350982666
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 2
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 10
      }
    }
  }
}

features {
  feature {
    key: "company"
    value {
      int64_list {
        value: 48
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      int64_list {
        value: 9
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 0.3873794376850128
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 13
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      int64_list {
        value: 9
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "tips"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 0.21955278515815735
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      float_list {
        value: 0.0019067146349698305
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 12
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 11
      }
    }
  }
}

Transform 组件将您的数据转换为特征后,下一步是训练模型。

Trainer

The Trainer 组件将训练您在 TensorFlow 中定义的模型(使用 Estimator API 或使用 model_to_estimator 的 Keras API)。

Trainer 将使用来自 SchemaGen 的模式、来自 Transform 的转换后的数据和图、训练参数以及包含用户定义的模型代码的模块作为输入。

让我们在下面查看用户定义的模型代码的示例(有关 TensorFlow Estimator API 的介绍,请参阅教程

_taxi_trainer_module_file = 'taxi_trainer.py'
%%writefile {_taxi_trainer_module_file}

import tensorflow as tf
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils
from tfx_bsl.tfxio import dataset_options

import taxi_constants

_DENSE_FLOAT_FEATURE_KEYS = taxi_constants.DENSE_FLOAT_FEATURE_KEYS
_VOCAB_FEATURE_KEYS = taxi_constants.VOCAB_FEATURE_KEYS
_VOCAB_SIZE = taxi_constants.VOCAB_SIZE
_OOV_SIZE = taxi_constants.OOV_SIZE
_FEATURE_BUCKET_COUNT = taxi_constants.FEATURE_BUCKET_COUNT
_BUCKET_FEATURE_KEYS = taxi_constants.BUCKET_FEATURE_KEYS
_CATEGORICAL_FEATURE_KEYS = taxi_constants.CATEGORICAL_FEATURE_KEYS
_MAX_CATEGORICAL_FEATURE_VALUES = taxi_constants.MAX_CATEGORICAL_FEATURE_VALUES
_LABEL_KEY = taxi_constants.LABEL_KEY


# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
  return schema_utils.schema_as_feature_spec(schema).feature_spec


def _build_estimator(config, hidden_units=None, warm_start_from=None):
  """Build an estimator for predicting the tipping behavior of taxi riders.
  Args:
    config: tf.estimator.RunConfig defining the runtime environment for the
      estimator (including model_dir).
    hidden_units: [int], the layer sizes of the DNN (input layer first)
    warm_start_from: Optional directory to warm start from.
  Returns:
    A dict of the following:
      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  real_valued_columns = [
      tf.feature_column.numeric_column(key, shape=())
      for key in _DENSE_FLOAT_FEATURE_KEYS
  ]
  categorical_columns = [
      tf.feature_column.categorical_column_with_identity(
          key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0)
      for key in _VOCAB_FEATURE_KEYS
  ]
  categorical_columns += [
      tf.feature_column.categorical_column_with_identity(
          key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0)
      for key in _BUCKET_FEATURE_KEYS
  ]
  categorical_columns += [
      tf.feature_column.categorical_column_with_identity(  # pylint: disable=g-complex-comprehension
          key,
          num_buckets=num_buckets,
          default_value=0) for key, num_buckets in zip(
              _CATEGORICAL_FEATURE_KEYS,
              _MAX_CATEGORICAL_FEATURE_VALUES)
  ]
  return tf.estimator.DNNLinearCombinedClassifier(
      config=config,
      linear_feature_columns=categorical_columns,
      dnn_feature_columns=real_valued_columns,
      dnn_hidden_units=hidden_units or [100, 70, 50, 25],
      warm_start_from=warm_start_from)


def _example_serving_receiver_fn(tf_transform_graph, schema):
  """Build the serving in inputs.
  Args:
    tf_transform_graph: A TFTransformOutput.
    schema: the schema of the input data.
  Returns:
    Tensorflow graph which parses examples, applying tf-transform to them.
  """
  raw_feature_spec = _get_raw_feature_spec(schema)
  raw_feature_spec.pop(_LABEL_KEY)

  raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  serving_input_receiver = raw_input_fn()

  transformed_features = tf_transform_graph.transform_raw_features(
      serving_input_receiver.features)

  return tf.estimator.export.ServingInputReceiver(
      transformed_features, serving_input_receiver.receiver_tensors)


def _eval_input_receiver_fn(tf_transform_graph, schema):
  """Build everything needed for the tf-model-analysis to run the model.
  Args:
    tf_transform_graph: A TFTransformOutput.
    schema: the schema of the input data.
  Returns:
    EvalInputReceiver function, which contains:
      - Tensorflow graph which parses raw untransformed features, applies the
        tf-transform preprocessing operators.
      - Set of raw, untransformed features.
      - Label against which predictions will be compared.
  """
  # Notice that the inputs are raw features, not transformed features here.
  raw_feature_spec = _get_raw_feature_spec(schema)

  serialized_tf_example = tf.compat.v1.placeholder(
      dtype=tf.string, shape=[None], name='input_example_tensor')

  # Add a parse_example operator to the tensorflow graph, which will parse
  # raw, untransformed, tf examples.
  features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)

  # Now that we have our raw examples, process them through the tf-transform
  # function computed during the preprocessing step.
  transformed_features = tf_transform_graph.transform_raw_features(
      features)

  # The key name MUST be 'examples'.
  receiver_tensors = {'examples': serialized_tf_example}

  # NOTE: Model is driven by transformed features (since training works on the
  # materialized output of TFT, but slicing will happen on raw features.
  features.update(transformed_features)

  return tfma.export.EvalInputReceiver(
      features=features,
      receiver_tensors=receiver_tensors,
      labels=transformed_features[_LABEL_KEY])


def _input_fn(file_pattern, data_accessor, tf_transform_output, batch_size=200):
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      dataset_options.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY),
      tf_transform_output.transformed_metadata.schema)


# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.
  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.
  Returns:
    A dict of the following:
      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_graph = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      trainer_fn_args.data_accessor,
      tf_transform_graph,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      trainer_fn_args.data_accessor,
      tf_transform_graph,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_graph, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=trainer_fn_args.base_model)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_graph, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }
Writing taxi_trainer.py

现在,我们将此模型代码传递给 Trainer 组件并运行它以训练模型。

from tfx.components.trainer.executor import Executor
from tfx.dsl.components.base import executor_spec

trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_taxi_trainer_module_file),
    custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
    examples=transform.outputs['transformed_examples'],
    schema=schema_gen.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=tfx.proto.TrainArgs(num_steps=10000),
    eval_args=tfx.proto.EvalArgs(num_steps=5000))
context.run(trainer)
WARNING:absl:`custom_executor_spec` is deprecated. Please customize component directly.
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py' (including modules: ['taxi_trainer', 'taxi_transform', 'taxi_constants']).
INFO:absl:User module package has hash fingerprint version e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmpnta9yff_/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmp5wdb8vq6', '--dist-dir', '/tmpfs/tmp/tmp8v8hc7ie']
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl'; target user module is 'taxi_trainer'.
INFO:absl:Full user module path is 'taxi_trainer@/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl'
INFO:absl:Running driver for Trainer
INFO:absl:MetadataStore with DB connection initialized
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying taxi_trainer.py -> build/lib
copying taxi_transform.py -> build/lib
copying taxi_constants.py -> build/lib
installing to /tmpfs/tmp/tmp5wdb8vq6
running install
running install_lib
copying build/lib/taxi_trainer.py -> /tmpfs/tmp/tmp5wdb8vq6
copying build/lib/taxi_transform.py -> /tmpfs/tmp/tmp5wdb8vq6
copying build/lib/taxi_constants.py -> /tmpfs/tmp/tmp5wdb8vq6
running install_egg_info
running egg_info
creating tfx_user_code_Trainer.egg-info
writing tfx_user_code_Trainer.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
Copying tfx_user_code_Trainer.egg-info to /tmpfs/tmp/tmp5wdb8vq6/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmp5wdb8vq6/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618.dist-info/WHEEL
creating '/tmpfs/tmp/tmp8v8hc7ie/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl' and adding '/tmpfs/tmp/tmp5wdb8vq6' to it
adding 'taxi_constants.py'
adding 'taxi_trainer.py'
adding 'taxi_transform.py'
adding 'tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618.dist-info/METADATA'
adding 'tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618.dist-info/WHEEL'
adding 'tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618.dist-info/top_level.txt'
adding 'tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618.dist-info/RECORD'
removing /tmpfs/tmp/tmp5wdb8vq6
INFO:absl:Running executor for Trainer
INFO:absl:Train on the 'train' split when train_args.splits is not set.
INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set.
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
INFO:absl:udf_utils.get_fn {'train_args': '{\n  "num_steps": 10000\n}', 'eval_args': '{\n  "num_steps": 5000\n}', 'module_file': None, 'run_fn': None, 'trainer_fn': None, 'custom_config': 'null', 'module_path': 'taxi_trainer@/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl'} 'trainer_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp86irddos', '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl']
Processing /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/_wheels/tfx_user_code_Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618-py3-none-any.whl'.
Installing collected packages: tfx-user-code-Trainer
Successfully installed tfx-user-code-Trainer-0.0+e337a512821685b6d91445dbd0628b47de0e4c751e9e54edf78bcf0866309618
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:188: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:195: FinalExporter.__init__ (from tensorflow_estimator.python.estimator.exporter) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:196: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:202: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:41: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:45: categorical_column_with_identity (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:62: DNNLinearCombinedClassifierV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn_linear_combined) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/head_utils.py:54: BinaryClassHead.__init__ (from tensorflow_estimator.python.estimator.head.binary_class_head) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn_linear_combined.py:586: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 999, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:absl:Training model.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx/components/trainer/executor.py:270: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 999 or save_checkpoints_secs None.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:343: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
2024-05-08 09:29:41.217001: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    for Tuple type infernce function 0
    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 0.7004041, step = 0
INFO:tensorflow:global_step/sec: 94.9559
INFO:tensorflow:loss = 0.58275115, step = 100 (1.054 sec)
INFO:tensorflow:global_step/sec: 126.379
INFO:tensorflow:loss = 0.5188283, step = 200 (0.791 sec)
INFO:tensorflow:global_step/sec: 127.047
INFO:tensorflow:loss = 0.5207444, step = 300 (0.787 sec)
INFO:tensorflow:global_step/sec: 126.868
INFO:tensorflow:loss = 0.5565426, step = 400 (0.788 sec)
INFO:tensorflow:global_step/sec: 128.906
INFO:tensorflow:loss = 0.44116384, step = 500 (0.776 sec)
INFO:tensorflow:global_step/sec: 127.421
INFO:tensorflow:loss = 0.4577001, step = 600 (0.785 sec)
INFO:tensorflow:global_step/sec: 127.514
INFO:tensorflow:loss = 0.4446908, step = 700 (0.784 sec)
INFO:tensorflow:global_step/sec: 125.617
INFO:tensorflow:loss = 0.4720246, step = 800 (0.796 sec)
INFO:tensorflow:global_step/sec: 131.021
INFO:tensorflow:loss = 0.4437019, step = 900 (0.763 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 999...
INFO:tensorflow:Saving checkpoints for 999 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1067: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 999...
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2024-05-08T09:29:55
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt-999
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Inference Time : 31.08471s
INFO:tensorflow:Finished evaluation at 2024-05-08-09:30:26
INFO:tensorflow:Saving dict for global step 999: accuracy = 0.771235, accuracy_baseline = 0.771235, auc = 0.9186248, auc_precision_recall = 0.6492264, average_loss = 0.4624323, global_step = 999, label/mean = 0.228765, loss = 0.46243173, precision = 0.0, prediction/mean = 0.24984238, recall = 0.0
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 999: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt-999
INFO:tensorflow:global_step/sec: 2.90748
INFO:tensorflow:loss = 0.42686376, step = 1000 (34.394 sec)
INFO:tensorflow:global_step/sec: 126.964
INFO:tensorflow:loss = 0.44425964, step = 1100 (0.788 sec)
INFO:tensorflow:global_step/sec: 127.535
INFO:tensorflow:loss = 0.47197676, step = 1200 (0.784 sec)
INFO:tensorflow:global_step/sec: 125.471
INFO:tensorflow:loss = 0.506038, step = 1300 (0.797 sec)
INFO:tensorflow:global_step/sec: 127.967
INFO:tensorflow:loss = 0.4099118, step = 1400 (0.781 sec)
INFO:tensorflow:global_step/sec: 126.812
INFO:tensorflow:loss = 0.5171037, step = 1500 (0.789 sec)
INFO:tensorflow:global_step/sec: 131.046
INFO:tensorflow:loss = 0.4371317, step = 1600 (0.763 sec)
INFO:tensorflow:global_step/sec: 128.146
INFO:tensorflow:loss = 0.45776543, step = 1700 (0.780 sec)
INFO:tensorflow:global_step/sec: 128.92
INFO:tensorflow:loss = 0.50659925, step = 1800 (0.776 sec)
INFO:tensorflow:global_step/sec: 129.06
INFO:tensorflow:loss = 0.43417373, step = 1900 (0.775 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1998...
INFO:tensorflow:Saving checkpoints for 1998 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1998...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 103.772
INFO:tensorflow:loss = 0.45863923, step = 2000 (0.963 sec)
INFO:tensorflow:global_step/sec: 125.899
INFO:tensorflow:loss = 0.3514557, step = 2100 (0.795 sec)
INFO:tensorflow:global_step/sec: 129.752
INFO:tensorflow:loss = 0.43468484, step = 2200 (0.771 sec)
INFO:tensorflow:global_step/sec: 131.014
INFO:tensorflow:loss = 0.48132992, step = 2300 (0.763 sec)
INFO:tensorflow:global_step/sec: 130.271
INFO:tensorflow:loss = 0.44048753, step = 2400 (0.768 sec)
INFO:tensorflow:global_step/sec: 129.449
INFO:tensorflow:loss = 0.3523005, step = 2500 (0.773 sec)
INFO:tensorflow:global_step/sec: 130.936
INFO:tensorflow:loss = 0.3773502, step = 2600 (0.764 sec)
INFO:tensorflow:global_step/sec: 129.258
INFO:tensorflow:loss = 0.43350023, step = 2700 (0.774 sec)
INFO:tensorflow:global_step/sec: 133.75
INFO:tensorflow:loss = 0.37304792, step = 2800 (0.748 sec)
INFO:tensorflow:global_step/sec: 130.275
INFO:tensorflow:loss = 0.3801176, step = 2900 (0.768 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2997...
INFO:tensorflow:Saving checkpoints for 2997 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2997...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 107.092
INFO:tensorflow:loss = 0.3836586, step = 3000 (0.933 sec)
INFO:tensorflow:global_step/sec: 133.249
INFO:tensorflow:loss = 0.43525982, step = 3100 (0.751 sec)
INFO:tensorflow:global_step/sec: 124.615
INFO:tensorflow:loss = 0.42075485, step = 3200 (0.803 sec)
INFO:tensorflow:global_step/sec: 122.737
INFO:tensorflow:loss = 0.3901537, step = 3300 (0.815 sec)
INFO:tensorflow:global_step/sec: 122.54
INFO:tensorflow:loss = 0.35952353, step = 3400 (0.816 sec)
INFO:tensorflow:global_step/sec: 124.721
INFO:tensorflow:loss = 0.3873772, step = 3500 (0.802 sec)
INFO:tensorflow:global_step/sec: 128.574
INFO:tensorflow:loss = 0.36566123, step = 3600 (0.778 sec)
INFO:tensorflow:global_step/sec: 126.009
INFO:tensorflow:loss = 0.40229043, step = 3700 (0.794 sec)
INFO:tensorflow:global_step/sec: 122.1
INFO:tensorflow:loss = 0.4070228, step = 3800 (0.819 sec)
INFO:tensorflow:global_step/sec: 123.903
INFO:tensorflow:loss = 0.4688112, step = 3900 (0.807 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3996...
INFO:tensorflow:Saving checkpoints for 3996 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3996...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 100.776
INFO:tensorflow:loss = 0.49602365, step = 4000 (0.992 sec)
INFO:tensorflow:global_step/sec: 123.848
INFO:tensorflow:loss = 0.2742646, step = 4100 (0.808 sec)
INFO:tensorflow:global_step/sec: 123.423
INFO:tensorflow:loss = 0.44800407, step = 4200 (0.810 sec)
INFO:tensorflow:global_step/sec: 123.175
INFO:tensorflow:loss = 0.43835735, step = 4300 (0.812 sec)
INFO:tensorflow:global_step/sec: 123.891
INFO:tensorflow:loss = 0.32207388, step = 4400 (0.807 sec)
INFO:tensorflow:global_step/sec: 125.024
INFO:tensorflow:loss = 0.38216084, step = 4500 (0.800 sec)
INFO:tensorflow:global_step/sec: 124.891
INFO:tensorflow:loss = 0.45092455, step = 4600 (0.801 sec)
INFO:tensorflow:global_step/sec: 123.988
INFO:tensorflow:loss = 0.3750969, step = 4700 (0.807 sec)
INFO:tensorflow:global_step/sec: 125.77
INFO:tensorflow:loss = 0.3925597, step = 4800 (0.795 sec)
INFO:tensorflow:global_step/sec: 125.459
INFO:tensorflow:loss = 0.43026057, step = 4900 (0.797 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4995...
INFO:tensorflow:Saving checkpoints for 4995 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4995...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 103.654
INFO:tensorflow:loss = 0.41449785, step = 5000 (0.964 sec)
INFO:tensorflow:global_step/sec: 125.195
INFO:tensorflow:loss = 0.3261056, step = 5100 (0.799 sec)
INFO:tensorflow:global_step/sec: 127.462
INFO:tensorflow:loss = 0.35417694, step = 5200 (0.785 sec)
INFO:tensorflow:global_step/sec: 127.844
INFO:tensorflow:loss = 0.4030676, step = 5300 (0.782 sec)
INFO:tensorflow:global_step/sec: 130.923
INFO:tensorflow:loss = 0.4126954, step = 5400 (0.764 sec)
INFO:tensorflow:global_step/sec: 130.98
INFO:tensorflow:loss = 0.32259554, step = 5500 (0.764 sec)
INFO:tensorflow:global_step/sec: 128.589
INFO:tensorflow:loss = 0.3811575, step = 5600 (0.777 sec)
INFO:tensorflow:global_step/sec: 125.918
INFO:tensorflow:loss = 0.40286702, step = 5700 (0.794 sec)
INFO:tensorflow:global_step/sec: 129.137
INFO:tensorflow:loss = 0.32921094, step = 5800 (0.775 sec)
INFO:tensorflow:global_step/sec: 130.104
INFO:tensorflow:loss = 0.4013093, step = 5900 (0.768 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5994...
INFO:tensorflow:Saving checkpoints for 5994 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5994...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 106.58
INFO:tensorflow:loss = 0.31678432, step = 6000 (0.938 sec)
INFO:tensorflow:global_step/sec: 126.52
INFO:tensorflow:loss = 0.3150363, step = 6100 (0.791 sec)
INFO:tensorflow:global_step/sec: 128.723
INFO:tensorflow:loss = 0.39068645, step = 6200 (0.777 sec)
INFO:tensorflow:global_step/sec: 128.852
INFO:tensorflow:loss = 0.2832807, step = 6300 (0.776 sec)
INFO:tensorflow:global_step/sec: 125.866
INFO:tensorflow:loss = 0.32048258, step = 6400 (0.794 sec)
INFO:tensorflow:global_step/sec: 125.299
INFO:tensorflow:loss = 0.38626963, step = 6500 (0.798 sec)
INFO:tensorflow:global_step/sec: 125.342
INFO:tensorflow:loss = 0.39416704, step = 6600 (0.798 sec)
INFO:tensorflow:global_step/sec: 127.235
INFO:tensorflow:loss = 0.30232263, step = 6700 (0.786 sec)
INFO:tensorflow:global_step/sec: 126.055
INFO:tensorflow:loss = 0.41977397, step = 6800 (0.793 sec)
INFO:tensorflow:global_step/sec: 127.477
INFO:tensorflow:loss = 0.47491065, step = 6900 (0.784 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6993...
INFO:tensorflow:Saving checkpoints for 6993 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6993...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 104.669
INFO:tensorflow:loss = 0.35919297, step = 7000 (0.955 sec)
INFO:tensorflow:global_step/sec: 126.032
INFO:tensorflow:loss = 0.42433387, step = 7100 (0.794 sec)
INFO:tensorflow:global_step/sec: 126.079
INFO:tensorflow:loss = 0.3359905, step = 7200 (0.793 sec)
INFO:tensorflow:global_step/sec: 124.834
INFO:tensorflow:loss = 0.4118205, step = 7300 (0.801 sec)
INFO:tensorflow:global_step/sec: 126.048
INFO:tensorflow:loss = 0.3594822, step = 7400 (0.793 sec)
INFO:tensorflow:global_step/sec: 127.316
INFO:tensorflow:loss = 0.3544901, step = 7500 (0.786 sec)
INFO:tensorflow:global_step/sec: 125.922
INFO:tensorflow:loss = 0.3517708, step = 7600 (0.794 sec)
INFO:tensorflow:global_step/sec: 127.473
INFO:tensorflow:loss = 0.32316074, step = 7700 (0.784 sec)
INFO:tensorflow:global_step/sec: 127.602
INFO:tensorflow:loss = 0.28583208, step = 7800 (0.784 sec)
INFO:tensorflow:global_step/sec: 128.23
INFO:tensorflow:loss = 0.379911, step = 7900 (0.780 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7992...
INFO:tensorflow:Saving checkpoints for 7992 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7992...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 105.424
INFO:tensorflow:loss = 0.3968008, step = 8000 (0.948 sec)
INFO:tensorflow:global_step/sec: 128.249
INFO:tensorflow:loss = 0.43308416, step = 8100 (0.780 sec)
INFO:tensorflow:global_step/sec: 128.472
INFO:tensorflow:loss = 0.42253828, step = 8200 (0.778 sec)
INFO:tensorflow:global_step/sec: 125.642
INFO:tensorflow:loss = 0.39132017, step = 8300 (0.796 sec)
INFO:tensorflow:global_step/sec: 128.607
INFO:tensorflow:loss = 0.30107036, step = 8400 (0.777 sec)
INFO:tensorflow:global_step/sec: 126.434
INFO:tensorflow:loss = 0.30194753, step = 8500 (0.791 sec)
INFO:tensorflow:global_step/sec: 127.391
INFO:tensorflow:loss = 0.30165237, step = 8600 (0.785 sec)
INFO:tensorflow:global_step/sec: 127.042
INFO:tensorflow:loss = 0.44196972, step = 8700 (0.787 sec)
INFO:tensorflow:global_step/sec: 126.923
INFO:tensorflow:loss = 0.42164555, step = 8800 (0.788 sec)
INFO:tensorflow:global_step/sec: 127.39
INFO:tensorflow:loss = 0.3490799, step = 8900 (0.785 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8991...
INFO:tensorflow:Saving checkpoints for 8991 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8991...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:global_step/sec: 105.59
INFO:tensorflow:loss = 0.31310123, step = 9000 (0.947 sec)
INFO:tensorflow:global_step/sec: 124.933
INFO:tensorflow:loss = 0.4325568, step = 9100 (0.801 sec)
INFO:tensorflow:global_step/sec: 127.59
INFO:tensorflow:loss = 0.30360752, step = 9200 (0.784 sec)
INFO:tensorflow:global_step/sec: 130.138
INFO:tensorflow:loss = 0.29442087, step = 9300 (0.768 sec)
INFO:tensorflow:global_step/sec: 130.544
INFO:tensorflow:loss = 0.31136292, step = 9400 (0.766 sec)
INFO:tensorflow:global_step/sec: 130.849
INFO:tensorflow:loss = 0.34016177, step = 9500 (0.764 sec)
INFO:tensorflow:global_step/sec: 129.551
INFO:tensorflow:loss = 0.39522016, step = 9600 (0.772 sec)
INFO:tensorflow:global_step/sec: 130.262
INFO:tensorflow:loss = 0.34697112, step = 9700 (0.768 sec)
INFO:tensorflow:global_step/sec: 129.093
INFO:tensorflow:loss = 0.38748953, step = 9800 (0.775 sec)
INFO:tensorflow:global_step/sec: 131.427
INFO:tensorflow:loss = 0.29107836, step = 9900 (0.761 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9990...
INFO:tensorflow:Saving checkpoints for 9990 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9990...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000...
INFO:tensorflow:Saving checkpoints for 10000 into /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs).
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:absl:Feature company has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature fare has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_community_area has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_latitude has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_hour has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_month has a shape . Setting to DenseTensor.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2024-05-08T09:31:40
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt-10000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [500/5000]
INFO:tensorflow:Evaluation [1000/5000]
INFO:tensorflow:Evaluation [1500/5000]
INFO:tensorflow:Evaluation [2000/5000]
INFO:tensorflow:Evaluation [2500/5000]
INFO:tensorflow:Evaluation [3000/5000]
INFO:tensorflow:Evaluation [3500/5000]
INFO:tensorflow:Evaluation [4000/5000]
INFO:tensorflow:Evaluation [4500/5000]
INFO:tensorflow:Evaluation [5000/5000]
INFO:tensorflow:Inference Time : 30.60751s
INFO:tensorflow:Finished evaluation at 2024-05-08-09:32:11
INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.78671, accuracy_baseline = 0.771205, auc = 0.93178755, auc_precision_recall = 0.69907445, average_loss = 0.34567952, global_step = 10000, label/mean = 0.228795, loss = 0.34567845, precision = 0.7014421, prediction/mean = 0.23119248, recall = 0.117987715
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10000: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt-10000
INFO:tensorflow:Performing the final export in the end of training.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
WARNING:tensorflow:From /tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py:81: build_parsing_serving_input_receiver_fn (from tensorflow_estimator.python.estimator.export.export) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/export/export.py:312: ServingInputReceiver.__new__ (from tensorflow_estimator.python.estimator.export.export) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
WARNING:tensorflow:Loading a TF2 SavedModel but eager mode seems disabled.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/base_head.py:786: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/binary_class_head.py:561: RegressionOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/binary_class_head.py:563: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:168: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://tensorflowcn.cn/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py:83: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://tensorflowcn.cn/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt-10000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/export/chicago-taxi/temp-1715160731/assets
INFO:tensorflow:SavedModel written to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/export/chicago-taxi/temp-1715160731/saved_model.pb
INFO:tensorflow:Loss for final step: 0.33868045.
INFO:absl:Training complete. Model written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving. ModelRun written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6
INFO:absl:Exporting eval_savedmodel for TFMA.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
WARNING:tensorflow:Loading a TF2 SavedModel but eager mode seems disabled.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py:345: _SupervisedOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']
WARNING:tensorflow:Export includes no default signature!
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-Serving/model.ckpt-10000
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-TFMA/temp-1715160733/assets
INFO:tensorflow:SavedModel written to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-TFMA/temp-1715160733/saved_model.pb
INFO:absl:Exported eval_savedmodel to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model_run/6/Format-TFMA.
WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/serving_model_dir/saved_model.pb"
INFO:absl:Serving model copied to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model/6/Format-Serving.
WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/eval_model_dir/saved_model.pb"
INFO:absl:Eval model copied to: /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model/6/Format-TFMA.
INFO:absl:Running publisher for Trainer
INFO:absl:MetadataStore with DB connection initialized

使用 TensorBoard 分析训练

可选地,我们可以将 TensorBoard 连接到 Trainer 以分析模型的训练曲线。

# Get the URI of the output artifact representing the training logs, which is a directory
model_run_dir = trainer.outputs['model_run'].get()[0].uri

%load_ext tensorboard
%tensorboard --logdir {model_run_dir}

Evaluator

The Evaluator 组件计算评估集上的模型性能指标。它使用 TensorFlow 模型分析 库。The Evaluator 还可以选择验证新训练的模型是否优于以前的模型。这在您可能每天自动训练和验证模型的生产管道设置中很有用。在此笔记本中,我们只训练了一个模型,因此 Evaluator 将自动将模型标记为“良好”。

Evaluator 将使用来自 ExampleGen 的数据、来自 Trainer 的训练模型以及切片配置作为输入。切片配置允许您根据特征值切片指标(例如,您的模型在上午 8 点开始的出租车行程与晚上 8 点开始的出租车行程的性能如何?)。请参见下面此配置的示例

eval_config = tfma.EvalConfig(
    model_specs=[
        # Using signature 'eval' implies the use of an EvalSavedModel. To use
        # a serving model remove the signature to defaults to 'serving_default'
        # and add a label_key.
        tfma.ModelSpec(signature_name='eval')
    ],
    metrics_specs=[
        tfma.MetricsSpec(
            # The metrics added here are in addition to those saved with the
            # model (assuming either a keras model or EvalSavedModel is used).
            # Any metrics added into the saved model (for example using
            # model.compile(..., metrics=[...]), etc) will be computed
            # automatically.
            metrics=[
                tfma.MetricConfig(class_name='ExampleCount')
            ],
            # To add validation thresholds for metrics saved with the model,
            # add them keyed by metric name to the thresholds map.
            thresholds = {
                'accuracy': tfma.MetricThreshold(
                    value_threshold=tfma.GenericValueThreshold(
                        lower_bound={'value': 0.5}),
                    # Change threshold will be ignored if there is no
                    # baseline model resolved from MLMD (first run).
                    change_threshold=tfma.GenericChangeThreshold(
                       direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                       absolute={'value': -1e-10}))
            }
        )
    ],
    slicing_specs=[
        # An empty slice spec means the overall slice, i.e. the whole dataset.
        tfma.SlicingSpec(),
        # Data can be sliced along a feature column. In this case, data is
        # sliced along feature column trip_start_hour.
        tfma.SlicingSpec(feature_keys=['trip_start_hour'])
    ])

接下来,我们将此配置提供给 Evaluator 并运行它。

# Use TFMA to compute a evaluation statistics over features of a model and
# validate them against a baseline.

# The model resolver is only required if performing model validation in addition
# to evaluation. In this case we validate against the latest blessed model. If
# no model has been blessed before (as in this case) the evaluator will make our
# candidate the first blessed model.
model_resolver = tfx.dsl.Resolver(
      strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
      model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
      model_blessing=tfx.dsl.Channel(
          type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
              'latest_blessed_model_resolver')
context.run(model_resolver)

evaluator = tfx.components.Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    eval_config=eval_config)
context.run(evaluator)
INFO:absl:Running driver for latest_blessed_model_resolver
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running publisher for latest_blessed_model_resolver
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running driver for Evaluator
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Evaluator
INFO:absl:udf_utils.get_fn {'eval_config': '{\n  "metrics_specs": [\n    {\n      "metrics": [\n        {\n          "class_name": "ExampleCount"\n        }\n      ],\n      "thresholds": {\n        "accuracy": {\n          "change_threshold": {\n            "absolute": -1e-10,\n            "direction": "HIGHER_IS_BETTER"\n          },\n          "value_threshold": {\n            "lower_bound": 0.5\n          }\n        }\n      }\n    }\n  ],\n  "model_specs": [\n    {\n      "signature_name": "eval"\n    }\n  ],\n  "slicing_specs": [\n    {},\n    {\n      "feature_keys": [\n        "trip_start_hour"\n      ]\n    }\n  ]\n}', 'feature_slicing_spec': None, 'fairness_indicator_thresholds': 'null', 'example_splits': 'null', 'module_file': None, 'module_path': None} 'custom_eval_shared_model'
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "eval"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  thresholds {
    key: "accuracy"
    value {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}

INFO:absl:Using /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model/6/Format-TFMA as  model.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
INFO:absl:The 'example_splits' parameter is not set, using 'eval' split.
INFO:absl:Evaluating model.
INFO:absl:udf_utils.get_fn {'eval_config': '{\n  "metrics_specs": [\n    {\n      "metrics": [\n        {\n          "class_name": "ExampleCount"\n        }\n      ],\n      "thresholds": {\n        "accuracy": {\n          "change_threshold": {\n            "absolute": -1e-10,\n            "direction": "HIGHER_IS_BETTER"\n          },\n          "value_threshold": {\n            "lower_bound": 0.5\n          }\n        }\n      }\n    }\n  ],\n  "model_specs": [\n    {\n      "signature_name": "eval"\n    }\n  ],\n  "slicing_specs": [\n    {},\n    {\n      "feature_keys": [\n        "trip_start_hour"\n      ]\n    }\n  ]\n}', 'feature_slicing_spec': None, 'fairness_indicator_thresholds': 'null', 'example_splits': 'null', 'module_file': None, 'module_path': None} 'custom_extractors'
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "eval"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  model_names: ""
  thresholds {
    key: "accuracy"
    value {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}

INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "eval"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  model_names: ""
  thresholds {
    key: "accuracy"
    value {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}

INFO:absl:eval_shared_models have model_types: {'tfma_eval'}
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "eval"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  model_names: ""
  thresholds {
    key: "accuracy"
    value {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/eval_saved_model/load.py:163: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Trainer/model/6/Format-TFMA/variables/variables
2024-05-08 09:32:18.153481: W tensorflow/c/c_api.cc:305] Operation '{name:'head/metrics/count_3/Assign' id:1368 op device:{requested: '', assigned: ''} def:{ { {node head/metrics/count_3/Assign} } = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](head/metrics/count_3, head/metrics/count_3/Initializer/zeros)} }' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
2024-05-08 09:32:18.323322: W tensorflow/c/c_api.cc:305] Operation '{name:'head/metrics/count_3/Assign' id:1368 op device:{requested: '', assigned: ''} def:{ { {node head/metrics/count_3/Assign} } = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](head/metrics/count_3, head/metrics/count_3/Initializer/zeros)} }' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
INFO:absl:Evaluation complete. Results written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Evaluator/evaluation/8.
INFO:absl:Checking validation results.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:112: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
INFO:absl:Blessing result True written to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Evaluator/blessing/8.
INFO:absl:Running publisher for Evaluator
INFO:absl:MetadataStore with DB connection initialized

现在让我们检查一下 Evaluator 的输出工件。

evaluator.outputs
{'evaluation': OutputChannel(artifact_type=ModelEvaluation, producer_component_id=Evaluator, output_key=evaluation, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False),
 'blessing': OutputChannel(artifact_type=ModelBlessing, producer_component_id=Evaluator, output_key=blessing, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False)}

使用 evaluation 输出,我们可以显示整个评估集上全局指标的默认可视化。

context.show(evaluator.outputs['evaluation'])
SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'Overall', 'metrics':…

要查看切片评估指标的可视化,我们可以直接调用 TensorFlow 模型分析库。

import tensorflow_model_analysis as tfma

# Get the TFMA output result path and load the result.
PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)

# Show data sliced along feature column trip_start_hour.
tfma.view.render_slicing_metrics(
    tfma_result, slicing_column='trip_start_hour')
SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'trip_start_hour:19',…

此可视化显示了相同的指标,但在 trip_start_hour 的每个特征值处计算,而不是在整个评估集上计算。

TensorFlow 模型分析支持许多其他可视化,例如公平指标和绘制模型性能的时间序列。要了解更多信息,请参见 教程

由于我们在配置中添加了阈值,因此验证输出也可用。The blessing 工件的存在表明我们的模型已通过验证。由于这是执行的第一次验证,因此候选者会自动获得祝福。

blessing_uri = evaluator.outputs['blessing'].get()[0].uri
!ls -l {blessing_uri}
total 0
-rw-rw-r-- 1 kbuilder kbuilder 0 May  8 09:32 BLESSED

现在还可以通过加载验证结果记录来验证成功

PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
print(tfma.load_validation_result(PATH_TO_RESULT))
validation_ok: true
validation_details {
  slicing_details {
    slicing_spec {
    }
    num_matching_slices: 25
  }
}

Pusher

The Pusher 组件通常位于 TFX 管道的末尾。它检查模型是否已通过验证,如果通过,则将模型导出到 _serving_model_dir

pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher)
INFO:absl:Running driver for Pusher
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Pusher
INFO:absl:Model version: 1715160747
INFO:absl:Model written to serving path /tmpfs/tmp/tmp15vk_44e/serving_model/taxi_simple/1715160747.
INFO:absl:Model pushed to /tmpfs/tmp/tfx-interactive-2024-05-08T09_28_51.324450-cz1zlfzs/Pusher/pushed_model/9.
INFO:absl:Running publisher for Pusher
INFO:absl:MetadataStore with DB connection initialized

让我们检查一下 Pusher 的输出工件。

pusher.outputs
{'pushed_model': OutputChannel(artifact_type=PushedModel, producer_component_id=Pusher, output_key=pushed_model, additional_properties={}, additional_custom_properties={}, _input_trigger=None, _is_async=False)}

特别是,Pusher 将以 SavedModel 格式导出您的模型,它看起来像这样

push_uri = pusher.outputs['pushed_model'].get()[0].uri
model = tf.saved_model.load(push_uri)

for item in model.signatures.items():
  pp.pprint(item)
INFO:absl:Fingerprint not found. Saved model loading will continue.
INFO:absl:path_and_singleprint metric could not be logged. Saved model loading will continue.
('regression',
 <ConcreteFunction () -> Dict[['outputs', TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)]] at 0x7F38C020BDF0>)
('predict',
 <ConcreteFunction () -> Dict[['class_ids', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None)], ['classes', TensorSpec(shape=(None, 1), dtype=tf.string, name=None)], ['probabilities', TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)], ['logits', TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)], ['all_class_ids', TensorSpec(shape=(None, 2), dtype=tf.int32, name=None)], ['logistic', TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)], ['all_classes', TensorSpec(shape=(None, 2), dtype=tf.string, name=None)]] at 0x7F39100E0A90>)
('serving_default',
 <ConcreteFunction () -> Dict[['scores', TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)], ['classes', TensorSpec(shape=(None, 2), dtype=tf.string, name=None)]] at 0x7F38C0424700>)
('classification',
 <ConcreteFunction () -> Dict[['classes', TensorSpec(shape=(None, 2), dtype=tf.string, name=None)], ['scores', TensorSpec(shape=(None, 2), dtype=tf.float32, name=None)]] at 0x7F38C0620D60>)

我们完成了对内置 TFX 组件的介绍!