在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
欢迎使用 Keras 量化感知训练综合指南。
此页面记录了各种用例,并展示了如何针对每种用例使用 API。了解所需 API 后,请在 API 文档 中查找参数和底层详细信息。
涵盖以下用例
- 按照以下步骤部署具有 8 位量化的模型。
- 定义量化感知模型。
- 仅适用于 Keras HDF5 模型,使用特殊的检查点和反序列化逻辑。否则,训练是标准的。
- 从量化感知模型创建量化模型。
- 尝试量化。
- 任何实验都没有受支持的部署路径。
- 自定义 Keras 层属于实验范畴。
设置
为了查找所需的 API 并了解目的,您可以运行此部分,但可以跳过阅读。
! pip install -q tensorflow
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tf_keras as keras
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = keras.Sequential([
keras.layers.Dense(20, input_shape=input_shape),
keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model= setup_model()
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def setup_pretrained_model():
model = setup_model()
pretrained_weights = setup_pretrained_weights()
model.load_weights(pretrained_weights)
return model
setup_model()
pretrained_weights = setup_pretrained_weights()
2024-03-09 12:29:37.526315: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
定义量化感知模型
通过以下方式定义模型,可以找到在 概述页面 中列出的后端部署的可用路径。默认情况下,使用 8 位量化。
量化整个模型
您的用例
- 不支持子类化模型。
提高模型精度的提示
- 尝试“量化部分层”以跳过量化精度下降幅度最大的层。
- 通常,使用量化感知训练进行微调比从头开始训练效果更好。
要使整个模型感知量化,请将 tfmot.quantization.keras.quantize_model
应用于模型。
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer (QuantizeLa (None, 20) 3 yer) quant_dense_2 (QuantizeWra (None, 20) 425 pperV2) quant_flatten_2 (QuantizeW (None, 20) 1 rapperV2) ================================================================= Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte) _________________________________________________________________
量化部分层
量化模型会对精度产生负面影响。您可以选择性地量化模型的层,以探索精度、速度和模型大小之间的权衡。
您的用例
- 要部署到仅适用于完全量化模型的后端(例如 EdgeTPU v1、大多数 DSP),请尝试“量化整个模型”。
提高模型精度的提示
- 通常,使用量化感知训练进行微调比从头开始训练效果更好。
- 尝试量化后层,而不是前层。
- 避免量化关键层(例如注意力机制)。
在以下示例中,仅量化 Dense
层。
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `quantize_annotate_layer` to annotate that only the
# Dense layers should be quantized.
def apply_quantization_to_dense(layer):
if isinstance(layer, keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
# Use `keras.models.clone_model` to apply `apply_quantization_to_dense`
# to the layers of the model.
annotated_model = keras.models.clone_model(
base_model,
clone_function=apply_quantization_to_dense,
)
# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_1 (Quantize (None, 20) 3 Layer) quant_dense_3 (QuantizeWra (None, 20) 425 pperV2) flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 428 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 8 (32.00 Byte) _________________________________________________________________
虽然此示例使用层的类型来决定要量化什么,但量化特定层的最简单方法是设置其 name
属性,并在 clone_function
中查找该名称。
print(base_model.layers[0].name)
dense_3
可读性更强,但模型精度可能较低
这与使用量化感知训练进行微调不兼容,这就是为什么它可能不如以上示例准确的原因。
函数示例
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
i = keras.Input(shape=(20,))
x = tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(10))(i)
o = keras.layers.Flatten()(x)
annotated_model = keras.Model(inputs=i, outputs=o)
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
quant_aware_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 20)] 0 quantize_layer_2 (Quantize (None, 20) 3 Layer) quant_dense_4 (QuantizeWra (None, 10) 215 pperV2) flatten_4 (Flatten) (None, 10) 0 ================================================================= Total params: 218 (872.00 Byte) Trainable params: 210 (840.00 Byte) Non-trainable params: 8 (32.00 Byte) _________________________________________________________________
顺序示例
# Use `quantize_annotate_layer` to annotate that the `Dense` layer
# should be quantized.
annotated_model = keras.Sequential([
tfmot.quantization.keras.quantize_annotate_layer(keras.layers.Dense(20, input_shape=input_shape)),
keras.layers.Flatten()
])
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
quant_aware_model.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_3 (Quantize (None, 20) 3 Layer) quant_dense_5 (QuantizeWra (None, 20) 425 pperV2) flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 428 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 8 (32.00 Byte) _________________________________________________________________
检查点和反序列化
您的用例:此代码仅适用于 HDF5 模型格式(不适用于 HDF5 权重或其他格式)。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
quant_aware_model.save(keras_model_file)
# `quantize_scope` is needed for deserializing HDF5 models.
with tfmot.quantization.keras.quantize_scope():
loaded_model = keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`. saving_api.save_model( WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_4 (Quantize (None, 20) 3 Layer) quant_dense_6 (QuantizeWra (None, 20) 425 pperV2) quant_flatten_6 (QuantizeW (None, 20) 1 rapperV2) ================================================================= Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte) _________________________________________________________________
创建和部署量化模型
通常,参考您将使用的部署后端的文档。
这是 TFLite 后端的示例。
base_model = setup_pretrained_model()
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
# Typically you train the model here.
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
1/1 [==============================] - 1s 684ms/step - loss: 16.1181 - accuracy: 0.0000e+00 WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://tensorflowcn.cn/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4 INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpyo_u4d_8/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpyo_u4d_8/assets /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:964: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway. warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1709987395.907073 23976 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format. W0000 00:00:1709987395.907116 23976 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
实验量化
您的用例:使用以下 API 意味着没有受支持的部署路径。例如,TFLite 转换和内核实现仅支持 8 位量化。这些功能也处于实验阶段,不受向后兼容性的约束。
tfmot.quantization.keras.QuantizeConfig
tfmot.quantization.keras.quantizers.Quantizer
tfmot.quantization.keras.quantizers.LastValueQuantizer
tfmot.quantization.keras.quantizers.MovingAverageQuantizer
设置:DefaultDenseQuantizeConfig
实验需要使用 tfmot.quantization.keras.QuantizeConfig
,它描述了如何量化层的权重、激活和输出。
以下是一个示例,它定义了与 API 默认值中 Dense
层使用的相同的 QuantizeConfig
。
在此示例中的前向传播期间,LastValueQuantizer
在 get_weights_and_quantizers
中返回,并使用 layer.kernel
作为输入调用,生成一个输出。该输出通过 set_quantize_weights
中定义的逻辑,替换 Dense
层的原始前向传播中的 layer.kernel
。相同的思想适用于激活和输出。
LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer
class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
# Configure how to quantize weights.
def get_weights_and_quantizers(self, layer):
return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]
# Configure how to quantize activations.
def get_activations_and_quantizers(self, layer):
return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]
def set_quantize_weights(self, layer, quantize_weights):
# Add this line for each item returned in `get_weights_and_quantizers`
# , in the same order
layer.kernel = quantize_weights[0]
def set_quantize_activations(self, layer, quantize_activations):
# Add this line for each item returned in `get_activations_and_quantizers`
# , in the same order.
layer.activation = quantize_activations[0]
# Configure how to quantize outputs (may be equivalent to activations).
def get_output_quantizers(self, layer):
return []
def get_config(self):
return {}
量化自定义 Keras 层
此示例使用 DefaultDenseQuantizeConfig
来量化 CustomLayer
。
在“使用量化进行实验”用例中,应用配置的方式是相同的。
- 将
tfmot.quantization.keras.quantize_annotate_layer
应用于CustomLayer
,并传入QuantizeConfig
。 - 使用
tfmot.quantization.keras.quantize_annotate_model
继续使用 API 默认值量化模型的其余部分。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class CustomLayer(keras.layers.Dense):
pass
model = quantize_annotate_model(keras.Sequential([
quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# as well as the custom Keras layer.
with quantize_scope(
{'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
'CustomLayer': CustomLayer}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_6 (Quantize (None, 20) 3 Layer) quant_custom_layer (Quanti (None, 20) 425 zeWrapperV2) quant_flatten_9 (QuantizeW (None, 20) 1 rapperV2) ================================================================= Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte) _________________________________________________________________
修改量化参数
常见错误:将偏差量化为少于 32 位通常会极大地损害模型准确性。
此示例修改 Dense
层,使其对权重使用 4 位,而不是默认的 8 位。模型的其余部分继续使用 API 默认值。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
# Configure weights to quantize with 4-bit instead of 8-bits.
def get_weights_and_quantizers(self, layer):
return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]
在“使用量化进行实验”用例中,应用配置的方式是相同的。
- 将
tfmot.quantization.keras.quantize_annotate_layer
应用于Dense
层,并传入QuantizeConfig
。 - 使用
tfmot.quantization.keras.quantize_annotate_model
继续使用 API 默认值量化模型的其余部分。
model = quantize_annotate_model(keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this Dense layer.
quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_9" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_7 (Quantize (None, 20) 3 Layer) quant_dense_9 (QuantizeWra (None, 20) 425 pperV2) quant_flatten_10 (Quantize (None, 20) 1 WrapperV2) ================================================================= Total params: 429 (1.68 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 9 (36.00 Byte) _________________________________________________________________
修改要量化的层部分
此示例修改 Dense
层,以跳过对激活的量化。模型的其余部分继续使用 API 默认值。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
def get_activations_and_quantizers(self, layer):
# Skip quantizing activations.
return []
def set_quantize_activations(self, layer, quantize_activations):
# Empty since `get_activaations_and_quantizers` returns
# an empty list.
return
在“使用量化进行实验”用例中,应用配置的方式是相同的。
- 将
tfmot.quantization.keras.quantize_annotate_layer
应用于Dense
层,并传入QuantizeConfig
。 - 使用
tfmot.quantization.keras.quantize_annotate_model
继续使用 API 默认值量化模型的其余部分。
model = quantize_annotate_model(keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this Dense layer.
quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_10" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_8 (Quantize (None, 20) 3 Layer) quant_dense_10 (QuantizeWr (None, 20) 423 apperV2) quant_flatten_11 (Quantize (None, 20) 1 WrapperV2) ================================================================= Total params: 427 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 7 (28.00 Byte) _________________________________________________________________
使用自定义量化算法
tfmot.quantization.keras.quantizers.Quantizer
类是一个可调用类,可以将其任何算法应用于其输入。
在此示例中,输入是权重,我们将 FixedRangeQuantizer
__call__ 函数中的数学运算应用于权重。现在,FixedRangeQuantizer
的输出将传递给任何使用权重的内容,而不是原始权重值。
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_scope = tfmot.quantization.keras.quantize_scope
class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):
"""Quantizer which forces outputs to be between -1 and 1."""
def build(self, tensor_shape, name, layer):
# Not needed. No new TensorFlow variables needed.
return {}
def __call__(self, inputs, training, weights, **kwargs):
return keras.backend.clip(inputs, -1.0, 1.0)
def get_config(self):
# Not needed. No __init__ parameters to serialize.
return {}
class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):
# Configure weights to quantize with 4-bit instead of 8-bits.
def get_weights_and_quantizers(self, layer):
# Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.
return [(layer.kernel, FixedRangeQuantizer())]
在“使用量化进行实验”用例中,应用配置的方式是相同的。
- 将
tfmot.quantization.keras.quantize_annotate_layer
应用于Dense
层,并传入QuantizeConfig
。 - 使用
tfmot.quantization.keras.quantize_annotate_model
继续使用 API 默认值量化模型的其余部分。
model = quantize_annotate_model(keras.Sequential([
# Pass in modified `QuantizeConfig` to modify this `Dense` layer.
quantize_annotate_layer(keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),
keras.layers.Flatten()
]))
# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:
with quantize_scope(
{'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):
# Use `quantize_apply` to actually make the model quantization aware.
quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
quant_aware_model.summary()
Model: "sequential_11" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= quantize_layer_9 (Quantize (None, 20) 3 Layer) quant_dense_11 (QuantizeWr (None, 20) 423 apperV2) quant_flatten_12 (Quantize (None, 20) 1 WrapperV2) ================================================================= Total params: 427 (1.67 KB) Trainable params: 420 (1.64 KB) Non-trainable params: 7 (28.00 Byte) _________________________________________________________________