本教程将向您展示如何使用 TensorFlow Serving 组件导出经过训练的 TensorFlow 模型,并使用标准的 tensorflow_model_server 来服务它。如果您已经熟悉 TensorFlow Serving,并且想要了解有关服务器内部工作原理的更多信息,请参阅 TensorFlow Serving 高级教程。
本教程使用一个简单的 Softmax 回归模型来对手写数字进行分类。它与 TensorFlow 使用 Fashion MNIST 数据集进行图像分类的教程 中介绍的模型非常相似。
本教程的代码分为两个部分
一个 Python 文件,mnist_saved_model.py,用于训练和导出模型。
一个 ModelServer 二进制文件,可以使用 Apt 安装,也可以从 C++ 文件 (main.cc) 编译。TensorFlow Serving ModelServer 会发现新的导出模型,并运行一个 gRPC 服务来服务它们。
在开始之前,请先 安装 Docker。
训练和导出 TensorFlow 模型
在训练阶段,TensorFlow 图表在 TensorFlow 会话 sess
中启动,输入张量(图像)为 x
,输出张量(Softmax 分数)为 y
。
然后,我们使用 TensorFlow 的 SavedModelBuilder 模块 来导出模型。 SavedModelBuilder
将经过训练的模型的“快照”保存到可靠的存储中,以便以后加载以进行推理。
有关 SavedModel 格式的详细信息,请参阅 SavedModel README.md 中的文档。
从 mnist_saved_model.py 中,以下是一段简短的代码片段,用于说明将模型保存到磁盘的一般过程。
export_path_base = sys.argv[-1]
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(FLAGS.model_version)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
tf.compat.v1.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
main_op=tf.compat.v1.tables_initializer(),
strip_default_attrs=True)
builder.save()
SavedModelBuilder.__init__
接受以下参数
export_path
是导出目录的路径。
SavedModelBuilder
将在目录不存在的情况下创建它。在本例中,我们将命令行参数和 FLAGS.model_version
连接起来以获得导出目录。 FLAGS.model_version
指定模型的**版本**。在导出同一模型的较新版本时,您应该指定一个更大的整数值。每个版本都将导出到给定路径下的不同子目录中。
您可以使用 SavedModelBuilder.add_meta_graph_and_variables()
将元图和变量添加到构建器中,并使用以下参数
sess
是保存要导出的经过训练模型的 TensorFlow 会话。tags
是用于保存元图的标签集。在本例中,由于我们打算在服务中使用该图,因此我们使用预定义的 SavedModel 标签常量中的serve
标签。有关更多详细信息,请参阅 tag_constants.py 和 相关的 TensorFlow API 文档。signature_def_map
指定了用于将**签名**的用户提供的键映射到 tensorflow::SignatureDef,以添加到元图中。签名指定了要导出的模型类型,以及在运行推理时要绑定的输入/输出张量。特殊的签名键
serving_default
指定了默认服务签名。默认服务签名定义键以及与签名相关的其他常量,定义为 SavedModel 签名常量的一部分。有关更多详细信息,请参阅 signature_constants.py 和 相关的 TensorFlow API 文档。此外,为了帮助轻松构建签名定义,SavedModel API 提供了 签名定义工具。具体来说,在原始的 mnist_saved_model.py 文件中,我们使用
signature_def_utils.build_signature_def()
来构建predict_signature
和classification_signature
。例如,
predict_signature
的定义方式如下,该工具接受以下参数:inputs={'images': tensor_info_x}
指定输入张量信息。outputs={'scores': tensor_info_y}
指定分数张量信息。method_name
是用于推理的方法。对于预测请求,它应该设置为tensorflow/serving/predict
。对于其他方法名称,请参见 signature_constants.py 和 相关的 TensorFlow API 文档。
请注意,tensor_info_x
和 tensor_info_y
具有 tensorflow::TensorInfo
协议缓冲区的结构,该协议缓冲区定义 在此。为了轻松构建张量信息,TensorFlow SavedModel API 还提供了 utils.py,以及 相关的 TensorFlow API 文档。
此外,请注意,images
和 scores
是张量别名。它们可以是您想要的任何唯一的字符串,并且它们将成为您在稍后发送预测请求时用于张量绑定的张量 x
和 y
的逻辑名称。
例如,如果 x
指的是名为 'long_tensor_name_foo' 的张量,而 y
指的是名为 'generated_tensor_name_bar' 的张量,则 builder
将存储张量逻辑名称到真实名称的映射 ('images' -> 'long_tensor_name_foo') 和 ('scores' -> 'generated_tensor_name_bar')。这允许用户在运行推理时使用其逻辑名称来引用这些张量。
让我们运行它!
首先,如果您还没有这样做,请将此存储库克隆到您的本地机器上
git clone https://github.com/tensorflow/serving.git
cd serving
如果导出目录已存在,请清除它
rm -rf /tmp/mnist
现在让我们训练模型
tools/run_in_docker.sh python tensorflow_serving/example/mnist_saved_model.py \
/tmp/mnist
这应该会产生类似于以下内容的输出
Training model...
...
Done training!
Exporting trained model to models/mnist
Done exporting!
现在让我们看一下导出目录。
$ ls /tmp/mnist
1
如上所述,将为导出模型的每个版本创建一个子目录。 FLAGS.model_version
的默认值为 1,因此创建了相应的子目录 1
。
$ ls /tmp/mnist/1
saved_model.pb variables
每个版本子目录包含以下文件
saved_model.pb
是序列化后的 tensorflow::SavedModel。它包含模型的一个或多个图形定义,以及模型的元数据,例如签名。variables
是保存图形序列化变量的文件。
有了这些,您的 TensorFlow 模型就已导出,可以加载了!
使用标准 TensorFlow ModelServer 加载导出的模型
使用 Docker 服务镜像来轻松加载模型以供服务
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/mnist,target=/models/mnist \
-e MODEL_NAME=mnist -t tensorflow/serving &
测试服务器
我们可以使用提供的 mnist_client 工具来测试服务器。客户端下载 MNIST 测试数据,将它们作为请求发送到服务器,并计算推理错误率。
tools/run_in_docker.sh python tensorflow_serving/example/mnist_client.py \
--num_tests=1000 --server=127.0.0.1:8500
这应该输出类似于以下内容的内容
...
Inference error rate: 11.13%
我们预计训练后的 Softmax 模型的准确率约为 90%,对于前 1000 个测试图像,我们获得了 11% 的推理错误率。这证实了服务器成功加载并运行了训练后的模型!