将 TensorFlow 模型导入 TensorFlow.js

TensorFlow 基于 GraphDef 的模型(通常通过 Python API 创建)可以保存为以下格式之一

  1. TensorFlow SavedModel
  2. 冻结模型
  3. Tensorflow Hub 模块

所有上述格式都可以通过 TensorFlow.js 转换器 转换为可以直接加载到 TensorFlow.js 中进行推理的格式。

(注意:TensorFlow 已弃用会话捆绑包格式。请将您的模型迁移到 SavedModel 格式。)

要求

转换过程需要 Python 环境;您可能希望使用 pipenvvirtualenv 保持一个隔离的环境。

要安装转换器,请运行以下命令

 pip install tensorflowjs

将 TensorFlow 模型导入 TensorFlow.js 是一个两步过程。首先,将现有模型转换为 TensorFlow.js Web 格式,然后将其加载到 TensorFlow.js 中。

步骤 1. 将现有 TensorFlow 模型转换为 TensorFlow.js Web 格式

运行 pip 包提供的转换器脚本

SavedModel 示例

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

冻结模型示例

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Tensorflow Hub 模块示例

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
位置参数 描述
input_path 保存的模型目录、会话捆绑包目录、冻结模型文件或 TensorFlow Hub 模块句柄或路径的完整路径。
output_path 所有输出工件的路径。
选项 描述
--input_format 输入模型的格式。对于 SavedModel 使用 tf_saved_model,对于冻结模型使用 tf_frozen_model,对于会话捆绑包使用 tf_session_bundle,对于 TensorFlow Hub 模块使用 tf_hub,对于 Keras HDF5 使用 keras。
--output_node_names 输出节点的名称,用逗号分隔。
--saved_model_tags 仅适用于 SavedModel 转换。要加载的 MetaGraphDef 的标签,以逗号分隔的格式。默认为 serve
--signature_name 仅适用于 TensorFlow Hub 模块转换,要加载的签名。默认为 default。请参阅 https://tensorflowcn.cn/hub/common_signatures/

使用以下命令获取详细的帮助消息

tensorflowjs_converter --help

转换器生成的的文件

上面的转换脚本生成两种类型的文件

  • model.json:数据流图和权重清单
  • group1-shard\*of\*:二进制权重文件的集合

例如,以下是转换 MobileNet v2 的输出

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

步骤 2:在浏览器中加载和运行

  1. 安装 tfjs-converter npm 包

yarn add @tensorflow/tfjsnpm install @tensorflow/tfjs

  1. 实例化 FrozenModel 类 并运行推理。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

查看 MobileNet 演示

loadGraphModel API 接受一个额外的 LoadOptions 参数,可用于与请求一起发送凭据或自定义标头。有关详细信息,请参阅 loadGraphModel() 文档

支持的操作

目前 TensorFlow.js 支持有限的 TensorFlow 操作集。如果您的模型使用不支持的操作,则 tensorflowjs_converter 脚本将失败并打印出模型中不支持的操作列表。请为每个操作提交一个 问题,让我们知道您需要哪些操作的支持。

仅加载权重

如果您希望仅加载权重,可以使用以下代码片段

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...