Trainer TFX 管道组件

Trainer TFX 管道组件训练 TensorFlow 模型。

Trainer 和 TensorFlow

Trainer 广泛使用 Python TensorFlow API 来训练模型。

组件

Trainer 接受

  • 用于训练和评估的 tf.Examples。
  • 用户提供的模块文件,用于定义训练器逻辑。
  • Protobuf 定义的训练参数和评估参数。
  • (可选) 由 SchemaGen 管道组件创建的数据模式,开发人员可以对其进行修改。
  • (可选) 由上游 Transform 组件生成的转换图。
  • (可选) 用于热启动等场景的预训练模型。
  • (可选) 超参数,将传递给用户模块函数。有关与 Tuner 集成的详细信息,请参见 此处

Trainer 发出:至少一个用于推理/服务的模型(通常为 SavedModelFormat),以及可选的另一个用于评估的模型(通常为 EvalSavedModel)。

我们通过 模型重写库 为 TFLite 等其他模型格式提供支持。有关如何转换 Estimator 和 Keras 模型的示例,请参阅模型重写库的链接。

通用训练器

通用训练器使开发人员能够使用任何 TensorFlow 模型 API 与 Trainer 组件一起使用。除了 TensorFlow Estimators 之外,开发人员还可以使用 Keras 模型或自定义训练循环。有关详细信息,请参阅 通用训练器的 RFC

配置 Trainer 组件

通用 Trainer 的典型管道 DSL 代码如下所示

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer 调用训练模块,该模块在 module_file 参数中指定。如果在 custom_executor_spec 中指定了 GenericExecutor,则模块文件中需要 run_fn 而不是 trainer_fntrainer_fn 负责创建模型。除了这一点之外,run_fn 还需要处理训练部分,并将训练后的模型输出到由 FnArgs 给出的所需位置

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

这是一个包含 run_fn示例模块文件

请注意,如果管道中未使用 Transform 组件,则 Trainer 将直接从 ExampleGen 获取示例

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

有关更多详细信息,请参阅 Trainer API 参考