自定义 Python 函数组件

基于 Python 函数的组件定义使您更轻松地创建 TFX 自定义组件,因为它可以为您节省定义组件规范类、执行器类和组件接口类的麻烦。在此组件定义样式中,您编写一个用类型提示注释的函数。类型提示描述了组件的输入工件、输出工件和参数。

使用此样式编写自定义组件非常简单,如下面的示例所示。

class MyOutput(TypedDict):
  accuracy: float

@component
def MyValidationComponent(
    model: InputArtifact[Model],
    blessing: OutputArtifact[Model],
    accuracy_threshold: Parameter[int] = 10,
) -> MyOutput:
  '''My simple custom model validation component.'''

  accuracy = evaluate_model(model)
  if accuracy >= accuracy_threshold:
    write_output_blessing(blessing)

  return {
    'accuracy': accuracy
  }

在幕后,这定义了一个自定义组件,它是 BaseComponent 及其 Spec 和 Executor 类的子类。

如果您想定义 BaseBeamComponent 的子类,以便您可以将 beam 管道与 TFX 管道级共享配置一起使用,即在编译管道时使用 beam_pipeline_args芝加哥出租车管道示例),您可以在装饰器中设置 use_beam=True,并在您的函数中添加另一个 BeamComponentParameter,其默认值为 None,如下面的示例所示

@component(use_beam=True)
def MyDataProcessor(
    examples: InputArtifact[Example],
    processed_examples: OutputArtifact[Example],
    beam_pipeline: BeamComponentParameter[beam.Pipeline] = None,
    ) -> None:
  '''My simple custom model validation component.'''

  with beam_pipeline as p:
    # data pipeline definition with beam_pipeline begins
    ...
    # data pipeline definition with beam_pipeline ends

如果您不熟悉 TFX 管道,请了解有关 TFX 管道核心概念的更多信息

输入、输出和参数

在 TFX 中,输入和输出作为 Artifact 对象进行跟踪,这些对象描述了基础数据的存储位置以及与其关联的元数据属性;此信息存储在 ML Metadata 中。Artifact 可以描述复杂数据类型或简单数据类型,例如:int、float、bytes 或 unicode 字符串。

参数是在管道构建时已知的组件参数(int、float、bytes 或 unicode 字符串)。参数对于指定参数和超参数(如训练迭代次数、dropout 率和其他配置)到您的组件很有用。参数在 ML Metadata 中进行跟踪时,会存储为组件执行的属性。

定义

要创建自定义组件,请编写一个实现自定义逻辑的函数,并使用 @component 装饰器(来自 tfx.dsl.component.experimental.decorators 模块)对其进行装饰。要定义组件的输入和输出模式,请使用 tfx.dsl.component.experimental.annotations 模块中的注释来注释函数的参数和返回值

  • 对于每个工件输入,请应用 InputArtifact[ArtifactType] 类型提示注释。将 ArtifactType 替换为工件的类型,它是 tfx.types.Artifact 的子类。这些输入可以是可选参数。

  • 对于每个**输出工件**,应用OutputArtifact[ArtifactType]类型提示注释。将ArtifactType替换为工件的类型,该类型是tfx.types.Artifact的子类。组件输出工件应作为函数的输入参数传递,以便您的组件可以将输出写入系统管理的位置并设置适当的工件元数据属性。此参数可以是可选的,也可以定义为具有默认值。

  • 对于每个**参数**,使用类型提示注释Parameter[T]。将T替换为参数的类型。我们目前只支持原始 Python 类型:boolintfloatstrbytes

  • 对于**Beam 管道**,使用类型提示注释BeamComponentParameter[beam.Pipeline]。将默认值设置为None。值None将被BaseBeamExecutor_make_beam_pipeline()创建的实例化 Beam 管道替换。

  • 对于每个**简单数据类型输入**(intfloatstrbytes),在管道构建时未知,使用类型提示T。请注意,在 TFX 0.22 版本中,此类型输入的具体值不能在管道构建时传递(请改用Parameter注释,如上一节所述)。此参数可以是可选的,也可以定义为具有默认值。如果您的组件具有简单数据类型输出(intfloatstrbytes),您可以使用TypedDict作为返回类型注释并返回适当的字典对象来返回这些输出。

在函数体中,输入和输出工件作为tfx.types.Artifact对象传递;您可以检查其.uri以获取其系统管理的位置并读取/设置任何属性。输入参数和简单数据类型输入作为指定类型的对象传递。简单数据类型输出应作为字典返回,其中键是相应的输出名称,值是所需的返回值。

完成的函数组件可能如下所示

from typing import TypedDict
import tfx.v1 as tfx
from tfx.dsl.component.experimental.decorators import component

class MyOutput(TypedDict):
  loss: float
  accuracy: float

@component
def MyTrainerComponent(
    training_data: tfx.dsl.components.InputArtifact[tfx.types.standard_artifacts.Examples],
    model: tfx.dsl.components.OutputArtifact[tfx.types.standard_artifacts.Model],
    dropout_hyperparameter: float,
    num_iterations: tfx.dsl.components.Parameter[int] = 10
) -> MyOutput:
  '''My simple trainer component.'''

  records = read_examples(training_data.uri)
  model_obj = train_model(records, num_iterations, dropout_hyperparameter)
  model_obj.write_to(model.uri)

  return {
    'loss': model_obj.loss,
    'accuracy': model_obj.accuracy
  }

# Example usage in a pipeline graph definition:
# ...
trainer = MyTrainerComponent(
    examples=example_gen.outputs['examples'],
    dropout_hyperparameter=other_component.outputs['dropout'],
    num_iterations=1000)
pusher = Pusher(model=trainer.outputs['model'])
# ...

前面的示例将MyTrainerComponent定义为基于 Python 函数的自定义组件。此组件使用examples工件作为其输入,并生成model工件作为其输出。该组件使用artifact_instance.uri在其系统管理的位置读取或写入工件。该组件接受num_iterations输入参数和dropout_hyperparameter简单数据类型值,并且该组件输出lossaccuracy指标作为简单数据类型输出值。然后,输出model工件由Pusher组件使用。