保持稀疏性的聚类 Keras 示例

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

概述

这是一个端到端示例,展示了 **保持稀疏性的聚类** API 的用法,它是 TensorFlow 模型优化工具包的协作优化管道的组成部分。

其他页面

有关管道和可用技术的介绍,请参阅 协作优化概述页面

内容

在本教程中,您将

  1. 从头开始训练 MNIST 数据集的 keras 模型。
  2. 使用稀疏性微调模型,查看准确率并观察模型是否已成功剪枝。
  3. 将权重聚类应用于剪枝模型,并观察稀疏性的损失。
  4. 在剪枝模型上应用保持稀疏性的聚类,并观察之前应用的稀疏性是否已保留。
  5. 生成 TFLite 模型,并检查剪枝聚类模型中是否保留了准确率。
  6. 比较不同模型的大小,以观察应用稀疏性后,再应用保持稀疏性的聚类这种协作优化技术带来的压缩优势。

设置

您可以在本地 virtualenvcolab 中运行此 Jupyter 笔记本。有关设置依赖项的详细信息,请参阅 安装指南

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

训练 MNIST 的 keras 模型,以便进行剪枝和聚类

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2024-03-09 12:54:09.347032: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 21s 4ms/step - loss: 0.3119 - accuracy: 0.9120 - val_loss: 0.1272 - val_accuracy: 0.9640
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1224 - accuracy: 0.9655 - val_loss: 0.0870 - val_accuracy: 0.9770
Epoch 3/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0908 - accuracy: 0.9740 - val_loss: 0.0740 - val_accuracy: 0.9800
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0759 - accuracy: 0.9775 - val_loss: 0.0639 - val_accuracy: 0.9830
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0659 - accuracy: 0.9810 - val_loss: 0.0653 - val_accuracy: 0.9832
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0593 - accuracy: 0.9822 - val_loss: 0.0675 - val_accuracy: 0.9805
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0538 - accuracy: 0.9839 - val_loss: 0.0615 - val_accuracy: 0.9825
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0489 - accuracy: 0.9848 - val_loss: 0.0619 - val_accuracy: 0.9832
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0457 - accuracy: 0.9862 - val_loss: 0.0639 - val_accuracy: 0.9838
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0422 - accuracy: 0.9870 - val_loss: 0.0593 - val_accuracy: 0.9835
<tf_keras.src.callbacks.History at 0x7f0f2dac7b80>

评估基线模型并保存以供以后使用

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.98089998960495
Saving model to:  /tmpfs/tmp/tmp98l4xiax.h5
/tmpfs/tmp/ipykernel_44770/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

将模型剪枝到 50% 的稀疏性并进行微调

应用 prune_low_magnitude() API 对整个预训练模型进行剪枝,以实现将在下一步中进行聚类的模型。有关如何最佳地使用 API 来实现最佳压缩率并同时保持目标准确率,请参阅 剪枝综合指南

定义模型并应用稀疏性 API

请注意,使用的是预训练模型。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

pruned_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense   (None, 10)                40572     
 (PruneLowMagnitude)                                             
                                                                 
=================================================================
Total params: 40805 (159.41 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 20395 (79.69 KB)
_________________________________________________________________

微调模型,检查稀疏性,并评估相对于基线的准确率

使用剪枝对模型进行 3 个 epoch 的微调。

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
Epoch 1/3
1688/1688 [==============================] - 10s 4ms/step - loss: 0.0805 - accuracy: 0.9729 - val_loss: 0.0834 - val_accuracy: 0.9753
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0636 - accuracy: 0.9791 - val_loss: 0.0735 - val_accuracy: 0.9798
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0557 - accuracy: 0.9825 - val_loss: 0.0688 - val_accuracy: 0.9813
<tf_keras.src.callbacks.History at 0x7f0f2d9bc220>

定义辅助函数来计算和打印模型的稀疏性。

def print_model_weights_sparsity(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

检查模型内核是否已正确剪枝。我们需要先剥离剪枝包装器。我们还创建了模型的深层副本,将在下一步中使用。

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

stripped_pruned_model_copy = keras.models.clone_model(stripped_pruned_model)
stripped_pruned_model_copy.set_weights(stripped_pruned_model.get_weights())
conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

应用聚类和保持稀疏性的聚类,并检查其对两种情况下模型稀疏性的影响

接下来,我们在剪枝模型上应用聚类和保持稀疏性的聚类,并观察后者如何保留剪枝模型的稀疏性。请注意,我们在应用聚类 API 之前,使用 tfmot.sparsity.keras.strip_pruning 从剪枝模型中剥离了剪枝包装器。

# Clustering
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train clustering model:')
clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)


stripped_pruned_model.save("stripped_pruned_model_clustered.h5")

# Sparsity preserving clustering
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model_copy, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
Train clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0477 - accuracy: 0.9846 - val_loss: 0.0659 - val_accuracy: 0.9813
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0464 - accuracy: 0.9851 - val_loss: 0.0611 - val_accuracy: 0.9825
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0452 - accuracy: 0.9855 - val_loss: 0.0728 - val_accuracy: 0.9797
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(
Train sparsity preserving clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0471 - accuracy: 0.9853 - val_loss: 0.0669 - val_accuracy: 0.9823
Epoch 2/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0446 - accuracy: 0.9863 - val_loss: 0.0661 - val_accuracy: 0.9817
Epoch 3/3
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0435 - accuracy: 0.9857 - val_loss: 0.0714 - val_accuracy: 0.9813
<tf_keras.src.callbacks.History at 0x7f0e4415cc40>

检查两个模型的稀疏性。

print("Clustered Model sparsity:\n")
print_model_weights_sparsity(clustered_model)
print("\nSparsity preserved clustered Model sparsity:\n")
print_model_weights_sparsity(sparsity_clustered_model)
Clustered Model sparsity:

conv2d/kernel:0: 0.00% sparsity  (0/108)
dense/kernel:0: 0.34% sparsity  (69/20280)

Sparsity preserved clustered Model sparsity:

conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

从聚类创建 1.6 倍更小的模型

定义辅助函数以获取压缩的模型文件。

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000
# Clustered model
clustered_model_file = 'clustered_model.h5'

# Save the model.
clustered_model.save(clustered_model_file)

#Sparsity Preserve Clustered model
sparsity_clustered_model_file = 'sparsity_clustered_model.h5'

# Save the model.
sparsity_clustered_model.save(sparsity_clustered_model_file)

print("Clustered Model size: ", get_gzipped_model_size(clustered_model_file), ' KB')
print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
Clustered Model size:  247.191  KB
Sparsity preserved clustered Model size:  155.272  KB

通过结合保持稀疏性的权重聚类和训练后量化来创建 TFLite 模型

剥离聚类包装器并转换为 TFLite。

stripped_sparsity_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_sparsity_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
sparsity_clustered_quant_model = converter.convert()

_, pruned_and_clustered_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_and_clustered_tflite_file, 'wb') as f:
  f.write(sparsity_clustered_quant_model)

print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
print("Sparsity preserved clustered and quantized TFLite model size:",
       get_gzipped_model_size(pruned_and_clustered_tflite_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpzf7jus7v/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpzf7jus7v/assets
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1709989008.133294   44770 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709989008.133348   44770 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
Sparsity preserved clustered Model size:  155.272  KB
Sparsity preserved clustered and quantized TFLite model size: 8.183  KB

查看从 TF 到 TFLite 的准确率是否保持不变

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

您评估已剪枝、聚类和量化的模型,然后查看 TensorFlow 中的准确率是否在 TFLite 后端中保持不变。

# Keras model evaluation
stripped_sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
_, sparsity_clustered_keras_accuracy = stripped_sparsity_clustered_model.evaluate(
    test_images, test_labels, verbose=0)

# TFLite model evaluation
interpreter = tf.lite.Interpreter(pruned_and_clustered_tflite_file)
interpreter.allocate_tensors()

sparsity_clustered_tflite_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized Keras model accuracy:', sparsity_clustered_keras_accuracy)
print('Pruned, clustered and quantized TFLite model accuracy:', sparsity_clustered_tflite_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
WARNING: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#13 is a dynamic-sized tensor).
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned, clustered and quantized Keras model accuracy: 0.9782999753952026
Pruned, clustered and quantized TFLite model accuracy: 0.9784

结论

在本教程中,您学习了如何创建模型,使用 prune_low_magnitude() API 对其进行剪枝,并应用保持稀疏性的聚类来在聚类权重的同时保留稀疏性。将保持稀疏性的聚类模型与聚类模型进行比较,以表明前者保留了稀疏性,而后者则丢失了稀疏性。接下来,将剪枝聚类模型转换为 TFLite,以展示将剪枝和保持稀疏性的聚类模型优化技术链接在一起带来的压缩优势,最后,评估 TFLite 模型以确保准确率在 TFLite 后端中保持不变。