Keras 量化感知训练示例

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

欢迎来到 _量化感知训练_ 的端到端示例。

其他页面

有关量化感知训练的介绍以及如何确定是否应该使用它(包括支持的内容),请参阅 概述页面

要快速找到用例所需的 API(除了使用 8 位对模型进行完全量化之外),请参阅 综合指南

摘要

在本教程中,您将

  1. 从头开始训练一个用于 MNIST 的 keras 模型。
  2. 通过应用量化感知训练 API 对模型进行微调,查看准确率,并导出量化感知模型。
  3. 使用该模型为 TFLite 后端创建实际量化的模型。
  4. 查看 TFLite 中的准确率持久性以及 4 倍更小的模型。要查看移动设备上的延迟优势,请尝试 TFLite 示例 在 TFLite 应用程序存储库中

设置

 pip install -q tensorflow
 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf

from tensorflow_model_optimization.python.core.keras.compat import keras

训练用于 MNIST 的模型,不使用量化感知训练

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
2024-03-09 12:32:07.505187: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3232 - accuracy: 0.9089 - val_loss: 0.1358 - val_accuracy: 0.9618
<tf_keras.src.callbacks.History at 0x7fc8a42a07f0>

克隆并微调使用量化感知训练的预训练模型

定义模型

您将对整个模型应用量化感知训练,并在模型摘要中看到这一点。现在所有层都以“quant”为前缀。

请注意,生成的模型是量化感知的,但未量化(例如,权重为 float32 而不是 int8)。后面的部分将展示如何从量化感知模型创建量化模型。

综合指南 中,您可以看到如何量化某些层以提高模型准确率。

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer (QuantizeLa  (None, 28, 28)            3         
 yer)                                                            
                                                                 
 quant_reshape (QuantizeWra  (None, 28, 28, 1)         1         
 pperV2)                                                         
                                                                 
 quant_conv2d (QuantizeWrap  (None, 26, 26, 12)        147       
 perV2)                                                          
                                                                 
 quant_max_pooling2d (Quant  (None, 13, 13, 12)        1         
 izeWrapperV2)                                                   
                                                                 
 quant_flatten (QuantizeWra  (None, 2028)              1         
 pperV2)                                                         
                                                                 
 quant_dense (QuantizeWrapp  (None, 10)                20295     
 erV2)                                                           
                                                                 
=================================================================
Total params: 20448 (79.88 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 38 (152.00 Byte)
_________________________________________________________________

训练模型并针对基线进行评估

为了演示在仅训练一个 epoch 后对模型进行微调,请使用量化感知训练对训练数据的子集进行微调。

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 2s 480ms/step - loss: 0.1749 - accuracy: 0.9500 - val_loss: 0.1863 - val_accuracy: 0.9700
<tf_keras.src.callbacks.History at 0x7fc8a4006640>

对于此示例,与基线相比,量化感知训练后测试准确率几乎没有损失。

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9550999999046326
Quant test accuracy: 0.9584000110626221

为 TFLite 后端创建量化模型

之后,您将拥有一个实际量化的模型,其中包含 int8 权重和 uint8 激活。

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3vwylslo/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3vwylslo/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709987557.593385   26473 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987557.593449   26473 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

查看从 TF 到 TFLite 的准确率持久性

定义一个辅助函数,用于在测试数据集上评估 TFLite 模型。

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您评估量化模型,并发现 TensorFlow 中的准确率保留在 TFLite 后端。

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9584
Quant TF test accuracy: 0.9584000110626221

查看量化带来的 4 倍更小的模型

您创建了一个浮点 TFLite 模型,然后发现量化的 TFLite 模型小了 4 倍。

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmppztp93bk/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmppztp93bk/assets
Float model in Mb: 0.08068466186523438
Quantized model in Mb: 0.0236053466796875
W0000 00:00:1709987559.464849   26473 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987559.464884   26473 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

结论

在本教程中,您了解了如何使用 TensorFlow 模型优化工具包 API 创建量化感知模型,然后为 TFLite 后端创建量化模型。

您看到了 MNIST 模型的 4 倍模型大小压缩优势,同时准确率差异很小。要查看移动设备上的延迟优势,请尝试 TFLite 示例 在 TFLite 应用程序存储库中

我们鼓励您尝试此新功能,它对于在资源受限的环境中部署尤其重要。