使用 TensorFlow Lite 的 JAX 模型

此页面为希望在 JAX 中训练模型并在移动设备上部署以进行推理的用户提供了一条路径(示例 Colab)。

本指南中的方法会生成一个 tflite_model,该模型可以直接与 TFLite 解释器代码示例一起使用,也可以保存为 TFLite FlatBuffer 文件。

先决条件

建议使用最新的 TensorFlow 夜间构建 Python 包来尝试此功能。

pip install tf-nightly --upgrade

我们将使用 Orbax 导出 库来导出 JAX 模型。请确保您的 JAX 版本至少为 0.4.20 或更高版本。

pip install jax --upgrade
pip install orbax
-export --upgrade

将 JAX 模型转换为 TensorFlow Lite

我们将 TensorFlow SavedModel 作为 JAX 和 TensorFlow Lite 之间的中间格式。获得 SavedModel 后,可以使用现有的 TensorFlow Lite API 完成转换过程。

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
 
return jnp.sin(jnp.cos(x))

jax_module
= JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf
.saved_model.save(
    jax_module
,
   
'/some/directory',
    signatures
=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf
.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
   
),
    options
=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter
= tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model
= converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config
= ServingConfig(
   
'Serving_default',
   
# Corresponds to the input signature of `tf_preprocessor`
    input_signature
=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor
=lambda x: x,
    tf_postprocessor
=lambda out: {'output': out}
)
export_mgr
= ExportManager(jax_module, [serving_config])
export_mgr
.save('/some/directory')
converter
= tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model
= converter.convert()

# Option 3: Convert from TF concrete function directly
converter
= tf.lite.TFLiteConverter.from_concrete_functions(
   
[
        jax_module
.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf
.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
       
)
   
]
)
tflite_model
= converter.convert()

检查转换后的 TFLite 模型

将模型转换为 TFLite 后,您可以运行 TFLite 解释器 API 来检查模型输出。

# Run the model with TensorFlow Lite
interpreter
= tf.lite.Interpreter(model_content=tflite_model)
interpreter
.allocate_tensors() input_details = interpreter.get_input_details()
output_details
= interpreter.get_output_details()
interpreter
.set_tensor(input_details[0]["index"], input_data)
interpreter
.invoke()
result
= interpreter.get_tensor(output_details[0]["index"])