在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
欢迎来到 Keras 权重修剪的综合指南。
此页面记录了各种用例,并展示了如何为每个用例使用 API。 了解所需 API 后,请在 API 文档 中查找参数和低级详细信息。
涵盖以下用例
- 定义和训练修剪后的模型。
- 顺序和函数式。
- Keras 模型。fit 和自定义训练循环
- 检查点和反序列化修剪后的模型。
- 部署修剪后的模型并查看压缩优势。
有关修剪算法的配置,请参阅 tfmot.sparsity.keras.prune_low_magnitude
API 文档。
设置
要查找所需的 API 并了解用途,您可以运行但跳过阅读本节。
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tf_keras as keras
%load_ext tensorboard
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = keras.Sequential([
keras.layers.Dense(20, input_shape=input_shape),
keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model = setup_model()
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def get_gzipped_model_size(model):
# Returns size of gzipped model, in bytes.
import os
import zipfile
_, keras_file = tempfile.mkstemp('.h5')
model.save(keras_file, include_optimizer=False)
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(keras_file)
return os.path.getsize(zipped_file)
setup_model()
pretrained_weights = setup_pretrained_weights()
2024-03-09 12:22:11.550860: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
定义模型
修剪整个模型(顺序和函数式)
提高模型精度的技巧
- 尝试“修剪某些层”以跳过修剪对精度影响最大的层。
- 通常,与从头开始训练相比,使用修剪进行微调更好。
要使整个模型使用修剪进行训练,请将 tfmot.sparsity.keras.prune_low_magnitude
应用于模型。
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
model_for_pruning.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 2 (PruneLowMagnitude) prune_low_magnitude_flatte (None, 20) 1 n_2 (PruneLowMagnitude) ================================================================= Total params: 823 (3.22 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 403 (1.58 KB) _________________________________________________________________
修剪某些层(顺序和函数式)
修剪模型可能会对精度产生负面影响。 您可以选择性地修剪模型的层,以探索精度、速度和模型大小之间的权衡。
提高模型精度的技巧
- 通常,与从头开始训练相比,使用修剪进行微调更好。
- 尝试修剪后面的层而不是前面的层。
- 避免修剪关键层(例如注意力机制)。
更多:
tfmot.sparsity.keras.prune_low_magnitude
API 文档提供了有关如何根据层更改修剪配置的详细信息。
在下面的示例中,仅修剪 Dense
层。
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `prune_low_magnitude` to make only the
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
if isinstance(layer, keras.layers.Dense):
return tfmot.sparsity.keras.prune_low_magnitude(layer)
return layer
# Use `keras.models.clone_model` to apply `apply_pruning_to_dense`
# to the layers of the model.
model_for_pruning = keras.models.clone_model(
base_model,
clone_function=apply_pruning_to_dense,
)
model_for_pruning.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 3 (PruneLowMagnitude) flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 822 (3.21 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 402 (1.57 KB) _________________________________________________________________
虽然此示例使用层类型来决定要修剪的内容,但修剪特定层的最简单方法是设置其 name
属性,并在 clone_function
中查找该名称。
print(base_model.layers[0].name)
dense_3
更易读,但模型精度可能较低
这与使用修剪进行微调不兼容,因此它可能不如支持微调的上述示例准确。
虽然 prune_low_magnitude
可以在定义初始模型时应用,但在下面的示例中,加载权重后不起作用。
函数式示例
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(10))(i)
o = keras.layers.Flatten()(x)
model_for_pruning = keras.Model(inputs=i, outputs=o)
model_for_pruning.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 20)] 0 prune_low_magnitude_dense_ (None, 10) 412 4 (PruneLowMagnitude) flatten_4 (Flatten) (None, 10) 0 ================================================================= Total params: 412 (1.61 KB) Trainable params: 210 (840.00 Byte) Non-trainable params: 202 (812.00 Byte) _________________________________________________________________
顺序示例
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(20, input_shape=input_shape)),
keras.layers.Flatten()
])
model_for_pruning.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 5 (PruneLowMagnitude) flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 822 (3.21 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 402 (1.57 KB) _________________________________________________________________
修剪自定义 Keras 层或修改要修剪的层的部分
常见错误: 修剪偏差通常会严重损害模型精度。
tfmot.sparsity.keras.PrunableLayer
服务于两种用例
- 修剪自定义 Keras 层
- 修改内置 Keras 层的部分以进行修剪。
例如,API 默认仅修剪 Dense
层的内核。 下面的示例还修剪了偏差。
class MyDenseLayer(keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):
def get_prunable_weights(self):
# Prune bias also, though that usually harms model accuracy too much.
return [self.kernel, self.bias]
# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
keras.layers.Flatten()
])
model_for_pruning.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_my_den (None, 20) 843 se_layer (PruneLowMagnitud e) flatten_6 (Flatten) (None, 20) 0 ================================================================= Total params: 843 (3.30 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 423 (1.66 KB) _________________________________________________________________
训练模型
Model.fit
在训练期间调用 tfmot.sparsity.keras.UpdatePruningStep
回调。
为了帮助调试训练,请使用 tfmot.sparsity.keras.PruningSummaries
回调。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
log_dir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
# Log sparsity and other metrics in Tensorboard.
tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]
model_for_pruning.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model_for_pruning.fit(
x_train,
y_train,
callbacks=callbacks,
epochs=2,
)
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
对于非 Colab 用户,您可以在 TensorBoard.dev 上查看此代码块的先前运行结果。
自定义训练循环
在训练期间调用 tfmot.sparsity.keras.UpdatePruningStep
回调。
为了帮助调试训练,请使用 tfmot.sparsity.keras.PruningSummaries
回调。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Boilerplate
loss = keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.
# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)
step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
for _ in range(batches):
step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback
with tf.GradientTape() as tape:
logits = model_for_pruning(x_train, training=True)
loss_value = loss(y_train, logits)
grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))
step_callback.on_epoch_end(batch=unused_arg) # run pruning callback
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
对于非 Colab 用户,您可以在 TensorBoard.dev 上查看此代码块的先前运行结果。
提高修剪后的模型精度
首先,查看 tfmot.sparsity.keras.prune_low_magnitude
API 文档,了解修剪计划是什么以及每种修剪计划的数学原理。
技巧:
在模型进行修剪时,使用不太高或不太低的学习率。 将 修剪计划 视为超参数。
作为快速测试,尝试在训练开始时将模型修剪到最终稀疏度,方法是将
begin_step
设置为 0,并使用tfmot.sparsity.keras.ConstantSparsity
计划。您可能会幸运地获得良好的结果。不要频繁修剪,以便模型有时间恢复。 修剪计划 提供了相当不错的默认频率。
有关提高模型准确性的通用想法,请在“定义模型”下查找适合您的用例的提示。
检查点和反序列化
您必须在检查点过程中保留优化器步骤。这意味着虽然您可以使用 Keras HDF5 模型进行检查点,但不能使用 Keras HDF5 权重。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
_, keras_model_file = tempfile.mkstemp('.h5')
# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`. saving_api.save_model( WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
以上内容普遍适用。以下代码仅适用于 HDF5 模型格式(不适用于 HDF5 权重和其他格式)。
# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
loaded_model = keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 6 (PruneLowMagnitude) prune_low_magnitude_flatte (None, 20) 1 n_7 (PruneLowMagnitude) ================================================================= Total params: 823 (3.22 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 403 (1.58 KB) _________________________________________________________________
部署修剪后的模型
导出具有大小压缩的模型
常见错误: strip_pruning
和应用标准压缩算法(例如通过 gzip)都是必要的,才能看到修剪的压缩优势。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Typically you train the model here.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("final model")
model_for_export.summary()
print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 final model Model: "sequential_7" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_7 (Dense) (None, 20) 420 flatten_8 (Flatten) (None, 20) 0 ================================================================= Total params: 420 (1.64 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. Size of gzipped pruned model without stripping: 3455.00 bytes WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. Size of gzipped pruned model with stripping: 2939.00 bytes
特定于硬件的优化
一旦不同的后端 启用修剪以提高延迟,使用块稀疏性可以提高某些硬件的延迟。
增加块大小将降低目标模型精度可实现的峰值稀疏度。尽管如此,延迟仍然可以提高。
有关块稀疏性支持内容的详细信息,请参阅 tfmot.sparsity.keras.prune_low_magnitude
API 文档。
base_model = setup_model()
# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)
model_for_pruning.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_ (None, 20) 822 8 (PruneLowMagnitude) prune_low_magnitude_flatte (None, 20) 1 n_9 (PruneLowMagnitude) ================================================================= Total params: 823 (3.22 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 403 (1.58 KB) _________________________________________________________________