使用预训练模型

在本教程中,您将探索一个示例 Web 应用程序,该应用程序演示了使用 TensorFlow.js Layers API 进行迁移学习。该示例加载一个预训练模型,然后在浏览器中重新训练该模型。

该模型已在 Python 中使用 MNIST 数字分类数据集 的数字 0-4 进行预训练。浏览器中的重新训练(或迁移学习)使用数字 5-9。该示例表明,预训练模型的前几层可用于在迁移学习期间从新数据中提取特征,从而能够更快地在新数据上进行训练。

本教程的示例应用程序 在线提供,因此您无需下载任何代码或设置开发环境。如果您想在本地运行代码,请完成 在本地运行示例 中的可选步骤。如果您不想设置开发环境,可以跳到 探索示例

示例代码在 GitHub 上提供。

(可选) 在本地运行示例

先决条件

要在本地运行示例应用程序,您需要在开发环境中安装以下内容

安装并运行示例应用程序

  1. 克隆或下载 tfjs-examples 存储库。
  2. 进入 mnist-transfer-cnn 目录

    cd tfjs-examples/mnist-transfer-cnn
    
  3. 安装依赖项

    yarn
    
  4. 启动开发服务器

    yarn run watch
    

探索示例

打开示例应用程序。(或者,如果您在本地运行示例,请在浏览器中访问 https://127.0.0.1:1234。)

您应该会看到一个名为 MNIST CNN 迁移学习 的页面。按照说明尝试该应用程序。

以下是一些可以尝试的操作

  • 尝试不同的训练模式,并比较损失和准确率。
  • 选择不同的位图示例,并检查分类概率。请注意,每个位图示例中的数字都是表示图像像素的灰度整数值。
  • 编辑位图整数值,并查看更改如何影响分类概率。

探索代码

示例 Web 应用程序加载一个在 MNIST 数据集子集上预训练的模型。预训练是在 Python 程序中定义的:mnist_transfer_cnn.py。Python 程序不在本教程的范围内,但如果您想查看 模型转换 的示例,则值得一看。

index.js 文件包含演示的大部分训练代码。当 index.js 在浏览器中运行时,一个设置函数 setupMnistTransferCNN 实例化并初始化 MnistTransferCNNPredictor,它封装了重新训练和预测例程。

初始化方法 MnistTransferCNNPredictor.init 加载模型、加载重新训练数据并创建测试数据。以下是 加载模型的行

this.model = await loader.loadHostedPretrainedModel(urls.model);

如果您查看 loader.loadHostedPretrainedModel 的定义,您会发现它返回对 tf.loadLayersModel 的调用的结果。这是 TensorFlow.js API,用于加载由 Layer 对象组成的模型。

重新训练逻辑定义在 MnistTransferCNNPredictor.retrainModel 中。如果用户选择了 **冻结特征层** 作为训练模式,则基础模型的前 7 层将被冻结,只有最后的 5 层会在新数据上进行训练。如果用户选择了 **重新初始化权重**,则所有权重都会被重置,应用程序实际上是从头开始训练模型。

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

然后,模型将被 编译,然后使用 model.fit() 在测试数据上 训练

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

要了解有关 model.fit() 参数的更多信息,请参阅 API 文档

在使用新数据集(数字 5-9)训练后,模型可用于进行预测。 MnistTransferCNNPredictor.predict 方法使用 model.predict() 执行此操作。

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

请注意 tf.tidy 的使用,它有助于防止内存泄漏。

了解更多

本教程探讨了一个使用 TensorFlow.js 在浏览器中执行迁移学习的示例应用程序。查看以下资源以了解有关预训练模型和迁移学习的更多信息。

TensorFlow.js

TensorFlow Core