使用 Web 工作线程训练模型

在本教程中,您将探索一个示例 Web 应用程序,它使用 Web 工作线程 来训练 循环神经网络 (RNN) 进行整数加法。示例应用程序未明确定义加法运算符。相反,它使用示例和进行 RNN 训练。

当然,这不是加两个整数的最有效方法!但本教程演示了 Web ML 中的一项重要技术:如何在不阻塞处理 UI 逻辑的主线程的情况下执行长时间运行的计算。

本教程的示例应用程序在线提供,因此您无需下载任何代码或设置开发环境。如果您想在本地运行代码,请完成在本地运行示例中的可选步骤。如果您不想设置开发环境,您可以跳到探索示例

示例代码在GitHub上提供。

(可选) 在本地运行示例

先决条件

要在本地运行示例应用程序,您需要在开发环境中安装以下内容

安装并运行示例应用程序

  1. 克隆或下载tfjs-examples存储库。
  2. 切换到addition-rnn-webworker目录

    cd tfjs-examples/addition-rnn-webworker
    
  3. 安装依赖项

    yarn
    
  4. 启动开发服务器

    yarn run watch
    

探索示例

打开示例应用程序。(或者,如果您在本地运行示例,请在浏览器中转到http://localhost:1234。)

您应该会看到一个标题为TensorFlow.js: 加法 RNN的页面。按照说明试用该应用程序。

使用网络表单,您可以更新用于训练模型的一些参数,包括以下参数

  • 数字:要相加的项中的最大数字。
  • 训练规模:要生成的训练示例数。
  • RNN 类型SimpleRNNGRULSTM之一。
  • RNN 隐藏层大小:输出空间的维数(必须为正整数)。
  • 批量大小:每次梯度更新的样本数。
  • 训练迭代次数:通过调用model.fit()训练模型的次数
  • 测试示例数:要生成的示例字符串数(例如,27+41)。

尝试使用不同的参数训练模型,看看您是否可以提高对各种数字集的预测准确性。还要注意模型拟合时间如何受到不同参数的影响。

探索代码

示例应用程序演示了可用于配置 RNN 训练的部分参数。它还演示了使用 Web 工作器在主线程之外训练模型。Web 工作器在 Web ML 中很重要,因为它们允许你在后台线程上运行计算量大的训练任务,从而避免在主线程上出现可能影响用户性能的问题。主线程和工作器线程通过消息事件相互通信。

要了解有关 Web 工作器的更多信息,请参阅 Web Workers API使用 Web 工作器

示例应用程序的主模块是 index.jsindex.js 脚本 创建一个 Web 工作器,该工作器运行 worker.js 模块

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js 主要由一个函数组成,runAdditionRNNDemo,它处理表单提交、处理表单数据、将表单数据传递给工作器、等待工作器训练模型并返回结果,然后在页面上显示结果。

要将表单数据发送给工作器,脚本 在工作器上调用 postMessage

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

工作器 侦听此消息,并将表单数据传递给准备数据和开始训练的函数

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

在训练期间,工作器可以发送两种不同的消息类型,一种isPredict 设置为 true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

另一种isPredict 设置为 false

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

当 UI 线程 (index.js) 处理消息事件时,它会 检查 isPredict 标志,以确定从工作线程返回的数据的形状。如果 isPredict 为 true,则数据应表示预测,并且脚本会 使用 tfjs-vis 更新页面。如果 isPredict 为 false,则脚本会运行 一段代码,该代码假定数据表示示例。它将数据包装在 HTML 中,并将 HTML 插入页面中。

下一步

本教程提供了一个使用 Web 工作线程来避免长时间运行的训练过程阻塞 UI 线程的示例。要详细了解在后台线程上进行昂贵计算的好处,请参阅 使用 Web 工作线程在浏览器主线程之外运行 JavaScript

要详细了解如何训练 TensorFlow.js 模型,请参阅 训练模型