重新训练图像分类器

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本 查看 TF Hub 模型

简介

图像分类模型具有数百万个参数。从头开始训练它们需要大量标记的训练数据和大量的计算能力。迁移学习是一种技术,它可以通过使用已在相关任务上训练过的模型的一部分并将其重新用于新模型来缩短这一过程。

此 Colab 演示了如何构建一个 Keras 模型,通过使用来自 TensorFlow Hub 的预训练 TF2 SavedModel 进行图像特征提取来对五种花卉进行分类,该模型是在更大、更通用的 ImageNet 数据集上训练的。可以选择对特征提取器进行训练(“微调”),同时训练新添加的分类器。

需要工具吗?

这是一个 TensorFlow 编码教程。如果您需要一个工具来构建 TensorFlow 或 TFLite 模型,请查看 make_image_classifier 命令行工具,该工具可以通过 PIP 包 tensorflow-hub[make_image_classifier] 安装,或者在 这里 查看 TFLite Colab。

设置

import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

选择要使用的 TF2 SavedModel 模块

首先,使用 https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4。相同的 URL 可用于代码中识别 SavedModel,以及在浏览器中显示其文档。(请注意,TF1 Hub 格式的模型在此处不起作用。)

您可以在 此处 找到更多生成图像特征向量的 TF2 模型。

有多种可能的模型可供尝试。您只需在下面的单元格中选择一个不同的模型,然后继续使用笔记本。

设置花卉数据集

输入已针对所选模块进行了适当的调整。数据集增强(即每次读取图像时随机扭曲图像)可以改善训练,尤其是在微调时。

data_dir = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

定义模型

只需在 feature_extractor_layer 上添加一个线性分类器,该分类器使用 Hub 模块。

为了提高速度,我们从一个不可训练的 feature_extractor_layer 开始,但您也可以启用微调以获得更高的准确率。

do_fine_tuning = False
print("Building model with", model_handle)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(model_handle, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(len(class_names),
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()

训练模型

model.compile(
  optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])
steps_per_epoch = train_size // BATCH_SIZE
validation_steps = valid_size // BATCH_SIZE
hist = model.fit(
    train_ds,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps).history
plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])

在验证数据中的图像上尝试模型

x, y = next(iter(val_ds))
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + class_names[true_index])
print("Predicted label: " + class_names[predicted_index])

最后,可以将训练后的模型保存以部署到 TF Serving 或 TFLite(在移动设备上),如下所示。

saved_model_path = f"/tmp/saved_flowers_model_{model_name}"
tf.saved_model.save(model, saved_model_path)

可选:部署到 TensorFlow Lite

TensorFlow Lite 允许您将 TensorFlow 模型部署到移动设备和物联网设备。下面的代码展示了如何将训练后的模型转换为 TFLite 并应用来自 TensorFlow 模型优化工具包 的训练后工具。最后,它在 TFLite 解释器中运行它以检查结果质量

  • 在没有优化的情况下进行转换将提供与之前相同的结果(直到舍入误差)。
  • 在没有数据的情况下进行优化转换会将模型权重量化为 8 位,但推理仍然使用浮点运算来计算神经网络的激活值。这将模型大小缩小了近 4 倍,并提高了移动设备上的 CPU 延迟。
  • 此外,如果提供一个小的参考数据集来校准量化范围,神经网络激活值的计算也可以量化为 8 位整数。在移动设备上,这将进一步加速推理,并使其能够在 Edge TPU 等加速器上运行。

优化设置

interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TFLite Interpreter as a numpy-to-numpy function.
def lite_model(images):
  interpreter.allocate_tensors()
  interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
  interpreter.invoke()
  return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
num_eval_examples = 50 
eval_dataset = ((image, label)  # TFLite expects batch size 1.
                for batch in train_ds
                for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
  probs_lite = lite_model(image[None, ...])[0]
  probs_tf = model(image[None, ...]).numpy()[0]
  y_lite = np.argmax(probs_lite)
  y_tf = np.argmax(probs_tf)
  y_true = np.argmax(label)
  count +=1
  if y_lite == y_tf: count_lite_tf_agree += 1
  if y_lite == y_true: count_lite_correct += 1
  if count >= num_eval_examples: break
print("TFLite model agrees with original model on %d of %d examples (%g%%)." %
      (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TFLite model is accurate on %d of %d examples (%g%%)." %
      (count_lite_correct, count, 100.0 * count_lite_correct / count))