稀疏性和聚类保留量化感知训练 (PCQAT) Keras 示例

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

概述

这是一个端到端示例,展示了 **稀疏性和聚类保留量化感知训练 (PCQAT)** API 的用法,该 API 是 TensorFlow 模型优化工具包的协作优化管道的一部分。

其他页面

有关管道和可用技术的介绍,请参阅 协作优化概述页面

内容

在本教程中,您将

  1. 从头开始训练 keras 模型以用于 MNIST 数据集。
  2. 使用修剪微调模型,查看准确率并观察模型是否已成功修剪。
  3. 在修剪后的模型上应用稀疏性保留聚类,观察之前应用的稀疏性是否已保留。
  4. 应用 QAT 并观察稀疏性和聚类的损失。
  5. 应用 PCQAT 并观察之前应用的稀疏性和聚类是否已保留。
  6. 生成 TFLite 模型并观察在其中应用 PCQAT 的效果。
  7. 比较不同模型的大小,观察应用稀疏性后,再应用稀疏性保留聚类和 PCQAT 协作优化技术的压缩优势。
  8. 比较完全优化的模型与未优化的基线模型的准确率。

设置

您可以在本地 virtualenvcolab 中运行此 Jupyter 笔记本。有关设置依赖项的详细信息,请参阅 安装指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

训练用于修剪和聚类的 MNIST keras 模型

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

opt = keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:49:28.954689: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3037 - accuracy: 0.9146 - val_loss: 0.1153 - val_accuracy: 0.9682
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1133 - accuracy: 0.9680 - val_loss: 0.0895 - val_accuracy: 0.9762
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0792 - accuracy: 0.9768 - val_loss: 0.0652 - val_accuracy: 0.9825
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0662 - accuracy: 0.9803 - val_loss: 0.0633 - val_accuracy: 0.9823
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0570 - accuracy: 0.9833 - val_loss: 0.0649 - val_accuracy: 0.9825
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0498 - accuracy: 0.9853 - val_loss: 0.0571 - val_accuracy: 0.9842
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0448 - accuracy: 0.9867 - val_loss: 0.0586 - val_accuracy: 0.9840
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0405 - accuracy: 0.9873 - val_loss: 0.0586 - val_accuracy: 0.9848
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0370 - accuracy: 0.9885 - val_loss: 0.0624 - val_accuracy: 0.9828
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0332 - accuracy: 0.9902 - val_loss: 0.0554 - val_accuracy: 0.9848
<tf_keras.src.callbacks.History at 0x7f615076beb0>

评估基线模型并将其保存以供以后使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9835000038146973
Saving model to:  /tmpfs/tmp/tmpf70eijr3.h5
/tmpfs/tmp/ipykernel_41361/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

将模型修剪到 50% 的稀疏性并进行微调

应用 prune_low_magnitude() API 以实现修剪后的模型,该模型将在下一步中进行聚类。有关修剪 API 的更多信息,请参阅 修剪综合指南

定义模型并应用稀疏性 API

请注意,使用了预训练模型。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

微调模型,检查稀疏性,并根据基线评估准确率

使用修剪对模型进行 3 个纪元的微调。

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
Epoch 1/3
1688/1688 [==============================] - 10s 4ms/step - loss: 0.0851 - accuracy: 0.9707 - val_loss: 0.0801 - val_accuracy: 0.9768
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0591 - accuracy: 0.9801 - val_loss: 0.0672 - val_accuracy: 0.9808
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0493 - accuracy: 0.9852 - val_loss: 0.0626 - val_accuracy: 0.9837
<tf_keras.src.callbacks.History at 0x7f60c8593ee0>

定义辅助函数以计算和打印模型的稀疏性和聚类。

def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

让我们先剥离修剪包装器,然后检查模型内核是否已正确修剪。

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)
conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

应用稀疏性保留聚类并检查其对两种情况下模型稀疏性的影响

接下来,在修剪后的模型上应用稀疏性保留聚类,观察聚类数量并检查稀疏性是否已保留。

import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
Train sparsity preserving clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0422 - accuracy: 0.9869 - val_loss: 0.0712 - val_accuracy: 0.9818
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0398 - accuracy: 0.9878 - val_loss: 0.0627 - val_accuracy: 0.9848
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0403 - accuracy: 0.9865 - val_loss: 0.0597 - val_accuracy: 0.9830
<tf_keras.src.callbacks.History at 0x7f6080153790>

先剥离聚类包装器,然后检查模型是否已正确修剪和聚类。

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)
Model sparsity:

kernel:0: 50.93% sparsity  (55/108)
kernel:0: 58.12% sparsity  (11787/20280)

Model clusters:

conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

应用 QAT 和 PCQAT 并检查对模型聚类和稀疏性的影响

接下来,在稀疏聚类模型上应用 QAT 和 PCQAT,观察 PCQAT 是否保留了模型中的权重稀疏性和聚类。请注意,剥离后的模型已传递给 QAT 和 PCQAT API。

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
pcqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

pcqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0298 - accuracy: 0.9911 - val_loss: 0.0587 - val_accuracy: 0.9853
Train pcqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
422/422 [==============================] - 5s 7ms/step - loss: 0.0315 - accuracy: 0.9904 - val_loss: 0.0563 - val_accuracy: 0.9842
<tf_keras.src.callbacks.History at 0x7f6050606e80>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 100 clusters 
quant_dense/dense/kernel:0: 18251 clusters 

QAT Model sparsity:
conv2d/kernel:0: 8.33% sparsity  (9/108)
dense/kernel:0: 7.52% sparsity  (1525/20280)

PCQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters 

PCQAT Model sparsity:
conv2d/kernel:0: 50.93% sparsity  (55/108)
dense/kernel:0: 58.16% sparsity  (11794/20280)

查看 PCQAT 模型的压缩优势

定义辅助函数以获取压缩的模型文件。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

观察到,将稀疏性、聚类和 PCQAT 应用于模型会产生显著的压缩优势。

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbd29dk98/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbd29dk98/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709988717.237025   41361 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988717.237075   41361 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy4q5o_1n/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy4q5o_1n/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1709988720.060897   41361 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988720.060927   41361 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
QAT model size:  13.958  KB
PCQAT model size:  7.876  KB

查看从 TF 到 TFLite 的准确率持久性

定义辅助函数以在测试数据集上评估 TFLite 模型。

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

评估已修剪、聚类和量化的模型,然后查看 TensorFlow 中的准确率是否保留在 TFLite 后端中。

interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()

pcqat_test_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor).
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned, clustered and quantized TFLite test_accuracy: 0.9806
Baseline TF test accuracy: 0.9835000038146973

结论

在本教程中,您学习了如何创建模型,使用 prune_low_magnitude() API 对其进行修剪,并使用 cluster_weights() API 应用稀疏性保留聚类以在聚类权重时保留稀疏性。

接下来,应用了稀疏性和聚类保留量化感知训练 (PCQAT) 以在使用 QAT 时保留模型稀疏性和聚类。将最终的 PCQAT 模型与 QAT 模型进行比较,以表明前者保留了稀疏性和聚类,而后者则丢失了。

接下来,将模型转换为 TFLite 以显示链接稀疏性、聚类和 PCQAT 模型优化技术的压缩优势,并评估 TFLite 模型以确保准确率保留在 TFLite 后端中。

最后,将 PCQAT TFLite 模型的准确率与预优化基线模型的准确率进行比较,以表明协作优化技术在保持与原始模型相似的准确率的同时,成功地实现了压缩优势。