在 TFX 中使用其他 ML 框架

TFX 作为平台是框架中立的,可以与其他 ML 框架(例如 JAX、scikit-learn)一起使用。

对于模型开发者来说,这意味着他们不需要重写用其他 ML 框架实现的模型代码,而是可以将大部分训练代码原封不动地重用在 TFX 中,并受益于 TFX 和 TensorFlow 生态系统提供的其他功能。

TFX 管道 SDK 和 TFX 中的大多数模块(例如管道编排器)没有直接依赖于 TensorFlow,但有些方面是面向 TensorFlow 的,例如数据格式。通过考虑特定建模框架的需求,TFX 管道可用于在任何其他基于 Python 的 ML 框架中训练模型。这包括 Scikit-learn、XGBoost 和 PyTorch 等。使用标准 TFX 组件与其他框架一起使用的一些注意事项包括

  • ExampleGen 在 TFRecord 文件中输出 tf.train.Example。它是一种通用的训练数据表示形式,下游组件使用 TFXIO 将其读取为内存中的 Arrow/RecordBatch,可以进一步转换为 tf.datasetTensors 或其他格式。除了 tf.train.Example/TFRecord 之外的有效负载/文件格式正在考虑中,但对于 TFXIO 用户来说,它应该是一个黑盒。
  • Transform 可用于生成转换后的训练示例,无论使用什么框架进行训练,但如果模型格式不是 saved_model,用户将无法将转换图嵌入到模型中。在这种情况下,模型预测需要采用转换后的特征而不是原始特征,用户可以在服务时在调用模型预测之前运行转换作为预处理步骤。
  • Trainer 支持 通用训练,因此用户可以使用任何 ML 框架训练其模型。
  • Evaluator 默认情况下只支持 saved_model,但用户可以提供一个 UDF 来生成模型评估的预测。

在非基于 Python 的框架中训练模型需要将自定义训练组件隔离在 Docker 容器中,作为在容器化环境(如 Kubernetes)中运行的管道的一部分。

JAX

JAX 是 Autograd 和 XLA,它们结合在一起用于高性能机器学习研究。 Flax 是一个用于 JAX 的神经网络库和生态系统,专为灵活性和可扩展性而设计。

使用 jax2tf,我们可以将训练好的 JAX/Flax 模型转换为 saved_model 格式,该格式可以在 TFX 中与通用训练和模型评估无缝使用。有关详细信息,请查看此 示例

scikit-learn

Scikit-learn 是 Python 编程语言的机器学习库。我们有一个 e2e 示例,其中在 TFX-Addons 中定制了训练和评估。