TFX for Mobile

简介

本指南演示了 Tensorflow Extended (TFX) 如何创建和评估将在设备上部署的机器学习模型。TFX 现在为 TFLite 提供原生支持,这使得在移动设备上执行高效推理成为可能。

本指南将引导您完成对任何管道进行的更改,以生成和评估 TFLite 模型。我们提供了一个完整的示例 此处,演示了 TFX 如何训练和评估从 MNIST 数据集训练的 TFLite 模型。此外,我们还展示了如何使用相同的管道同时导出标准基于 Keras 的 SavedModel 以及 TFLite 模型,允许用户比较两者的质量。

我们假设您熟悉 TFX、我们的组件和我们的管道。如果不是,请参阅此 教程

步骤

在 TFX 中创建和评估 TFLite 模型只需要两个步骤。第一步是在 TFX Trainer 的上下文中调用 TFLite 重写器,将训练后的 TensorFlow 模型转换为 TFLite 模型。第二步是配置 Evaluator 以评估 TFLite 模型。我们现在依次讨论每个步骤。

在 Trainer 中调用 TFLite 重写器。

TFX Trainer 预计用户定义的 run_fn 在模块文件中指定。此 run_fn 定义要训练的模型,对其进行指定迭代次数的训练,并导出训练后的模型。

在本节的其余部分,我们提供代码片段,这些代码片段显示了调用 TFLite 重写器和导出 TFLite 模型所需的更改。所有这些代码都位于 MNIST TFLite 模块run_fn 中。

如下面的代码所示,我们必须首先创建一个签名,该签名将每个特征的 Tensor 作为输入。请注意,这与 TFX 中大多数现有模型不同,大多数现有模型将序列化 tf.Example 协议作为输入。

 signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(
              model, tf_transform_output).get_concrete_function(
                  tf.TensorSpec(
                      shape=[None, 784],
                      dtype=tf.float32,
                      name='image_floats'))
  }

然后,Keras 模型以与通常相同的方式保存为 SavedModel。

  temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
  model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)

最后,我们创建 TFLite 重写器 (tfrw) 的实例,并在 SavedModel 上调用它以获取 TFLite 模型。我们将此 TFLite 模型存储在 run_fn 调用者提供的 serving_model_dir 中。这样,TFLite 模型将存储在所有下游 TFX 组件预期找到模型的位置。

  tfrw = rewriter_factory.create_rewriter(
      rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
  converters.rewrite_saved_model(temp_saving_model_dir,
                                 fn_args.serving_model_dir,
                                 tfrw,
                                 rewriter.ModelType.TFLITE_MODEL)

评估 TFLite 模型。

TFX Evaluator 提供了分析训练后的模型的能力,以了解其在各种指标上的质量。除了分析 SavedModel 之外,TFX Evaluator 现在还能够分析 TFLite 模型。

以下代码片段(摘自 MNIST 管道)展示了如何配置分析 TFLite 模型的 Evaluator。

  # Informs the evaluator that the model is a TFLite model.
  eval_config_lite.model_specs[0].model_type = 'tf_lite'

  ...

  # Uses TFMA to compute the evaluation statistics over features of a TFLite
  # model.
  model_analyzer_lite = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer_lite.outputs['model'],
      eval_config=eval_config_lite,
  ).with_id('mnist_lite')

如上所示,我们需要做的唯一更改是将 model_type 字段设置为 tf_lite。分析 TFLite 模型不需要进行其他配置更改。无论分析的是 TFLite 模型还是 SavedModel,Evaluator 的输出都将具有完全相同的结构。

但是,请注意,Evaluator 假设 TFLite 模型保存在 trainer_lite.outputs['model'] 中名为 tflite 的文件中。