Trainer TFX 管道组件训练 TensorFlow 模型。
Trainer 和 TensorFlow
Trainer 广泛使用 Python TensorFlow API 来训练模型。
组件
Trainer 接受
- 用于训练和评估的 tf.Examples。
- 用户提供的模块文件,用于定义训练器逻辑。
- Protobuf 定义的训练参数和评估参数。
- (可选) 由 SchemaGen 管道组件创建的数据模式,开发人员可以对其进行修改。
- (可选) 由上游 Transform 组件生成的转换图。
- (可选) 用于热启动等场景的预训练模型。
- (可选) 超参数,将传递给用户模块函数。有关与 Tuner 集成的详细信息,请参见 此处。
Trainer 发出:至少一个用于推理/服务的模型(通常为 SavedModelFormat),以及可选的另一个用于评估的模型(通常为 EvalSavedModel)。
我们通过 模型重写库 为 TFLite 等其他模型格式提供支持。有关如何转换 Estimator 和 Keras 模型的示例,请参阅模型重写库的链接。
通用训练器
通用训练器使开发人员能够使用任何 TensorFlow 模型 API 与 Trainer 组件一起使用。除了 TensorFlow Estimators 之外,开发人员还可以使用 Keras 模型或自定义训练循环。有关详细信息,请参阅 通用训练器的 RFC。
配置 Trainer 组件
通用 Trainer 的典型管道 DSL 代码如下所示
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer 调用训练模块,该模块在 module_file
参数中指定。如果在 custom_executor_spec
中指定了 GenericExecutor
,则模块文件中需要 run_fn
而不是 trainer_fn
。 trainer_fn
负责创建模型。除了这一点之外,run_fn
还需要处理训练部分,并将训练后的模型输出到由 FnArgs 给出的所需位置
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
这是一个包含 run_fn
的 示例模块文件。
请注意,如果管道中未使用 Transform 组件,则 Trainer 将直接从 ExampleGen 获取示例
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
有关更多详细信息,请参阅 Trainer API 参考。