使用 TensorFlow.js 进行预测性预取

在本教程中,您将运行一个示例 Web 应用程序,该应用程序使用 TensorFlow.js 进行资源的预测性预取。该示例使用 Angular 构建,灵感来自 Google 商品商店,但与之不共享任何数据或实现细节。

该示例使用预训练模型进行预测。在实际场景中,您需要使用网站分析数据来训练模型。您可以使用 TFX 进行此类训练。要了解有关为预测性预取训练自定义模型的更多信息,请参阅 这篇博文

示例代码可在 GitHub 上获取。

先决条件

要完成本教程,您需要在开发环境中安装以下内容

安装示例

获取源代码并安装依赖项

  1. 克隆或下载 tfjs-examples 存储库。
  2. 进入 angular-predictive-prefetching/client 目录并安装依赖项

    cd tfjs-examples/angular-predictive-prefetching/client && yarn
    
  3. 进入 angular-predictive-prefetching/server 目录并安装依赖项

    cd ../server && yarn
    

运行示例

启动服务器和客户端

  1. 启动服务器:在 server 目录中,运行 yarn start

  2. 启动客户端

    1. 打开另一个终端窗口。
    2. 进入 tfjs-examples/angular-predictive-prefetching/client
    3. 运行以下命令

      yarn build
      cd dist/merch-store
      npx serve -s .
      

      您可能会被提示安装 serve 包。如果是,请输入 y 以安装该包。

  3. 在浏览器中导航到 https://127.0.0.1:3000。您应该会看到一个模拟的 Google 商品商店。

使用开发者工具进行探索

使用 Chrome 开发者工具查看预测性预取的实际操作

  1. 打开开发者工具并选择 控制台
  2. 在应用程序中导航到几个不同的页面,以启动应用程序。然后在左侧导航中选择 特价。您应该会看到类似于以下内容的日志输出

    Navigating from: 'sale'
    'quickview' -> 0.381757915019989
    'apparel-unisex' -> 0.3150934875011444
    'store.html' -> 0.1957530975341797
    '' -> 0.052346792072057724
    'signin.html' -> 0.0007763378671370447
    

    此输出显示了对您(用户)将要访问的页面的预测。应用程序根据这些预测获取资源。

  3. 要查看获取请求,请选择 网络。输出有点嘈杂,但您应该能够找到对预测页面资源的请求。例如,在预测 quickview 后,应用程序会向 https://127.0.0.1:8000/api/merch/quickview 发出请求。

预测性预取的工作原理

示例应用程序使用预训练模型来预测用户将要访问的下一个页面。当用户导航到新页面时,应用程序会查询模型,然后预取与预测页面关联的图像。

应用程序在 服务工作者 上执行预测性预取,以便它可以在不阻塞主线程的情况下查询模型。根据用户的导航历史记录,服务工作者会对未来的导航进行预测,并预取相关的产品图像。

服务工作者在 Angular 应用程序的主文件中加载,即 main.ts

if ('serviceWorker' in navigator) {
  navigator.serviceWorker.register('/prefetch.service-worker.js', { scope: '/' });
}

上面的代码段会下载 prefetch.service-worker.js 脚本并在后台运行它。

merch-display.component.ts 中,应用程序会将导航事件转发到服务工作者

this.route.params.subscribe((routeParams) => {
  this.getMerch(routeParams.category);
  if (this._serviceWorker) {
    this._serviceWorker.postMessage({ page: routeParams.category });
  }
});

在上面的代码段中,应用程序会监视 URL 参数的变化。发生变化时,脚本会将页面的类别转发到服务工作者。

服务工作者脚本 prefetch.service-worker.js 会处理来自主线程的消息,根据这些消息进行预测,并预取相关的资源。

服务工作者使用 loadGraphModel加载预训练模型

const MODEL_URL = "/assets/model.json";

let model = null;
tf.loadGraphModel(MODEL_URL).then((m) => (model = m));

预测发生在以下函数表达式中。

const predict = async (path, userId) => {
  if (!model) {
    return;
  }
  const page = pages.indexOf(path);
  const pageId = tf.tensor1d([parseInt(page)], "int32");

  const sessionIndex = tf.tensor1d([parseInt(userId)], "int32");

  const result = model.predict({
    cur_page: pageId,
    session_index: sessionIndex,
  });
  const values = result.dataSync();
  const orders = sortWithIndices(values).slice(0, 5);
  return orders;
};

然后,predict 函数由prefetch 函数调用。

const prefetch = async (path, sessionId) => {
  const predictions = await predict(path, sessionId);
  const formattedPredictions = predictions
    .map(([a, b]) => `'${b}' -> ${a}`)
    .join("\n");
  console.log(`Navigating from: '${path}'`);
  console.log(formattedPredictions);
  const connectionSpeed = navigator.connection.effectiveType;
  const threshold = connectionSpeeds[connectionSpeed];
  const cache = await caches.open(ImageCache);
  predictions.forEach(async ([probability, category]) => {
    if (probability >= threshold) {
      const merchs = (await getMerchList(category)).map(getUrl);
      [...new Set(merchs)].forEach((url) => {
        const request = new Request(url, {
          mode: "no-cors",
        });
        fetch(request).then((response) => cache.put(request, response));
      });
    }
  });
};

首先,prefetch 预测用户接下来可能访问的页面。然后,它遍历这些预测。对于每个预测,如果概率超过基于连接速度的某个阈值,该函数将获取预测页面的资源。通过在下一个页面请求之前获取这些资源,应用程序可以潜在地更快地提供内容并提供更好的用户体验。

下一步

在本教程中,示例应用程序使用预训练模型进行预测。您可以使用TFX 训练用于预测性预取的模型。要了解更多信息,请参阅使用机器学习加速您的网站