TensorFlow 基于 GraphDef 的模型(通常通过 Python API 创建)可以保存为以下格式之一
- TensorFlow SavedModel
- 冻结模型
- Tensorflow Hub 模块
所有上述格式都可以通过 TensorFlow.js 转换器 转换为可以直接加载到 TensorFlow.js 中进行推理的格式。
(注意:TensorFlow 已弃用会话捆绑包格式。请将您的模型迁移到 SavedModel 格式。)
要求
转换过程需要 Python 环境;您可能希望使用 pipenv 或 virtualenv 保持一个隔离的环境。
要安装转换器,请运行以下命令
pip install tensorflowjs
将 TensorFlow 模型导入 TensorFlow.js 是一个两步过程。首先,将现有模型转换为 TensorFlow.js Web 格式,然后将其加载到 TensorFlow.js 中。
步骤 1. 将现有 TensorFlow 模型转换为 TensorFlow.js Web 格式
运行 pip 包提供的转换器脚本
SavedModel 示例
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
冻结模型示例
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Tensorflow Hub 模块示例
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
位置参数 | 描述 |
---|---|
input_path |
保存的模型目录、会话捆绑包目录、冻结模型文件或 TensorFlow Hub 模块句柄或路径的完整路径。 |
output_path |
所有输出工件的路径。 |
选项 | 描述 |
---|---|
--input_format |
输入模型的格式。对于 SavedModel 使用 tf_saved_model,对于冻结模型使用 tf_frozen_model,对于会话捆绑包使用 tf_session_bundle,对于 TensorFlow Hub 模块使用 tf_hub,对于 Keras HDF5 使用 keras。 |
--output_node_names |
输出节点的名称,用逗号分隔。 |
--saved_model_tags |
仅适用于 SavedModel 转换。要加载的 MetaGraphDef 的标签,以逗号分隔的格式。默认为 serve 。 |
--signature_name |
仅适用于 TensorFlow Hub 模块转换,要加载的签名。默认为 default 。请参阅 https://tensorflowcn.cn/hub/common_signatures/ |
使用以下命令获取详细的帮助消息
tensorflowjs_converter --help
转换器生成的的文件
上面的转换脚本生成两种类型的文件
model.json
:数据流图和权重清单group1-shard\*of\*
:二进制权重文件的集合
例如,以下是转换 MobileNet v2 的输出
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
步骤 2:在浏览器中加载和运行
- 安装 tfjs-converter npm 包
yarn add @tensorflow/tfjs
或 npm install @tensorflow/tfjs
- 实例化 FrozenModel 类 并运行推理。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
查看 MobileNet 演示。
loadGraphModel
API 接受一个额外的 LoadOptions
参数,可用于与请求一起发送凭据或自定义标头。有关详细信息,请参阅 loadGraphModel() 文档。
支持的操作
目前 TensorFlow.js 支持有限的 TensorFlow 操作集。如果您的模型使用不支持的操作,则 tensorflowjs_converter
脚本将失败并打印出模型中不支持的操作列表。请为每个操作提交一个 问题,让我们知道您需要哪些操作的支持。
仅加载权重
如果您希望仅加载权重,可以使用以下代码片段
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...