在本教程中,您将探索一个示例 Web 应用程序,它使用 Web 工作线程 来训练 循环神经网络 (RNN) 进行整数加法。示例应用程序未明确定义加法运算符。相反,它使用示例和进行 RNN 训练。
当然,这不是加两个整数的最有效方法!但本教程演示了 Web ML 中的一项重要技术:如何在不阻塞处理 UI 逻辑的主线程的情况下执行长时间运行的计算。
本教程的示例应用程序在线提供,因此您无需下载任何代码或设置开发环境。如果您想在本地运行代码,请完成在本地运行示例中的可选步骤。如果您不想设置开发环境,您可以跳到探索示例。
示例代码在GitHub上提供。
(可选) 在本地运行示例
先决条件
要在本地运行示例应用程序,您需要在开发环境中安装以下内容
安装并运行示例应用程序
- 克隆或下载
tfjs-examples
存储库。 切换到
addition-rnn-webworker
目录cd tfjs-examples/addition-rnn-webworker
安装依赖项
yarn
启动开发服务器
yarn run watch
探索示例
打开示例应用程序。(或者,如果您在本地运行示例,请在浏览器中转到https://127.0.0.1:1234
。)
您应该会看到一个标题为TensorFlow.js: 加法 RNN的页面。按照说明试用该应用程序。
使用网络表单,您可以更新用于训练模型的一些参数,包括以下参数
- 数字:要相加的项中的最大数字。
- 训练规模:要生成的训练示例数。
- RNN 类型:SimpleRNN、GRU 或LSTM之一。
- RNN 隐藏层大小:输出空间的维数(必须为正整数)。
- 批量大小:每次梯度更新的样本数。
- 训练迭代次数:通过调用
model.fit()
训练模型的次数 - 测试示例数:要生成的示例字符串数(例如,
27+41
)。
尝试使用不同的参数训练模型,看看您是否可以提高对各种数字集的预测准确性。还要注意模型拟合时间如何受到不同参数的影响。
探索代码
示例应用程序演示了可用于配置 RNN 训练的部分参数。它还演示了使用 Web 工作器在主线程之外训练模型。Web 工作器在 Web ML 中很重要,因为它们允许你在后台线程上运行计算量大的训练任务,从而避免在主线程上出现可能影响用户性能的问题。主线程和工作器线程通过消息事件相互通信。
要了解有关 Web 工作器的更多信息,请参阅 Web Workers API 和 使用 Web 工作器。
示例应用程序的主模块是 index.js
。index.js
脚本 创建一个 Web 工作器,该工作器运行 worker.js
模块
const worker =
new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});
index.js
主要由一个函数组成,runAdditionRNNDemo
,它处理表单提交、处理表单数据、将表单数据传递给工作器、等待工作器训练模型并返回结果,然后在页面上显示结果。
要将表单数据发送给工作器,脚本 在工作器上调用 postMessage
worker.postMessage({
digits,
trainingSize,
rnnType,
layers,
hiddenSize,
trainIterations,
batchSize,
numTestExamples
});
工作器 侦听此消息,并将表单数据传递给准备数据和开始训练的函数
self.addEventListener('message', async (e) => {
const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
await demo.train(trainIterations, batchSize, numTestExamples);
})
在训练期间,工作器可以发送两种不同的消息类型,一种 将 isPredict
设置为 true
self.postMessage({
isPredict: true,
i, iterations, modelFitTime,
lossValues, accuracyValues,
});
而 另一种 将 isPredict
设置为 false
。
self.postMessage({
isPredict: false,
isCorrect, examples
});
当 UI 线程 (index.js
) 处理消息事件时,它会 检查 isPredict
标志,以确定从工作线程返回的数据的形状。如果 isPredict
为 true,则数据应表示预测,并且脚本会 使用 tfjs-vis
更新页面。如果 isPredict
为 false,则脚本会运行 一段代码,该代码假定数据表示示例。它将数据包装在 HTML 中,并将 HTML 插入页面中。
下一步
本教程提供了一个使用 Web 工作线程来避免长时间运行的训练过程阻塞 UI 线程的示例。要详细了解在后台线程上进行昂贵计算的好处,请参阅 使用 Web 工作线程在浏览器主线程之外运行 JavaScript。
要详细了解如何训练 TensorFlow.js 模型,请参阅 训练模型。