Trainer TFX 管道组件训练 TensorFlow 模型。
Trainer 和 TensorFlow
Trainer 广泛使用 Python TensorFlow API 来训练模型。
组件
Trainer 采用
- 用于训练和评估的 tf.Examples。
- 用户提供的定义 trainer 逻辑的模块文件。
- Protobuf 定义训练参数和评估参数。
- (可选)由 SchemaGen 管道组件创建并由开发者修改的数据架构。
- (可选)由上游 Transform 组件生成的转换图。
- (可选)用于热启动等场景的预训练模型。
- (可选)超参数,将传递给用户模块函数。有关与 Tuner 集成的详细信息,请参阅此处。
Trainer 发射:至少一个用于推理/服务的模型(通常采用 SavedModelFormat),以及用于评估的另一个模型(通常采用 EvalSavedModel)。
我们通过模型重写库为备选模型格式(如 TFLite)提供支持。请参阅模型重写库的链接,了解如何转换 Estimator 和 Keras 模型的示例。
通用 Trainer
通用 trainer 使开发者能够将任何 TensorFlow 模型 API 与 Trainer 组件配合使用。除了 TensorFlow Estimator 之外,开发者还可以使用 Keras 模型或自定义训练循环。有关详细信息,请参阅通用 trainer 的 RFC。
配置 Trainer 组件
通用 Trainer 的典型管道 DSL 代码如下所示
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer 调用训练模块,该模块在 module_file
参数中指定。如果在 custom_executor_spec
中指定了 GenericExecutor
,则模块文件中需要 run_fn
,而不是 trainer_fn
。 trainer_fn
负责创建模型。除此之外,run_fn
还需要处理训练部分,并将训练后的模型输出到 FnArgs 给出的所需位置。
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
这是一个带有 run_fn
的 示例模块文件。
请注意,如果管道中未使用转换组件,则 Trainer 将直接从 ExampleGen 中获取示例
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
更多详细信息可在 Trainer API 参考 中找到。