在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看 | 下载笔记本 |
概述
这是一个端到端的示例,展示了 **修剪保留量化感知训练 (PQAT)** API 的用法,该 API 是 TensorFlow 模型优化工具包的协作优化管道的一部分。
其他页面
有关管道和可用技术的介绍,请参阅 协作优化概述页面。
内容
在本教程中,您将
- 从头开始训练一个
keras
模型用于 MNIST 数据集。 - 使用稀疏性 API 对模型进行微调,并查看准确率。
- 应用 QAT 并观察稀疏性的损失。
- 应用 PQAT 并观察之前应用的稀疏性是否得到保留。
- 生成一个 TFLite 模型,并观察在模型上应用 PQAT 的效果。
- 将 PQAT 模型的准确率与使用训练后量化量化的模型进行比较。
设置
您可以在本地 virtualenv 或 colab 中运行此 Jupyter 笔记本。有关设置依赖项的详细信息,请参阅 安装指南。
pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras
import numpy as np
import tempfile
import zipfile
import os
训练一个用于 MNIST 的 keras 模型,不进行修剪
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
2024-03-09 12:40:49.225662: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/10 1688/1688 [==============================] - 21s 4ms/step - loss: 0.3056 - accuracy: 0.9125 - val_loss: 0.1308 - val_accuracy: 0.9640 Epoch 2/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.1348 - accuracy: 0.9614 - val_loss: 0.0882 - val_accuracy: 0.9760 Epoch 3/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0951 - accuracy: 0.9730 - val_loss: 0.0719 - val_accuracy: 0.9797 Epoch 4/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0761 - accuracy: 0.9778 - val_loss: 0.0694 - val_accuracy: 0.9798 Epoch 5/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0648 - accuracy: 0.9808 - val_loss: 0.0599 - val_accuracy: 0.9838 Epoch 6/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0564 - accuracy: 0.9831 - val_loss: 0.0601 - val_accuracy: 0.9837 Epoch 7/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0496 - accuracy: 0.9852 - val_loss: 0.0578 - val_accuracy: 0.9848 Epoch 8/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0445 - accuracy: 0.9864 - val_loss: 0.0556 - val_accuracy: 0.9847 Epoch 9/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0408 - accuracy: 0.9874 - val_loss: 0.0539 - val_accuracy: 0.9853 Epoch 10/10 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0382 - accuracy: 0.9881 - val_loss: 0.0585 - val_accuracy: 0.9848 <tf_keras.src.callbacks.History at 0x7f06cb26d670>
评估基线模型并保存以备后用
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9812999963760376 Saving model to: /tmpfs/tmp/tmpgyooj7vn.h5 /tmpfs/tmp/ipykernel_34779/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`. keras.models.save_model(model, keras_file, include_optimizer=False)
将模型修剪到 50% 的稀疏性并进行微调
应用 prune_low_magnitude()
API 对整个预训练模型进行修剪,以演示和观察其在应用 zip 时减小模型大小的有效性,同时保持准确率。有关如何最佳使用 API 来实现最佳压缩率,同时保持目标准确率,请参阅 修剪综合指南。
定义模型并应用稀疏性 API
在使用稀疏性 API 之前,需要对模型进行预训练。
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
}
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]
pruned_model = prune_low_magnitude(model, **pruning_params)
# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)
pruned_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
pruned_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_reshap (None, 28, 28, 1) 1 e (PruneLowMagnitude) prune_low_magnitude_conv2d (None, 26, 26, 12) 230 (PruneLowMagnitude) prune_low_magnitude_max_po (None, 13, 13, 12) 1 oling2d (PruneLowMagnitude ) prune_low_magnitude_flatte (None, 2028) 1 n (PruneLowMagnitude) prune_low_magnitude_dense (None, 10) 40572 (PruneLowMagnitude) ================================================================= Total params: 40805 (159.41 KB) Trainable params: 20410 (79.73 KB) Non-trainable params: 20395 (79.69 KB) _________________________________________________________________
对模型进行微调,并根据基线评估准确率
对模型进行微调,并进行 3 个 epoch 的修剪。
# Fine-tune model
pruned_model.fit(
train_images,
train_labels,
epochs=3,
validation_split=0.1,
callbacks=callbacks)
Epoch 1/3 1688/1688 [==============================] - 10s 4ms/step - loss: 0.0852 - accuracy: 0.9716 - val_loss: 0.0814 - val_accuracy: 0.9742 Epoch 2/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0641 - accuracy: 0.9800 - val_loss: 0.0721 - val_accuracy: 0.9763 Epoch 3/3 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0559 - accuracy: 0.9829 - val_loss: 0.0682 - val_accuracy: 0.9788 <tf_keras.src.callbacks.History at 0x7f06b00f8eb0>
定义辅助函数来计算和打印模型的稀疏性。
def print_model_weights_sparsity(model):
for layer in model.layers:
if isinstance(layer, keras.layers.Wrapper):
weights = layer.trainable_weights
else:
weights = layer.weights
for weight in weights:
# ignore auxiliary quantization weights
if "quantize_layer" in weight.name:
continue
weight_size = weight.numpy().size
zero_num = np.count_nonzero(weight == 0)
print(
f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
f"({zero_num}/{weight_size})",
)
检查模型是否已正确修剪。我们需要首先剥离修剪包装器。
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
print_model_weights_sparsity(stripped_pruned_model)
conv2d/kernel:0: 50.00% sparsity (54/108) conv2d/bias:0: 0.00% sparsity (0/12) dense/kernel:0: 50.00% sparsity (10140/20280) dense/bias:0: 0.00% sparsity (0/10)
对于此示例,与基线相比,修剪后测试准确率的损失很小。
_, pruned_model_accuracy = pruned_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', pruned_model_accuracy)
Baseline test accuracy: 0.9812999963760376 Pruned test accuracy: 0.9769999980926514
应用 QAT 和 PQAT,并检查两种情况下对模型稀疏性的影响
接下来,我们对修剪后的模型应用 QAT 和修剪保留 QAT (PQAT),并观察 PQAT 是否保留了修剪模型的稀疏性。请注意,我们在应用 PQAT API 之前,使用 tfmot.sparsity.keras.strip_pruning
从修剪后的模型中剥离了修剪包装器。
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_pruned_model)
qat_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
# PQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
stripped_pruned_model)
pqat_model = tfmot.quantization.keras.quantize_apply(
quant_aware_annotate_model,
tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())
pqat_model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print('Train pqat model:')
pqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model: 422/422 [==============================] - 4s 7ms/step - loss: 0.0384 - accuracy: 0.9893 - val_loss: 0.0539 - val_accuracy: 0.9847 Train pqat model: 422/422 [==============================] - 4s 7ms/step - loss: 0.0395 - accuracy: 0.9890 - val_loss: 0.0543 - val_accuracy: 0.9850 <tf_keras.src.callbacks.History at 0x7f06cb25efa0>
print("QAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("PQAT Model sparsity:")
print_model_weights_sparsity(pqat_model)
QAT Model sparsity: conv2d/kernel:0: 15.74% sparsity (17/108) conv2d/bias:0: 0.00% sparsity (0/12) dense/kernel:0: 11.48% sparsity (2328/20280) dense/bias:0: 0.00% sparsity (0/10) PQAT Model sparsity: conv2d/kernel:0: 50.00% sparsity (54/108) conv2d/bias:0: 0.00% sparsity (0/12) dense/kernel:0: 50.00% sparsity (10140/20280) dense/bias:0: 0.00% sparsity (0/10)
查看 PQAT 模型的压缩优势
定义辅助函数以获取压缩的模型文件。
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in kilobytes.
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)/1000
由于这是一个小型模型,因此两个模型之间的差异并不明显。将修剪和 PQAT 应用于更大的生产模型将产生更显著的压缩。
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
f.write(qat_tflite_model)
# PQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pqat_tflite_model = converter.convert()
pqat_model_file = 'pqat_model.tflite'
# Save the model.
with open(pqat_model_file, 'wb') as f:
f.write(pqat_tflite_model)
print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PQAT model size: ", get_gzipped_model_size(pqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpe5bjb60h/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpe5bjb60h/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1709988173.208136 34779 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709988173.208187 34779 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkxxlbf0o/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkxxlbf0o/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( W0000 00:00:1709988175.442658 34779 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709988175.442688 34779 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency. QAT model size: 16.923 KB PQAT model size: 14.491 KB
查看从 TF 到 TFLite 的准确率持久性
定义一个辅助函数来评估测试数据集上的 TFLite 模型。
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print(f"Evaluated on {i} results so far.")
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
您评估已修剪和量化的模型,然后查看 TensorFlow 中的准确率是否在 TFLite 后端中保持。
interpreter = tf.lite.Interpreter(pqat_model_file)
interpreter.allocate_tensors()
pqat_test_accuracy = eval_model(interpreter)
print('Pruned and quantized TFLite test_accuracy:', pqat_test_accuracy)
print('Pruned TF test accuracy:', pruned_model_accuracy)
Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. INFO: Created TensorFlow Lite XNNPACK delegate for CPU. WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#12 is a dynamic-sized tensor). Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Pruned and quantized TFLite test_accuracy: 0.9821 Pruned TF test accuracy: 0.9769999980926514
应用训练后量化,并与 PQAT 模型进行比较
接下来,我们对修剪后的模型使用正常的训练后量化(无微调),并检查其准确率与 PQAT 模型的准确率。这演示了为什么您需要使用 PQAT 来提高量化模型的准确率。
首先,为来自前 1000 个训练图像的校准数据集定义一个生成器。
def mnist_representative_data_gen():
for image in train_images[:1000]:
image = np.expand_dims(image, axis=0).astype(np.float32)
yield [image]
量化模型,并将准确率与之前获得的 PQAT 模型进行比较。请注意,使用微调量化的模型实现了更高的准确率。
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
f.write(post_training_tflite_model)
# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()
post_training_test_accuracy = eval_model(interpreter)
print('PQAT TFLite test_accuracy:', pqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4xwi7ko3/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp4xwi7ko3/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( W0000 00:00:1709988177.152521 34779 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709988177.152549 34779 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32 Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. PQAT TFLite test_accuracy: 0.9821 Post-training (no fine-tuning) TF test accuracy: 0.9764
结论
在本教程中,您学习了如何创建模型,使用稀疏性 API 对其进行修剪,并应用稀疏性保留量化感知训练 (PQAT) 来保留稀疏性,同时使用 QAT。最终的 PQAT 模型与 QAT 模型进行了比较,以表明前者保留了稀疏性,而后者则丢失了稀疏性。接下来,将模型转换为 TFLite,以展示将修剪和 PQAT 模型优化技术链接起来的压缩优势,并评估 TFLite 模型以确保准确率在 TFLite 后端中保持。最后,将 PQAT 模型与使用训练后量化 API 实现的量化修剪模型进行比较,以展示 PQAT 在恢复正常量化导致的准确率损失方面的优势。