简介
本指南演示了 Tensorflow Extended (TFX) 如何创建和评估将在设备上部署的机器学习模型。TFX 现在为 TFLite 提供原生支持,这使得在移动设备上执行高效推理成为可能。
本指南将引导您完成对任何管道进行的更改,以生成和评估 TFLite 模型。我们提供了一个完整的示例 此处,演示了 TFX 如何训练和评估从 MNIST 数据集训练的 TFLite 模型。此外,我们还展示了如何使用相同的管道同时导出标准基于 Keras 的 SavedModel 以及 TFLite 模型,允许用户比较两者的质量。
我们假设您熟悉 TFX、我们的组件和我们的管道。如果不是,请参阅此 教程。
步骤
在 TFX 中创建和评估 TFLite 模型只需要两个步骤。第一步是在 TFX Trainer 的上下文中调用 TFLite 重写器,将训练后的 TensorFlow 模型转换为 TFLite 模型。第二步是配置 Evaluator 以评估 TFLite 模型。我们现在依次讨论每个步骤。
在 Trainer 中调用 TFLite 重写器。
TFX Trainer 预计用户定义的 run_fn
在模块文件中指定。此 run_fn
定义要训练的模型,对其进行指定迭代次数的训练,并导出训练后的模型。
在本节的其余部分,我们提供代码片段,这些代码片段显示了调用 TFLite 重写器和导出 TFLite 模型所需的更改。所有这些代码都位于 MNIST TFLite 模块 的 run_fn
中。
如下面的代码所示,我们必须首先创建一个签名,该签名将每个特征的 Tensor
作为输入。请注意,这与 TFX 中大多数现有模型不同,大多数现有模型将序列化 tf.Example 协议作为输入。
signatures = {
'serving_default':
_get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None, 784],
dtype=tf.float32,
name='image_floats'))
}
然后,Keras 模型以与通常相同的方式保存为 SavedModel。
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
最后,我们创建 TFLite 重写器 (tfrw
) 的实例,并在 SavedModel 上调用它以获取 TFLite 模型。我们将此 TFLite 模型存储在 run_fn
调用者提供的 serving_model_dir
中。这样,TFLite 模型将存储在所有下游 TFX 组件预期找到模型的位置。
tfrw = rewriter_factory.create_rewriter(
rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
converters.rewrite_saved_model(temp_saving_model_dir,
fn_args.serving_model_dir,
tfrw,
rewriter.ModelType.TFLITE_MODEL)
评估 TFLite 模型。
TFX Evaluator 提供了分析训练后的模型的能力,以了解其在各种指标上的质量。除了分析 SavedModel 之外,TFX Evaluator 现在还能够分析 TFLite 模型。
以下代码片段(摘自 MNIST 管道)展示了如何配置分析 TFLite 模型的 Evaluator。
# Informs the evaluator that the model is a TFLite model.
eval_config_lite.model_specs[0].model_type = 'tf_lite'
...
# Uses TFMA to compute the evaluation statistics over features of a TFLite
# model.
model_analyzer_lite = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer_lite.outputs['model'],
eval_config=eval_config_lite,
).with_id('mnist_lite')
如上所示,我们需要做的唯一更改是将 model_type
字段设置为 tf_lite
。分析 TFLite 模型不需要进行其他配置更改。无论分析的是 TFLite 模型还是 SavedModel,Evaluator
的输出都将具有完全相同的结构。
但是,请注意,Evaluator 假设 TFLite 模型保存在 trainer_lite.outputs['model'] 中名为 tflite
的文件中。