TensorFlow.js 提供保存和加载模型的功能,这些模型是使用 Layers API 创建的,或者从现有的 TensorFlow 模型转换而来。这些可能是您自己训练的模型,也可能是其他人训练的模型。使用 Layers API 的一个主要好处是,使用它创建的模型是可序列化的,这就是我们将在本教程中探讨的内容。
本教程将重点介绍保存和加载 TensorFlow.js 模型(由 JSON 文件标识)。我们也可以导入 TensorFlow Python 模型。加载这些模型将在以下两个教程中介绍
保存 tf.Model
tf.Model 和 tf.Sequential 都提供一个函数 model.save,允许您保存模型的拓扑结构和权重。
拓扑结构:这是一个描述模型架构的文件(即它使用哪些操作)。它包含对模型权重的引用,这些权重存储在外部。
权重:这些是二进制文件,以高效的格式存储给定模型的权重。它们通常存储在与拓扑结构相同的文件夹中。
让我们看一下保存模型的代码示例
const saveResult = await model.save('localstorage://my-model-1');
需要注意的几点
save方法接受一个以方案开头的类似 URL 的字符串参数。这描述了我们尝试保存模型的目标类型。在上面的示例中,方案是localstorage://- 方案后跟一个路径。在上面的示例中,路径是
my-model-1。 save方法是异步的。model.save的返回值是一个 JSON 对象,它包含模型拓扑结构和权重的字节大小等信息。- 用于保存模型的环境不会影响哪些环境可以加载模型。在 node.js 中保存模型不会阻止它在浏览器中加载。
下面我们将检查可用的不同方案。
本地存储(仅限浏览器)
方案: localstorage://
await model.save('localstorage://my-model');
这将在浏览器的 本地存储 中以 my-model 的名称保存模型。这将在刷新之间持续存在,尽管如果空间成为问题,用户或浏览器本身可以清除本地存储。每个浏览器还为给定域设置了他们自己的本地存储数据大小限制。
IndexedDB(仅限浏览器)
方案: indexeddb://
await model.save('indexeddb://my-model');
这会将模型保存到浏览器的 IndexedDB 存储中。与本地存储一样,它在刷新之间持续存在,它也往往对存储对象的尺寸有更大的限制。
文件下载(仅限浏览器)
方案: downloads://
await model.save('downloads://my-model');
这将导致浏览器将模型文件下载到用户的机器上。将生成两个文件
- 一个名为
[my-model].json的文本 JSON 文件,它包含拓扑结构和对下面描述的权重文件的引用。 - 一个名为
[my-model].weights.bin的二进制文件,它包含权重值。
您可以更改名称 [my-model] 以获取不同名称的文件。
因为 .json 文件使用相对路径指向 .bin,所以这两个文件应该在同一个文件夹中。
HTTP(S) 请求
方案: http:// 或 https://
await model.save('http://model-server.domain/upload')
这将创建一个 Web 请求,将模型保存到远程服务器。您应该控制该远程服务器,以便您可以确保它能够处理该请求。
模型将通过 POST 请求发送到指定的 HTTP 服务器。POST 的主体采用 multipart/form-data 格式,包含两个文件
- 一个名为
model.json的文本 JSON 文件,其中包含拓扑结构和对下面描述的权重文件的引用。 - 一个名为
model.weights.bin的二进制文件,其中包含权重值。
请注意,这两个文件的名称始终与上面指定的名称完全相同(名称内置于函数中)。此 api 文档 包含一个 Python 代码片段,演示了如何使用 flask Web 框架处理来自 save 的请求。
通常,您需要向 HTTP 服务器传递更多参数或请求标头(例如,用于身份验证或如果您想指定要保存模型的文件夹)。您可以通过替换 tf.io.browserHTTPRequest 中的 URL 字符串参数,对来自 save 的请求的这些方面进行细粒度控制。此 API 为控制 HTTP 请求提供了更大的灵活性。
例如
await model.save(tf.io.browserHTTPRequest(
'http://model-server.domain/upload',
{method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));
本地文件系统(仅限 Node.js)
方案: file://
await model.save('file:///path/to/my-model');
在 Node.js 上运行时,我们还可以直接访问文件系统,并将模型保存到那里。上面的命令将把两个文件保存到 scheme 后面指定的 path 中。
- 一个名为
[model].json的文本 JSON 文件,其中包含拓扑结构和对下面描述的权重文件的引用。 - 一个名为
[model].weights.bin的二进制文件,其中包含权重值。
请注意,这两个文件的名称始终与上面指定的名称完全相同(名称内置于函数中)。
加载 tf.Model
对于使用上述方法之一保存的模型,我们可以使用 tf.loadLayersModel API 加载它。
让我们看看加载模型的代码是什么样的
const model = await tf.loadLayersModel('localstorage://my-model-1');
需要注意的几点
- 与
model.save()一样,loadLayersModel函数接受一个以 **方案** 开头的 URL 类字符串参数。这描述了我们尝试从中加载模型的目标类型。 - 方案后跟一个路径。在上面的示例中,路径是
my-model-1。 - URL 类字符串可以替换为与 IOHandler 接口匹配的对象。
tf.loadLayersModel()函数是异步的。tf.loadLayersModel的返回值是tf.Model
下面我们将检查可用的不同方案。
本地存储(仅限浏览器)
方案: localstorage://
const model = await tf.loadLayersModel('localstorage://my-model');
这将从浏览器的 本地存储 中加载一个名为 my-model 的模型。
IndexedDB(仅限浏览器)
方案: indexeddb://
const model = await tf.loadLayersModel('indexeddb://my-model');
这将从浏览器的 IndexedDB 存储中加载模型。
HTTP(S)
方案: http:// 或 https://
const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');
这将从 http 端点加载模型。加载 json 文件后,该函数将对 json 文件引用的相应 .bin 文件发出请求。
本地文件系统(仅限 Node.js)
方案: file://
const model = await tf.loadLayersModel('file://path/to/my-model/model.json');
在 Node.js 上运行时,我们还可以直接访问文件系统,并从那里加载模型。请注意,在上面的函数调用中,我们引用了 model.json 文件本身(而在保存时,我们指定了一个文件夹)。相应的 .bin 文件应该与 json 文件位于同一个文件夹中。
使用 IOHandler 加载模型
如果上面的方案不足以满足您的需求,您可以使用 IOHandler 实现自定义加载行为。TensorFlow.js 提供的一个 IOHandler 是 tf.io.browserFiles,它允许浏览器用户在浏览器中上传模型文件。有关更多信息,请参阅 文档。
使用自定义 IOHandler 保存和加载模型
如果上面的方案不足以满足您的加载或保存需求,您可以通过实现 IOHandler 来实现自定义序列化行为。
一个 IOHandler 是一个具有 save 和 load 方法的对象。
save 函数接受一个与 ModelArtifacts 接口匹配的参数,并应返回一个解析为 SaveResult 对象的 Promise。
load 函数不接受任何参数,并应返回一个解析为 ModelArtifacts 对象的 Promise。这是传递给 save 的相同对象。
有关如何实现 IOHandler 的示例,请参阅 BrowserHTTPRequest。