在本教程中,您将探索一个示例 Web 应用程序,该应用程序演示了使用 TensorFlow.js Layers API 进行迁移学习。该示例加载一个预训练模型,然后在浏览器中重新训练该模型。
该模型已在 Python 中使用 MNIST 数字分类数据集 的数字 0-4 进行预训练。浏览器中的重新训练(或迁移学习)使用数字 5-9。该示例表明,预训练模型的前几层可用于在迁移学习期间从新数据中提取特征,从而能够更快地在新数据上进行训练。
本教程的示例应用程序 在线提供,因此您无需下载任何代码或设置开发环境。如果您想在本地运行代码,请完成 在本地运行示例 中的可选步骤。如果您不想设置开发环境,可以跳到 探索示例。
示例代码在 GitHub 上提供。
(可选) 在本地运行示例
先决条件
要在本地运行示例应用程序,您需要在开发环境中安装以下内容
安装并运行示例应用程序
- 克隆或下载
tfjs-examples
存储库。 进入
mnist-transfer-cnn
目录cd tfjs-examples/mnist-transfer-cnn
安装依赖项
yarn
启动开发服务器
yarn run watch
探索示例
打开示例应用程序。(或者,如果您在本地运行示例,请在浏览器中访问 http://localhost:1234
。)
您应该会看到一个名为 MNIST CNN 迁移学习 的页面。按照说明尝试该应用程序。
以下是一些可以尝试的操作
- 尝试不同的训练模式,并比较损失和准确率。
- 选择不同的位图示例,并检查分类概率。请注意,每个位图示例中的数字都是表示图像像素的灰度整数值。
- 编辑位图整数值,并查看更改如何影响分类概率。
探索代码
示例 Web 应用程序加载一个在 MNIST 数据集子集上预训练的模型。预训练是在 Python 程序中定义的:mnist_transfer_cnn.py
。Python 程序不在本教程的范围内,但如果您想查看 模型转换 的示例,则值得一看。
index.js
文件包含演示的大部分训练代码。当 index.js
在浏览器中运行时,一个设置函数 setupMnistTransferCNN
实例化并初始化 MnistTransferCNNPredictor
,它封装了重新训练和预测例程。
初始化方法 MnistTransferCNNPredictor.init
加载模型、加载重新训练数据并创建测试数据。以下是 加载模型的行
this.model = await loader.loadHostedPretrainedModel(urls.model);
如果您查看 loader.loadHostedPretrainedModel
的定义,您会发现它返回对 tf.loadLayersModel
的调用的结果。这是 TensorFlow.js API,用于加载由 Layer 对象组成的模型。
重新训练逻辑定义在 MnistTransferCNNPredictor.retrainModel
中。如果用户选择了 **冻结特征层** 作为训练模式,则基础模型的前 7 层将被冻结,只有最后的 5 层会在新数据上进行训练。如果用户选择了 **重新初始化权重**,则所有权重都会被重置,应用程序实际上是从头开始训练模型。
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
然后,模型将被 编译,然后使用 model.fit()
在测试数据上 训练。
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
要了解有关 model.fit()
参数的更多信息,请参阅 API 文档。
在使用新数据集(数字 5-9)训练后,模型可用于进行预测。 MnistTransferCNNPredictor.predict
方法使用 model.predict()
执行此操作。
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
请注意 tf.tidy
的使用,它有助于防止内存泄漏。
了解更多
本教程探讨了一个使用 TensorFlow.js 在浏览器中执行迁移学习的示例应用程序。查看以下资源以了解有关预训练模型和迁移学习的更多信息。
TensorFlow.js
TensorFlow Core