Trainer TFX 管道组件

Trainer TFX 管道组件训练 TensorFlow 模型。

Trainer 和 TensorFlow

Trainer 广泛使用 Python TensorFlow API 来训练模型。

组件

Trainer 采用

  • 用于训练和评估的 tf.Examples。
  • 用户提供的定义 trainer 逻辑的模块文件。
  • Protobuf 定义训练参数和评估参数。
  • (可选)由 SchemaGen 管道组件创建并由开发者修改的数据架构。
  • (可选)由上游 Transform 组件生成的转换图。
  • (可选)用于热启动等场景的预训练模型。
  • (可选)超参数,将传递给用户模块函数。有关与 Tuner 集成的详细信息,请参阅此处

Trainer 发射:至少一个用于推理/服务的模型(通常采用 SavedModelFormat),以及用于评估的另一个模型(通常采用 EvalSavedModel)。

我们通过模型重写库为备选模型格式(如 TFLite)提供支持。请参阅模型重写库的链接,了解如何转换 Estimator 和 Keras 模型的示例。

通用 Trainer

通用 trainer 使开发者能够将任何 TensorFlow 模型 API 与 Trainer 组件配合使用。除了 TensorFlow Estimator 之外,开发者还可以使用 Keras 模型或自定义训练循环。有关详细信息,请参阅通用 trainer 的 RFC

配置 Trainer 组件

通用 Trainer 的典型管道 DSL 代码如下所示

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer 调用训练模块,该模块在 module_file 参数中指定。如果在 custom_executor_spec 中指定了 GenericExecutor,则模块文件中需要 run_fn,而不是 trainer_fntrainer_fn 负责创建模型。除此之外,run_fn 还需要处理训练部分,并将训练后的模型输出到 FnArgs 给出的所需位置。

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

这是一个带有 run_fn示例模块文件

请注意,如果管道中未使用转换组件,则 Trainer 将直接从 ExampleGen 中获取示例

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

更多详细信息可在 Trainer API 参考 中找到。