Keras 中的量化感知训练示例

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概览

欢迎来到一个端到端示例,用于量化感知训练

其他页面

有关量化感知训练的简介以及确定您是否应该使用它(包括支持的内容),请参阅概述页面

要快速找到您用例所需的 API(除了使用 8 位对模型进行完全量化之外),请参阅综合指南

摘要

在本教程中,您将

  1. 从头开始训练一个用于 MNIST 的keras模型。
  2. 通过应用量化感知训练 API 来微调模型,查看准确性,并导出量化感知模型。
  3. 使用该模型为 TFLite 后端创建实际量化模型。
  4. 查看 TFLite 中的准确性持久性以及小 4 倍的模型。若要查看移动设备上的延迟优势,请尝试 TFLite 示例 在 TFLite 应用存储库中

设置

 pip install -q tensorflow
 pip install -q tensorflow-model-optimization
import tempfile
import os

import tensorflow as tf

from tensorflow_model_optimization.python.core.keras.compat import keras

在没有量化感知训练的情况下为 MNIST 训练模型

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
2024-03-09 12:32:07.505187: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3232 - accuracy: 0.9089 - val_loss: 0.1358 - val_accuracy: 0.9618
<tf_keras.src.callbacks.History at 0x7fc8a42a07f0>

克隆并微调具有量化感知训练的预训练模型

定义模型

您将对整个模型应用量化感知训练,并在模型摘要中看到这一点。现在所有图层都以“quant”为前缀。

请注意,生成的模型具有量化感知但未量化(例如,权重是 float32 而不是 int8)。后面的部分显示如何从量化感知模型创建量化模型。

综合指南 中,您可以了解如何量化一些图层以提高模型准确性。

import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 quantize_layer (QuantizeLa  (None, 28, 28)            3         
 yer)                                                            
                                                                 
 quant_reshape (QuantizeWra  (None, 28, 28, 1)         1         
 pperV2)                                                         
                                                                 
 quant_conv2d (QuantizeWrap  (None, 26, 26, 12)        147       
 perV2)                                                          
                                                                 
 quant_max_pooling2d (Quant  (None, 13, 13, 12)        1         
 izeWrapperV2)                                                   
                                                                 
 quant_flatten (QuantizeWra  (None, 2028)              1         
 pperV2)                                                         
                                                                 
 quant_dense (QuantizeWrapp  (None, 10)                20295     
 erV2)                                                           
                                                                 
=================================================================
Total params: 20448 (79.88 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 38 (152.00 Byte)
_________________________________________________________________

根据基准训练和评估模型

为了演示在仅对模型训练一个 epoch 后进行微调,请在训练数据的一个子集上使用量化感知训练进行微调。

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)
2/2 [==============================] - 2s 480ms/step - loss: 0.1749 - accuracy: 0.9500 - val_loss: 0.1863 - val_accuracy: 0.9700
<tf_keras.src.callbacks.History at 0x7fc8a4006640>

对于此示例,与基准相比,量化感知训练后测试准确性几乎没有损失或没有损失。

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9550999999046326
Quant test accuracy: 0.9584000110626221

为 TFLite 后端创建量化模型

在此之后,您将获得一个实际量化模型,其中具有 int8 权重和 uint8 激活。

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3vwylslo/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3vwylslo/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709987557.593385   26473 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987557.593449   26473 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

查看从 TF 到 TFLite 的准确性持久性

定义一个辅助函数,以在测试数据集上评估 TF Lite 模型。

import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您评估量化模型,并看到 TensorFlow 的准确性持续存在于 TFLite 后端。

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9584
Quant TF test accuracy: 0.9584000110626221

查看量化后小 4 倍的模型

您创建一个 float TFLite 模型,然后看到量化的 TFLite 模型小 4 倍。

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmppztp93bk/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmppztp93bk/assets
Float model in Mb: 0.08068466186523438
Quantized model in Mb: 0.0236053466796875
W0000 00:00:1709987559.464849   26473 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987559.464884   26473 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

结论

在本教程中,您了解了如何使用 TensorFlow 模型优化工具包 API 创建量化感知模型,然后为 TFLite 后端创建量化模型。

对于 MNIST 模型,您看到了 4 倍的模型大小压缩优势,而准确性差异很小。若要了解移动设备上的延迟优势,请在 TFLite 应用存储库中试用 TFLite 示例

我们鼓励您尝试此新功能,这对于在资源受限的环境中部署尤其重要。