在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看 | 下载笔记本 |
概述
这是一个端到端示例,展示了如何使用 **集群保留量化感知训练 (CQAT)** API,它是 TensorFlow 模型优化工具包的协作优化管道的一部分。
其他页面
有关管道和可用技术的介绍,请参阅 协作优化概述页面。
内容
在本教程中,您将
- 从头开始训练
keras
模型以用于 MNIST 数据集。 - 使用聚类微调模型并查看准确率。
- 应用 QAT 并观察集群的损失。
- 应用 CQAT 并观察之前应用的聚类是否已保留。
- 生成 TFLite 模型并观察在其中应用 CQAT 的效果。
- 将 CQAT 模型的准确率与使用训练后量化量化的模型进行比较。
设置
您可以在本地 virtualenv 或 colab 中运行此 Jupyter 笔记本。有关设置依赖项的详细信息,请参阅 安装指南。
pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras
import numpy as np
import tempfile
import zipfile
import os
训练用于 MNIST 的 keras 模型,不进行聚类
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
2024-03-09 12:45:06.324078: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 1688/1688 [==============================] - 21s 4ms/step - loss: 0.3191 - accuracy: 0.9090 - val_loss: 0.1358 - val_accuracy: 0.9645 Epoch 2/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.1291 - accuracy: 0.9635 - val_loss: 0.0912 - val_accuracy: 0.9748 Epoch 3/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0886 - accuracy: 0.9740 - val_loss: 0.0749 - val_accuracy: 0.9795 Epoch 4/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0710 - accuracy: 0.9789 - val_loss: 0.0637 - val_accuracy: 0.9818 Epoch 5/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0601 - accuracy: 0.9819 - val_loss: 0.0659 - val_accuracy: 0.9817 Epoch 6/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0532 - accuracy: 0.9838 - val_loss: 0.0630 - val_accuracy: 0.9828 Epoch 7/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0477 - accuracy: 0.9855 - val_loss: 0.0639 - val_accuracy: 0.9832 Epoch 8/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0427 - accuracy: 0.9865 - val_loss: 0.0598 - val_accuracy: 0.9850 Epoch 9/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0393 - accuracy: 0.9876 - val_loss: 0.0590 - val_accuracy: 0.9837 Epoch 10/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0353 - accuracy: 0.9891 - val_loss: 0.0610 - val_accuracy: 0.9842 <tf_keras.src.callbacks.History at 0x7f22b1b9a0d0>
评估基线模型并将其保存以供以后使用
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9818000197410583 Saving model to: /tmpfs/tmp/tmpo7dgy4fg.h5 /tmpfs/tmp/ipykernel_38069/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`. keras.models.save_model(model, keras_file, include_optimizer=False)
使用 8 个集群对模型进行聚类和微调
应用 cluster_weights()
API 对整个预训练模型进行聚类,以演示和观察其在应用 zip 时减少模型大小的有效性,同时保持准确率。有关如何最好地使用 API 在保持目标准确率的同时实现最佳压缩率,请参阅 聚类综合指南。
定义模型并应用聚类 API
在使用聚类 API 之前,需要对模型进行预训练。
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
'cluster_per_channel': True,
}
clustered_model = cluster_weights(model, **clustering_params)
# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)
clustered_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
clustered_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_reshape (ClusterWe (None, 28, 28, 1) 0 ights) cluster_conv2d (ClusterWei (None, 26, 26, 12) 324 ghts) cluster_max_pooling2d (Clu (None, 13, 13, 12) 0 sterWeights) cluster_flatten (ClusterWe (None, 2028) 0 ights) cluster_dense (ClusterWeig (None, 10) 40578 hts) ================================================================= Total params: 40902 (239.41 KB) Trainable params: 20514 (80.13 KB) Non-trainable params: 20388 (159.28 KB) _________________________________________________________________
微调模型并评估其相对于基线的准确率
使用聚类对模型进行 3 个 epoch 的微调。
# Fine-tune model
clustered_model.fit(
train_images,
train_labels,
epochs=3,
validation_split=0.1)
Epoch 1/3 1688/1688 [==============================] - 11s 5ms/step - loss: 0.0316 - accuracy: 0.9909 - val_loss: 0.0610 - val_accuracy: 0.9837 Epoch 2/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0297 - accuracy: 0.9916 - val_loss: 0.0603 - val_accuracy: 0.9852 Epoch 3/3 1688/1688 [==============================] - 8s 5ms/step - loss: 0.0291 - accuracy: 0.9919 - val_loss: 0.0596 - val_accuracy: 0.9850 <tf_keras.src.callbacks.History at 0x7f22946065e0>
定义辅助函数以计算和打印模型中每个内核的聚类数量。
def print_model_weight_clusters(model):
for layer in model.layers:
if isinstance(layer, keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
# ignore auxiliary quantization weights
if "quantize_layer" in weight.name:
continue
if "kernel" in weight.name:
unique_count = len(np.unique(weight))
print(
f"{layer.name}/{weight.name}: {unique_count} clusters "
)
检查模型内核是否已正确聚类。我们需要先剥离聚类包装器。
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)
print_model_weight_clusters(stripped_clustered_model)
conv2d/kernel:0: 96 clusters dense/kernel:0: 8 clusters
对于此示例,与基线相比,聚类后测试准确率的损失很小。
_, clustered_model_accuracy = clustered_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9818000197410583 Clustered test accuracy: 0.9818999767303467
应用 QAT 和 CQAT 并检查两种情况下对模型集群的影响
接下来,我们在聚类模型上应用 QAT 和集群保留 QAT (CQAT),并观察 CQAT 如何保留聚类模型中的权重集群。请注意,我们在应用 CQAT API 之前,使用 tfmot.clustering.keras.strip_clustering
从模型中剥离了聚类包装器。
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)
qat_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
# CQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
stripped_clustered_model)
cqat_model = tfmot.quantization.keras.quantize_apply(
quant_aware_annotate_model,
tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme())
cqat_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train cqat model:')
cqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model: 422/422 [==============================] - 4s 7ms/step - loss: 0.0315 - accuracy: 0.9905 - val_loss: 0.0573 - val_accuracy: 0.9855 WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. WARNING:root:Input layer does not contain zero weights, so apply CQAT instead. Train cqat model: WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? 422/422 [==============================] - 6s 8ms/step - loss: 0.0290 - accuracy: 0.9917 - val_loss: 0.0597 - val_accuracy: 0.9847 <tf_keras.src.callbacks.History at 0x7f229414f130>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("CQAT Model clusters:")
print_model_weight_clusters(cqat_model)
QAT Model clusters: quant_conv2d/conv2d/kernel:0: 108 clusters quant_dense/dense/kernel:0: 19910 clusters CQAT Model clusters: quant_conv2d/conv2d/kernel:0: 96 clusters quant_dense/dense/kernel:0: 8 clusters
查看 CQAT 模型的压缩优势
定义辅助函数以获取压缩的模型文件。
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in kilobytes.
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)/1000
请注意,这是一个小型模型。将聚类和 CQAT 应用于更大的生产模型将产生更显著的压缩。
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
f.write(qat_tflite_model)
# CQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(cqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
cqat_tflite_model = converter.convert()
cqat_model_file = 'cqat_model.tflite'
# Save the model.
with open(cqat_model_file, 'wb') as f:
f.write(cqat_tflite_model)
print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("CQAT model size: ", get_gzipped_model_size(cqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqnilvtco/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqnilvtco/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1709988433.596770 38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709988433.596818 38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqlp0tpfp/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpqlp0tpfp/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( QAT model size: 17.487 KB CQAT model size: 10.64 KB W0000 00:00:1709988436.818205 38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709988436.818236 38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
查看从 TF 到 TFLite 的准确率持久性
定义辅助函数以评估测试数据集上的 TFLite 模型。
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print(f"Evaluated on {i} results so far.")
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
您评估模型(已聚类和量化),然后查看 TensorFlow 中的准确率是否保留在 TFLite 后端中。
interpreter = tf.lite.Interpreter(cqat_model_file)
interpreter.allocate_tensors()
cqat_test_accuracy = eval_model(interpreter)
print('Clustered and quantized TFLite test_accuracy:', cqat_test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU. WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor). Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Clustered and quantized TFLite test_accuracy: 0.9822 Clustered TF test accuracy: 0.9818999767303467
应用训练后量化并将结果与 CQAT 模型进行比较
接下来,我们在聚类模型上使用训练后量化(无微调)并检查其相对于 CQAT 模型的准确率。这说明了为什么您需要使用 CQAT 来提高量化模型的准确率。差异可能不太明显,因为 MNIST 模型非常小且参数过多。
首先,为前 1000 个训练图像定义一个校准数据集的生成器。
def mnist_representative_data_gen():
for image in train_images[:1000]:
image = np.expand_dims(image, axis=0).astype(np.float32)
yield [image]
量化模型并将准确率与之前获得的 CQAT 模型进行比较。请注意,使用微调量化的模型实现了更高的准确率。
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
f.write(post_training_tflite_model)
# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()
post_training_test_accuracy = eval_model(interpreter)
print('CQAT TFLite test_accuracy:', cqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxyohbvab/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpxyohbvab/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( W0000 00:00:1709988438.608574 38069 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709988438.608603 38069 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32 Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. CQAT TFLite test_accuracy: 0.9822 Post-training (no fine-tuning) TF test accuracy: 0.9817
结论
在本教程中,您学习了如何创建模型、使用 cluster_weights()
API 对其进行聚类,以及应用集群保留量化感知训练 (CQAT) 以在使用 QAT 时保留集群。最后,将 CQAT 模型与 QAT 模型进行比较,以表明前者保留了集群,而后者则丢失了集群。接下来,将模型转换为 TFLite,以展示将聚类和 CQAT 模型优化技术链接起来的压缩优势,并评估 TFLite 模型以确保准确率保留在 TFLite 后端中。最后,将 CQAT 模型与使用训练后量化 API 实现的量化聚类模型进行比较,以说明 CQAT 在恢复正常量化造成的准确率损失方面的优势。