针对设备上推理的 XNNPACK 剪枝

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

欢迎来到 Keras 权重剪枝指南,该指南介绍了如何通过 XNNPACK 提高设备上推理的延迟。

本指南介绍了新引入的 tfmot.sparsity.keras.PruningPolicy API 的用法,并演示了如何使用它来加速现代 CPU 上的卷积模型,使用 XNNPACK 稀疏推理

本指南涵盖了模型创建过程的以下步骤

  • 构建和训练密集型基线
  • 使用剪枝微调模型
  • 转换为 TFLite
  • 设备上基准测试

本指南不涵盖使用剪枝进行微调的最佳实践。有关此主题的更详细的信息,请查看我们的 综合指南

设置

 pip install -q tensorflow
 pip install -q tensorflow-model-optimization
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
import tf_keras as keras

%load_ext tensorboard

构建和训练密集型模型

我们在 CIFAR10 数据集上构建和训练了一个简单的基线 CNN 用于分类任务。

# Load CIFAR10 dataset.
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train[:90%]', 'train[90%:]', 'test'],
    as_supervised=True,
    with_info=True,
)

# Normalize the input image so that each pixel value is between 0 and 1.
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.image.convert_image_dtype(image, tf.float32), label

# Load the data in batches of 128 images.
batch_size = 128
def prepare_dataset(ds, buffer_size=None):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  if buffer_size:
    ds = ds.shuffle(buffer_size)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds

ds_train = prepare_dataset(ds_train,
                           buffer_size=ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)

# Build the dense baseline model.
dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Compile and train the dense model for 10 epochs.
dense_model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

dense_model.fit(
  ds_train,
  epochs=10,
  validation_data=ds_val)

# Evaluate the dense model.
_, dense_model_accuracy = dense_model.evaluate(ds_test, verbose=0)
2024-03-09 12:24:36.121481: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
352/352 [==============================] - 30s 22ms/step - loss: 1.9770 - accuracy: 0.2809 - val_loss: 2.2183 - val_accuracy: 0.1870
Epoch 2/10
352/352 [==============================] - 5s 14ms/step - loss: 1.7223 - accuracy: 0.3653 - val_loss: 1.7735 - val_accuracy: 0.3536
Epoch 3/10
352/352 [==============================] - 5s 14ms/step - loss: 1.6209 - accuracy: 0.4032 - val_loss: 1.8308 - val_accuracy: 0.3450
Epoch 4/10
352/352 [==============================] - 5s 14ms/step - loss: 1.5506 - accuracy: 0.4355 - val_loss: 1.5608 - val_accuracy: 0.4204
Epoch 5/10
352/352 [==============================] - 5s 14ms/step - loss: 1.5062 - accuracy: 0.4489 - val_loss: 1.6044 - val_accuracy: 0.4158
Epoch 6/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4679 - accuracy: 0.4653 - val_loss: 1.5631 - val_accuracy: 0.4178
Epoch 7/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4425 - accuracy: 0.4773 - val_loss: 1.4628 - val_accuracy: 0.4752
Epoch 8/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4227 - accuracy: 0.4844 - val_loss: 1.5183 - val_accuracy: 0.4478
Epoch 9/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4066 - accuracy: 0.4886 - val_loss: 1.5305 - val_accuracy: 0.4382
Epoch 10/10
352/352 [==============================] - 5s 14ms/step - loss: 1.3929 - accuracy: 0.4952 - val_loss: 1.4030 - val_accuracy: 0.4894

构建稀疏模型

使用 综合指南 中的说明,我们应用 tfmot.sparsity.keras.prune_low_magnitude 函数,该函数的参数针对通过剪枝进行设备上加速,即 tfmot.sparsity.keras.PruneForLatencyOnXNNPack 策略。

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after after 5 epochs.
end_epoch = 5

num_iterations_per_epoch = len(ds_train)
end_step =  num_iterations_per_epoch * end_epoch

# Define parameters for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.25,
                                                               final_sparsity=0.75,
                                                               begin_step=0,
                                                               end_step=end_step),
      'pruning_policy': tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
}

# Try to apply pruning wrapper with pruning policy parameter.
try:
  model_for_pruning = prune_low_magnitude(dense_model, **pruning_params)
except ValueError as e:
  print(e)

调用 prune_low_magnitude 会导致 ValueError,并显示消息 Could not find a GlobalAveragePooling2D layer with keepdims = True in all output branches。该消息表明该模型不支持使用策略 tfmot.sparsity.keras.PruneForLatencyOnXNNPack 进行剪枝,具体来说,层 GlobalAveragePooling2D 需要参数 keepdims = True。让我们修复它并重新应用 prune_low_magnitude 函数。

fixed_dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(keepdims=True),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Use the pretrained model for pruning instead of training from scratch.
fixed_dense_model.set_weights(dense_model.get_weights())

# Try to reapply pruning wrapper.
model_for_pruning = prune_low_magnitude(fixed_dense_model, **pruning_params)

调用 prune_low_magnitude 已完成,没有任何错误,这意味着该模型完全支持 tfmot.sparsity.keras.PruneForLatencyOnXNNPack 策略,可以使用 XNNPACK 稀疏推理 进行加速。

微调稀疏模型

按照 剪枝示例,我们使用密集型模型的权重来微调稀疏模型。我们从 25% 的稀疏度(25% 的权重设置为零)开始微调模型,并以 75% 的稀疏度结束。

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

model_for_pruning.fit(
  ds_train,
  epochs=15,
  validation_data=ds_val,
  callbacks=callbacks)

# Evaluate the dense model.
_, pruned_model_accuracy = model_for_pruning.evaluate(ds_test, verbose=0)

print('Dense model test accuracy:', dense_model_accuracy)
print('Pruned model test accuracy:', pruned_model_accuracy)
Epoch 1/15
352/352 [==============================] - 11s 17ms/step - loss: 1.3992 - accuracy: 0.4897 - val_loss: 1.9449 - val_accuracy: 0.3402
Epoch 2/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4250 - accuracy: 0.4852 - val_loss: 1.7185 - val_accuracy: 0.3716
Epoch 3/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4584 - accuracy: 0.4666 - val_loss: 1.8855 - val_accuracy: 0.3426
Epoch 4/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4717 - accuracy: 0.4616 - val_loss: 1.8802 - val_accuracy: 0.3554
Epoch 5/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4519 - accuracy: 0.4727 - val_loss: 1.6495 - val_accuracy: 0.3972
Epoch 6/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4326 - accuracy: 0.4800 - val_loss: 1.4971 - val_accuracy: 0.4416
Epoch 7/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4205 - accuracy: 0.4860 - val_loss: 1.7675 - val_accuracy: 0.4002
Epoch 8/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4114 - accuracy: 0.4893 - val_loss: 1.5721 - val_accuracy: 0.4234
Epoch 9/15
352/352 [==============================] - 5s 15ms/step - loss: 1.4038 - accuracy: 0.4917 - val_loss: 1.6057 - val_accuracy: 0.4236
Epoch 10/15
352/352 [==============================] - 5s 15ms/step - loss: 1.3959 - accuracy: 0.4930 - val_loss: 1.5344 - val_accuracy: 0.4484
Epoch 11/15
352/352 [==============================] - 5s 14ms/step - loss: 1.3899 - accuracy: 0.4969 - val_loss: 1.4643 - val_accuracy: 0.4768
Epoch 12/15
352/352 [==============================] - 5s 14ms/step - loss: 1.3829 - accuracy: 0.4996 - val_loss: 1.5114 - val_accuracy: 0.4494
Epoch 13/15
352/352 [==============================] - 5s 14ms/step - loss: 1.3777 - accuracy: 0.5020 - val_loss: 1.5931 - val_accuracy: 0.4278
Epoch 14/15
352/352 [==============================] - 5s 14ms/step - loss: 1.3749 - accuracy: 0.5018 - val_loss: 1.4799 - val_accuracy: 0.4680
Epoch 15/15
352/352 [==============================] - 5s 15ms/step - loss: 1.3704 - accuracy: 0.5041 - val_loss: 1.5630 - val_accuracy: 0.4490
Dense model test accuracy: 0.49380001425743103
Pruned model test accuracy: 0.44940000772476196

日志显示了每个层的稀疏度进展。

#docs_infra: no_execute
%tensorboard --logdir={logdir}

在使用剪枝进行微调后,测试精度与密集型模型相比略有提高(从 43% 提高到 44%)。让我们使用 TFLite 基准测试 比较设备上的延迟。

模型转换和基准测试

要将剪枝模型转换为 TFLite,我们需要使用 strip_pruning 函数将 PruneLowMagnitude 包装器替换为原始层。此外,由于剪枝模型的权重 (model_for_pruning) 大部分为零,我们可以应用优化 tf.lite.Optimize.EXPERIMENTAL_SPARSITY 来有效地存储生成的 TFLite 模型。此优化标志对于密集型模型不是必需的。

converter = tf.lite.TFLiteConverter.from_keras_model(dense_model)
dense_tflite_model = converter.convert()

_, dense_tflite_file = tempfile.mkstemp('.tflite')
with open(dense_tflite_file, 'wb') as f:
  f.write(dense_tflite_model)

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.EXPERIMENTAL_SPARSITY]
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmprnl_sl6s/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmprnl_sl6s/assets
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709987241.973351   18472 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987241.973414   18472 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpk0lumuch/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpk0lumuch/assets
W0000 00:00:1709987245.660280   18472 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709987245.660323   18472 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.

按照 TFLite 模型基准测试工具 的说明,我们构建了该工具,并将其与密集型和剪枝后的 TFLite 模型一起上传到 Android 设备,并在设备上对这两个模型进行基准测试。

! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/dense_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1
/bin/bash: adb: command not found
! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/pruned_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1
/bin/bash: adb: command not found

在 Pixel 4 上的基准测试结果显示,密集型模型的平均推理时间为 *17us*,剪枝模型的平均推理时间为 *12us*。设备上的基准测试表明,即使对于如此小的模型,延迟也明显提高了 **5us** 或 **30%**。根据我们的经验,基于 MobileNetV3EfficientNet-lite 的更大模型显示出类似的性能改进。加速效果因 1x1 卷积对整个模型的相对贡献而异。

结论

在本教程中,我们展示了如何使用 TF MOT API 和 XNNPack 引入的新功能来创建稀疏模型,以实现更快的设备上性能。这些稀疏模型比其密集型对应模型更小、更快,同时保留甚至超越了其质量。

我们鼓励您尝试此新功能,它对于在设备上部署您的模型尤其重要。