集群保留量化感知训练 (CQAT) Keras 示例

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

概述

这是一个端到端示例,展示了如何使用 **集群保留量化感知训练 (CQAT)** API,它是 TensorFlow 模型优化工具包的协作优化管道的一部分。

其他页面

有关管道和可用技术的介绍,请参阅 协作优化概述页面

内容

在本教程中,您将

  1. 从头开始训练 keras 模型以用于 MNIST 数据集。
  2. 使用聚类微调模型并查看准确率。
  3. 应用 QAT 并观察集群的损失。
  4. 应用 CQAT 并观察之前应用的聚类是否已保留。
  5. 生成 TFLite 模型并观察在其中应用 CQAT 的效果。
  6. 将 CQAT 模型的准确率与使用训练后量化量化的模型进行比较。

设置

您可以在本地 virtualenvcolab 中运行此 Jupyter 笔记本。有关设置依赖项的详细信息,请参阅 安装指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

训练用于 MNIST 的 keras 模型,不进行聚类

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:45:06.324078: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3191 - accuracy: 0.9090 - val_loss: 0.1358 - val_accuracy: 0.9645
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1291 - accuracy: 0.9635 - val_loss: 0.0912 - val_accuracy: 0.9748
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0886 - accuracy: 0.9740 - val_loss: 0.0749 - val_accuracy: 0.9795
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0710 - accuracy: 0.9789 - val_loss: 0.0637 - val_accuracy: 0.9818
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0601 - accuracy: 0.9819 - val_loss: 0.0659 - val_accuracy: 0.9817
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0532 - accuracy: 0.9838 - val_loss: 0.0630 - val_accuracy: 0.9828
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0477 - accuracy: 0.9855 - val_loss: 0.0639 - val_accuracy: 0.9832
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0427 - accuracy: 0.9865 - val_loss: 0.0598 - val_accuracy: 0.9850
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0393 - accuracy: 0.9876 - val_loss: 0.0590 - val_accuracy: 0.9837
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0353 - accuracy: 0.9891 - val_loss: 0.0610 - val_accuracy: 0.9842
<tf_keras.src.callbacks.History at 0x7f22b1b9a0d0>

评估基线模型并将其保存以供以后使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9818000197410583
Saving model to:  /tmpfs/tmp/tmpo7dgy4fg.h5
/tmpfs/tmp/ipykernel_38069/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

使用 8 个集群对模型进行聚类和微调

应用 cluster_weights() API 对整个预训练模型进行聚类,以演示和观察其在应用 zip 时减少模型大小的有效性,同时保持准确率。有关如何最好地使用 API 在保持目标准确率的同时实现最佳压缩率,请参阅 聚类综合指南

定义模型并应用聚类 API

在使用聚类 API 之前,需要对模型进行预训练。

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'cluster_per_channel': True,
}

clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 cluster_reshape (ClusterWe  (None, 28, 28, 1)         0         
 ights)                                                          
                                                                 
 cluster_conv2d (ClusterWei  (None, 26, 26, 12)        324       
 ghts)                                                           
                                                                 
 cluster_max_pooling2d (Clu  (None, 13, 13, 12)        0         
 sterWeights)                                                    
                                                                 
 cluster_flatten (ClusterWe  (None, 2028)              0         
 ights)                                                          
                                                                 
 cluster_dense (ClusterWeig  (None, 10)                40578     
 hts)                                                            
                                                                 
=================================================================
Total params: 40902 (239.41 KB)
Trainable params: 20514 (80.13 KB)
Non-trainable params: 20388 (159.28 KB)
_________________________________________________________________

微调模型并评估其相对于基线的准确率

使用聚类对模型进行 3 个 epoch 的微调。

# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)
Epoch 1/3
1688/1688 [==============================] - 11s 5ms/step - loss: 0.0316 - accuracy: 0.9909 - val_loss: 0.0610 - val_accuracy: 0.9837
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0297 - accuracy: 0.9916 - val_loss: 0.0603 - val_accuracy: 0.9852
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0291 - accuracy: 0.9919 - val_loss: 0.0596 - val_accuracy: 0.9850
<tf_keras.src.callbacks.History at 0x7f22946065e0>

定义辅助函数以计算和打印模型中每个内核的聚类数量。

def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

检查模型内核是否已正确聚类。我们需要先剥离聚类包装器。

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 96 clusters 
dense/kernel:0: 8 clusters

对于此示例,与基线相比,聚类后测试准确率的损失很小。

_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9818000197410583
Clustered test accuracy: 0.9818999767303467

应用 QAT 和 CQAT 并检查两种情况下对模型集群的影响

接下来,我们在聚类模型上应用 QAT 和集群保留 QAT (CQAT),并观察 CQAT 如何保留聚类模型中的权重集群。请注意,我们在应用 CQAT API 之前,使用 tfmot.clustering.keras.strip_clustering 从模型中剥离了聚类包装器。

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())

cqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0315 - accuracy: 0.9905 - val_loss: 0.0573 - val_accuracy: 0.9855
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
WARNING:root:Input layer does not contain zero weights, so apply CQAT instead.
Train cqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
422/422 [==============================] - 6s 8ms/step - loss: 0.0290 - accuracy: 0.9917 - val_loss: 0.0597 - val_accuracy: 0.9847
<tf_keras.src.callbacks.History at 0x7f229414f130>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 108 clusters 
quant_dense/dense/kernel:0: 19910 clusters 
CQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 96 clusters 
quant_dense/dense/kernel:0: 8 clusters

查看 CQAT 模型的压缩优势

定义辅助函数以获取压缩的模型文件。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

请注意,这是一个小型模型。将聚类和 CQAT 应用于更大的生产模型将产生更显著的压缩。

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
    f.write(cqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqnilvtco/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqnilvtco/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709988433.596770   38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988433.596818   38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqlp0tpfp/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqlp0tpfp/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
QAT model size:  17.487  KB
CQAT model size:  10.64  KB
W0000 00:00:1709988436.818205   38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988436.818236   38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

查看从 TF 到 TFLite 的准确率持久性

定义辅助函数以评估测试数据集上的 TFLite 模型。

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您评估模型(已聚类和量化),然后查看 TensorFlow 中的准确率是否保留在 TFLite 后端中。

interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()

cqat_test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Clustered and quantized TFLite test_accuracy: 0.9822
Clustered TF test accuracy: 0.9818999767303467

应用训练后量化并将结果与 CQAT 模型进行比较

接下来,我们在聚类模型上使用训练后量化(无微调)并检查其相对于 CQAT 模型的准确率。这说明了为什么您需要使用 CQAT 来提高量化模型的准确率。差异可能不太明显,因为 MNIST 模型非常小且参数过多。

首先,为前 1000 个训练图像定义一个校准数据集的生成器。

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

量化模型并将准确率与之前获得的 CQAT 模型进行比较。请注意,使用微调量化的模型实现了更高的准确率。

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxyohbvab/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxyohbvab/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1709988438.608574   38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988438.608603   38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


CQAT TFLite test_accuracy: 0.9822
Post-training (no fine-tuning) TF test accuracy: 0.9817

结论

在本教程中,您学习了如何创建模型、使用 cluster_weights() API 对其进行聚类,以及应用集群保留量化感知训练 (CQAT) 以在使用 QAT 时保留集群。最后,将 CQAT 模型与 QAT 模型进行比较,以表明前者保留了集群,而后者则丢失了集群。接下来,将模型转换为 TFLite,以展示将聚类和 CQAT 模型优化技术链接起来的压缩优势,并评估 TFLite 模型以确保准确率保留在 TFLite 后端中。最后,将 CQAT 模型与使用训练后量化 API 实现的量化聚类模型进行比较,以说明 CQAT 在恢复正常量化造成的准确率损失方面的优势。