本指南训练了一个神经网络模型,用于对服装图像(如运动鞋和衬衫)进行分类,保存训练后的模型,然后使用TensorFlow Serving提供服务。重点是 TensorFlow Serving,而不是 TensorFlow 中的建模和训练,因此,有关专注于建模和训练的完整示例,请参阅基本分类示例。
本指南使用tf.keras,这是一个高级 API,用于在 TensorFlow 中构建和训练模型。
import sys
# Confirm that we're using Python 3
assert sys.version_info.major == 3, 'Oops, not running Python 3. Use Runtime > Change runtime type'
# TensorFlow and tf.keras
print("Installing dependencies for Colab environment")
!pip install -Uq grpcio==1.26.0
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess
print('TensorFlow version: {}'.format(tf.__version__))
2024-04-30 10:41:08.716822: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-04-30 10:41:08.716875: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-04-30 10:41:08.718684: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
创建您的模型
导入 Fashion MNIST 数据集
本指南使用Fashion MNIST数据集,其中包含 70,000 张 10 个类别的灰度图像。这些图像以低分辨率(28x28 像素)显示了单个服装,如下所示
图 1. Fashion-MNIST 样本(由 Zalando 提供,MIT 许可)。 |
Fashion MNIST 旨在作为经典MNIST数据集的直接替代品,该数据集通常用作计算机视觉机器学习程序的“Hello, World”。您可以直接从 TensorFlow 访问 Fashion MNIST,只需导入并加载数据即可。
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0
# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 29515/29515 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26421880/26421880 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 5148/5148 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4422102/4422102 [==============================] - 0s 0us/step train_images.shape: (60000, 28, 28, 1), of float64 test_images.shape: (10000, 28, 28, 1), of float64
训练和评估您的模型
让我们使用最简单的 CNN,因为我们不关注建模部分。
model = keras.Sequential([
keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3,
strides=2, activation='relu', name='Conv1'),
keras.layers.Flatten(),
keras.layers.Dense(10, name='Dense')
])
model.summary()
testing = False
epochs = 5
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= Conv1 (Conv2D) (None, 13, 13, 8) 80 flatten (Flatten) (None, 1352) 0 Dense (Dense) (None, 10) 13530 ================================================================= Total params: 13610 (53.16 KB) Trainable params: 13610 (53.16 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ Epoch 1/5 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1714473676.987867 178558 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1875/1875 [==============================] - 6s 2ms/step - loss: 0.5614 - sparse_categorical_accuracy: 0.8059 Epoch 2/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4274 - sparse_categorical_accuracy: 0.8499 Epoch 3/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.3859 - sparse_categorical_accuracy: 0.8654 Epoch 4/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.3561 - sparse_categorical_accuracy: 0.8740 Epoch 5/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.3328 - sparse_categorical_accuracy: 0.8818 313/313 [==============================] - 1s 2ms/step - loss: 0.3635 - sparse_categorical_accuracy: 0.8706 Test accuracy: 0.8705999851226807
保存您的模型
要将我们训练后的模型加载到 TensorFlow Serving 中,我们首先需要将其保存为SavedModel 格式。这将在一个定义良好的目录层次结构中创建一个 protobuf 文件,并将包含一个版本号。TensorFlow Serving 允许我们选择在进行推断请求时要使用哪个版本的模型或“可服务模型”。每个版本都将导出到给定路径下的不同子目录中。
# Fetch the Keras session and save the model
# The signature definition is defined by the input and output tensors,
# and stored with the default serving key
import tempfile
MODEL_DIR = tempfile.gettempdir()
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))
tf.keras.models.save_model(
model,
export_path,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None
)
print('\nSaved model:')
!ls -l {export_path}
export_path = /tmpfs/tmp/1 INFO:tensorflow:Assets written to: /tmpfs/tmp/1/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/1/assets Saved model: total 104 drwxr-xr-x 2 kbuilder kbuilder 4096 Apr 30 10:41 assets -rw-rw-r-- 1 kbuilder kbuilder 54 Apr 30 10:41 fingerprint.pb -rw-rw-r-- 1 kbuilder kbuilder 8798 Apr 30 10:41 keras_metadata.pb -rw-rw-r-- 1 kbuilder kbuilder 78816 Apr 30 10:41 saved_model.pb drwxr-xr-x 2 kbuilder kbuilder 4096 Apr 30 10:41 variables
检查您的保存模型
我们将使用命令行实用程序saved_model_cli
来查看MetaGraphDefs(模型)和SignatureDefs(您可以调用的方法)在我们的 SavedModel 中。请参阅TensorFlow 指南中有关 SavedModel CLI 的讨论。
saved_model_cli show --dir {export_path} --all
2024-04-30 10:41:41.354485: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-04-30 10:41:41.354542: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-04-30 10:41:41.355947: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['__saved_model_init_op']: The given SavedModel SignatureDef contains the following input(s): The given SavedModel SignatureDef contains the following output(s): outputs['__saved_model_init_op'] tensor_info: dtype: DT_INVALID shape: unknown_rank name: NoOp Method name is: signature_def['serving_default']: The given SavedModel SignatureDef contains the following input(s): inputs['Conv1_input'] tensor_info: dtype: DT_FLOAT shape: (-1, 28, 28, 1) name: serving_default_Conv1_input:0 The given SavedModel SignatureDef contains the following output(s): outputs['Dense'] tensor_info: dtype: DT_FLOAT shape: (-1, 10) name: StatefulPartitionedCall:0 Method name is: tensorflow/serving/predict The MetaGraph with tag set ['serve'] contains the following ops: {'SaveV2', 'Pack', 'ShardedFilename', 'NoOp', 'ReadVariableOp', 'Placeholder', 'Reshape', 'RestoreV2', 'StaticRegexFullMatch', 'Const', 'Identity', 'Select', 'Conv2D', 'AssignVariableOp', 'Relu', 'BiasAdd', 'MergeV2Checkpoints', 'DisableCopyOnRead', 'StringJoin', 'StatefulPartitionedCall', 'VarHandleOp', 'MatMul'} Concrete Functions: Function Name: '__call__' Option #1 Callable with: Argument #1 Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input') Argument #2 DType: bool Value: False Argument #3 DType: NoneType Value: None Option #2 Callable with: Argument #1 Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input') Argument #2 DType: bool Value: True Argument #3 DType: NoneType Value: None Function Name: '_default_save_signature' Option #1 Callable with: Argument #1 Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input') Function Name: 'call_and_return_all_conditional_losses' Option #1 Callable with: Argument #1 Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input') Argument #2 DType: bool Value: False Argument #3 DType: NoneType Value: None Option #2 Callable with: Argument #1 Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input') Argument #2 DType: bool Value: True Argument #3 DType: NoneType Value: None
这告诉我们很多关于我们模型的信息!在本例中,我们只是训练了我们的模型,因此我们已经知道输入和输出,但如果我们不知道,这将是重要的信息。它不会告诉我们所有信息,例如,这是灰度图像数据,但这是一个很好的开始。
使用 TensorFlow Serving 提供您的模型
将 TensorFlow Serving 分发 URI 添加为软件包源
我们正在准备使用Aptitude 安装 TensorFlow Serving,因为此 Colab 在 Debian 环境中运行。我们将tensorflow-model-server
软件包添加到 Aptitude 已知的软件包列表中。请注意,我们以 root 身份运行。
import sys
# We need sudo prefix if not on a Google Colab.
if 'google.colab' not in sys.modules:
SUDO_IF_NEEDED = 'sudo'
else:
SUDO_IF_NEEDED = ''
# This is the same as you would do from your command line, but without the [arch=amd64], and no sudo
# You would instead do:
# echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
# curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
!echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | {SUDO_IF_NEEDED} tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | {SUDO_IF_NEEDED} apt-key add -
!{SUDO_IF_NEEDED} apt update
deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 2943 100 2943 0 0 52553 0 --:--:-- --:--:-- --:--:-- 53509 OK Hit:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal InRelease Hit:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease Hit:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease Hit:4 https://download.docker.com/linux/ubuntu focal InRelease Get:5 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64 InRelease [1484 B] Hit:6 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64 InRelease Hit:7 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64 InRelease Hit:8 http://security.ubuntu.com/ubuntu focal-security InRelease Get:9 http://storage.googleapis.com/tensorflow-serving-apt stable InRelease [3026 B] Hit:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease Hit:11 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease Hit:13 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease Hit:14 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease Hit:12 https://apt.llvm.org/focal llvm-toolchain-focal-17 InRelease Get:15 http://storage.googleapis.com/tensorflow-serving-apt stable/tensorflow-model-server amd64 Packages [341 B] Get:16 http://storage.googleapis.com/tensorflow-serving-apt stable/tensorflow-model-server-universal amd64 Packages [349 B] Fetched 5200 B in 2s (3446 B/s) 235 packages can be upgraded. Run 'apt list --upgradable' to see them.
安装 TensorFlow Serving
这正是您需要的 - 一条命令行!
# TODO: Use the latest model server version when colab supports it.
#!{SUDO_IF_NEEDED} apt-get install tensorflow-model-server
# We need to install Tensorflow Model server 2.8 instead of latest version
# Tensorflow Serving >2.9.0 required `GLIBC_2.29` and `GLIBCXX_3.4.26`. Currently colab environment doesn't support latest version of`GLIBC`,so workaround is to use specific version of Tensorflow Serving `2.8.0` to mitigate issue.
wget 'http://storage.googleapis.com/tensorflow-serving-apt/pool/tensorflow-model-server-2.8.0/t/tensorflow-model-server/tensorflow-model-server_2.8.0_all.deb'
dpkg -i tensorflow-model-server_2.8.0_all.deb
pip3 install tensorflow-serving-api==2.8.0
--2024-04-30 10:41:50-- http://storage.googleapis.com/tensorflow-serving-apt/pool/tensorflow-model-server-2.8.0/t/tensorflow-model-server/tensorflow-model-server_2.8.0_all.deb Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.172.207, 142.251.180.207, 142.251.183.207, ... Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.172.207|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 340152790 (324M) [application/x-debian-package] Saving to: ‘tensorflow-model-server_2.8.0_all.deb’ tensorflow-model-se 100%[===================>] 324.39M 223MB/s in 1.5s 2024-04-30 10:41:51 (223 MB/s) - ‘tensorflow-model-server_2.8.0_all.deb’ saved [340152790/340152790] dpkg: error: requested operation requires superuser privilege Collecting tensorflow-serving-api==2.8.0 Downloading tensorflow_serving_api-2.8.0-py2.py3-none-any.whl.metadata (1.8 kB) Requirement already satisfied: grpcio<2,>=1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-serving-api==2.8.0) (1.26.0) Requirement already satisfied: protobuf>=3.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-serving-api==2.8.0) (3.20.3) Requirement already satisfied: tensorflow<3,>=2.8.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-serving-api==2.8.0) (2.15.1) Requirement already satisfied: six>=1.5.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grpcio<2,>=1.0->tensorflow-serving-api==2.8.0) (1.16.0) Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.4.0) Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.6.3) Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (24.3.25) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.5.4) Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.2.0) Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.11.0) Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (18.1.1) Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.3.2) Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.26.4) Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.3.0) Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (24.0) Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (69.5.1) Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.4.0) Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (4.11.0) Requirement already satisfied: wrapt<1.15,>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.14.1) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.36.0) Requirement already satisfied: tensorboard<2.16,>=2.15 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.15.2) Requirement already satisfied: tensorflow-estimator<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.15.0) Requirement already satisfied: keras<2.16,>=2.15.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.15.0) Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.43.0) Collecting grpcio<2,>=1.0 (from tensorflow-serving-api==2.8.0) Downloading grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB) Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.29.0) Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.2.0) Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.6) Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.31.0) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.0.2) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (5.3.3) Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.4.0) Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (4.9) Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.0.0) Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (7.1.0) Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.26.18) Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2024.2.2) Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.1.5) Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.18.1) Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.6.0) Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard<2.16,>=2.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.2.2) Downloading tensorflow_serving_api-2.8.0-py2.py3-none-any.whl (37 kB) Downloading grpcio-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.6 MB) Installing collected packages: grpcio, tensorflow-serving-api Attempting uninstall: grpcio Found existing installation: grpcio 1.26.0 Uninstalling grpcio-1.26.0: Successfully uninstalled grpcio-1.26.0 Attempting uninstall: tensorflow-serving-api Found existing installation: tensorflow-serving-api 2.15.1 Uninstalling tensorflow-serving-api-2.15.1: Successfully uninstalled tensorflow-serving-api-2.15.1 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tf-keras 2.16.0 requires tensorflow<2.17,>=2.16, but you have tensorflow 2.15.1 which is incompatible. tfx 1.15.0 requires tensorflow-serving-api<2.16,>=2.15, but you have tensorflow-serving-api 2.8.0 which is incompatible. tfx-bsl 1.15.1 requires tensorflow-serving-api<3,>=2.13.0, but you have tensorflow-serving-api 2.8.0 which is incompatible. Successfully installed grpcio-1.62.2 tensorflow-serving-api-2.8.0
开始运行 TensorFlow Serving
在这里,我们开始运行 TensorFlow Serving 并加载我们的模型。加载后,我们可以开始使用 REST 进行推断请求。有一些重要的参数
rest_api_port
:您将用于 REST 请求的端口。model_name
:您将在 REST 请求的 URL 中使用它。它可以是任何内容。model_base_path
:这是您保存模型的目录的路径。
os.environ["MODEL_DIR"] = MODEL_DIR
nohup tensorflow_model_server \
--rest_api_port=8501 \
--model_name=fashion_model \
--model_base_path="${MODEL_DIR}" >server.log 2>&1
tail server.log
nohup: failed to run command 'tensorflow_model_server': No such file or directory
向 TensorFlow Serving 中的模型发出请求
首先,让我们看一下测试数据中的一个随机示例。
def show(idx, title):
plt.figure()
plt.imshow(test_images[idx].reshape(28,28))
plt.axis('off')
plt.title('\n\n{}'.format(title), fontdict={'size': 16})
import random
rando = random.randint(0,len(test_images)-1)
show(rando, 'An Example Image: {}'.format(class_names[test_labels[rando]]))
好的,看起来很有趣。您识别这个有多难?现在让我们为一批三个推理请求创建 JSON 对象,看看我们的模型识别事物的效果如何。
import json
data = json.dumps({"signature_name": "serving_default", "instances": test_images[0:3].tolist()})
print('Data: {} ... {}'.format(data[:50], data[len(data)-52:]))
Data: {"signature_name": "serving_default", "instances": ... [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]]}
发出 REST 请求
可服务项的最新版本
我们将向服务器的 REST 端点发送一个 POST 预测请求,并传递三个示例。我们将要求服务器通过不指定特定版本来提供我们可服务项的最新版本。
# docs_infra: no_execute
!pip install -q requests
import requests
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/fashion_model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']
show(0, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(
class_names[np.argmax(predictions[0])], np.argmax(predictions[0]), class_names[test_labels[0]], test_labels[0]))
可服务项的特定版本
现在让我们指定可服务项的特定版本。由于我们只有一个,让我们选择版本 1。我们还将查看所有三个结果。
# docs_infra: no_execute
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/fashion_model/versions/1:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']
for i in range(0,3):
show(i, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(
class_names[np.argmax(predictions[i])], np.argmax(predictions[i]), class_names[test_labels[i]], test_labels[i]))