![]() |
![]() |
![]() |
![]() |
此笔记本展示了如何使用神经网络和 TensorFlow Compression 进行有损数据压缩。
以下示例使用类似自动编码器的模型来压缩来自 MNIST 数据集的图像。该方法基于论文 端到端优化图像压缩。
有关学习数据压缩的更多背景信息,请参阅针对熟悉经典数据压缩的人员的 这篇论文,或针对机器学习受众的 这篇综述。
通过 pip
安装 Tensorflow Compression。
# Installs the latest version of TFC compatible with the installed TF version.
read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
2024-02-02 02:21:52.947295: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-02-02 02:21:52.947345: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-02-02 02:21:52.947388: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
- **分析**(或编码器)变换,将图像转换为潜在空间,
- **合成**(或解码器)变换,将潜在空间转换回图像空间,以及
- **先验**和熵模型,对潜在变量的边缘概率进行建模。
def make_analysis_transform(latent_dims):
"""Creates the analysis (encoder) transform."""
return tf.keras.Sequential([
20, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_1"),
50, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_2"),
500, use_bias=True, activation="leaky_relu", name="fc_1"),
latent_dims, use_bias=True, activation=None, name="fc_2"),
], name="analysis_transform")
def make_synthesis_transform():
"""Creates the synthesis (decoder) transform."""
return tf.keras.Sequential([
500, use_bias=True, activation="leaky_relu", name="fc_1"),
2450, use_bias=True, activation="leaky_relu", name="fc_2"),
tf.keras.layers.Reshape((7, 7, 50)),
20, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_1"),
1, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_2"),
], name="synthesis_transform")
它的 call
- **速率**,表示数字批次所需的比特数的估计值,以及
- **失真**,原始数字像素与其重建之间的平均绝对差。
class MNISTCompressionTrainer(tf.keras.Model):
"""Model that trains a compressor/decompressor for MNIST."""
def __init__(self, latent_dims):
self.analysis_transform = make_analysis_transform(latent_dims)
self.synthesis_transform = make_synthesis_transform()
self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))
def prior(self):
return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))
def call(self, x, training):
"""Computes rate and distortion losses."""
# Ensure inputs are floats in the range (0, 1).
x = tf.cast(x, self.compute_dtype) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
# Compute latent space representation y, perturb it and model its entropy,
# then compute the reconstructed pixel-level representation x_hat.
y = self.analysis_transform(x)
entropy_model = tfc.ContinuousBatchedEntropyModel(
self.prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=training)
x_tilde = self.synthesis_transform(y_tilde)
# Average number of bits per MNIST digit.
rate = tf.reduce_mean(rate)
# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
return dict(rate=rate, distortion=distortion)
让我们一步一步地完成此步骤,使用训练集中的一个图像。加载用于训练和验证的 MNIST 数据集
training_dataset, validation_dataset = tfds.load(
split=["train", "test"],
2024-02-02 02:21:57.270402: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
并提取一个图像 \(x\)
(x, _), = validation_dataset.take(1)
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
2024-02-02 02:21:57.719574: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Data type: <dtype: 'uint8'> Shape: (28, 28, 1)
要获得潜在表示 \(y\),我们需要将其转换为 float32
x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)
print("y:", y)
y: tf.Tensor( [[ 0.02126924 -0.06894962 0.09302588 -0.03838679 -0.07932404 0.04434497 -0.01607926 0.03123217 -0.00262138 -0.10917155]], shape=(1, 10), dtype=float32)
潜在变量将在测试时进行量化。为了在训练期间以可微分的方式对这种情况进行建模,我们在区间 \((-.5, .5)\) 中添加均匀噪声,并将结果称为 \(\tilde y\)。这与论文 端到端优化图像压缩 中使用的术语相同。
y_tilde = y + tf.random.uniform(y.shape, -.5, .5)
print("y_tilde:", y_tilde)
y_tilde: tf.Tensor( [[ 0.3478785 0.08000787 -0.18317918 0.25482622 -0.31788376 -0.13823514 0.362665 -0.05973584 0.34111002 -0.1322978 ]], shape=(1, 10), dtype=float32)
“先验”是我们训练的概率密度,用于对噪声潜在变量的边缘分布进行建模。例如,它可以是一组独立的 逻辑分布,每个潜在维度具有不同的尺度。 tfc.NoisyLogistic
考虑了潜在变量具有加性噪声的事实。当尺度接近零时,逻辑分布接近狄拉克 delta(尖峰),但添加的噪声会导致“噪声”分布反而接近均匀分布。
prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))
_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));
在训练期间, tfc.ContinuousBatchedEntropyModel
entropy_model = tfc.ContinuousBatchedEntropyModel(
prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)
print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.012472], shape=(1,), dtype=float32) y_tilde: tf.Tensor( [[-0.09349963 -0.02974442 -0.1147738 0.39858678 0.33601853 -0.13327162 -0.10957606 0.2511304 -0.15129623 0.25686544]], shape=(1, 10), dtype=float32)
最后,噪声潜在变量通过合成变换传递回,以生成图像重建 \(\tilde x\)。失真是在原始图像和重建之间的误差。显然,在变换未经训练的情况下,重建不是很有用。
x_tilde = make_synthesis_transform()(y_tilde)
# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)
x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17072156, shape=(), dtype=float32) Data type: <dtype: 'uint8'> Shape: (28, 28, 1)
对于每批数字,调用 MNISTCompressionTrainer
(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)
print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate: tf.Tensor(20.296253, shape=(), dtype=float32) distortion: tf.Tensor(0.14659302, shape=(), dtype=float32) 2024-02-02 02:21:58.788887: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
我们以一种优化速率-失真拉格朗日的方式编译训练器,即速率和失真的总和,其中一项由拉格朗日参数 \(\lambda\) 加权。
- 分析变换经过训练,以生成能够实现速率和失真之间所需权衡的潜在表示。
- 合成变换经过训练,以最小化给定潜在表示的失真。
- 先验参数经过训练,以最小化给定潜在表示的速率。这与以最大似然意义拟合潜在变量的边缘分布相同。
def pass_through_loss(_, x):
# Since rate and distortion are unsupervised, the loss doesn't need a target.
return x
def make_mnist_compression_trainer(lmbda, latent_dims=50):
trainer = MNISTCompressionTrainer(latent_dims)
# Just pass through rate and distortion as losses/metrics.
loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
loss_weights=dict(rate=1., distortion=lmbda),
return trainer
接下来,训练模型。这里不需要人工标注,因为我们只想压缩图像,所以我们使用 map
def add_rd_targets(image, label):
# Training is unsupervised, so labels aren't necessary here. However, we
# need to add "dummy" targets for rate and distortion.
return image, dict(rate=0., distortion=0.)
def train_mnist_model(lmbda):
trainer = make_mnist_compression_trainer(lmbda)
return trainer
trainer = train_mnist_model(lmbda=2000)
Epoch 1/15 467/469 [============================>.] - ETA: 0s - loss: 219.4970 - distortion_loss: 0.0598 - rate_loss: 99.8539 - distortion_pass_through_loss: 0.0598 - rate_pass_through_loss: 99.8539 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 14s 23ms/step - loss: 219.3287 - distortion_loss: 0.0598 - rate_loss: 99.8235 - distortion_pass_through_loss: 0.0597 - rate_pass_through_loss: 99.8190 - val_loss: 176.6952 - val_distortion_loss: 0.0424 - val_rate_loss: 91.9455 - val_distortion_pass_through_loss: 0.0424 - val_rate_pass_through_loss: 91.9553 Epoch 2/15 469/469 [==============================] - 10s 21ms/step - loss: 166.0387 - distortion_loss: 0.0410 - rate_loss: 83.9876 - distortion_pass_through_loss: 0.0410 - rate_pass_through_loss: 83.9833 - val_loss: 156.0399 - val_distortion_loss: 0.0397 - val_rate_loss: 76.5447 - val_distortion_pass_through_loss: 0.0397 - val_rate_pass_through_loss: 76.5545 Epoch 3/15 469/469 [==============================] - 10s 21ms/step - loss: 150.7590 - distortion_loss: 0.0397 - rate_loss: 71.2698 - distortion_pass_through_loss: 0.0397 - rate_pass_through_loss: 71.2663 - val_loss: 143.7437 - val_distortion_loss: 0.0399 - val_rate_loss: 64.0326 - val_distortion_pass_through_loss: 0.0399 - val_rate_pass_through_loss: 64.0276 Epoch 4/15 469/469 [==============================] - 10s 21ms/step - loss: 142.2607 - distortion_loss: 0.0395 - rate_loss: 63.2149 - distortion_pass_through_loss: 0.0395 - rate_pass_through_loss: 63.2126 - val_loss: 136.2439 - val_distortion_loss: 0.0406 - val_rate_loss: 55.1126 - val_distortion_pass_through_loss: 0.0406 - val_rate_pass_through_loss: 55.0992 Epoch 5/15 469/469 [==============================] - 10s 21ms/step - loss: 136.6350 - distortion_loss: 0.0392 - rate_loss: 58.2967 - distortion_pass_through_loss: 0.0392 - rate_pass_through_loss: 58.2950 - val_loss: 131.8454 - val_distortion_loss: 0.0418 - val_rate_loss: 48.2437 - val_distortion_pass_through_loss: 0.0418 - val_rate_pass_through_loss: 48.2309 Epoch 6/15 469/469 [==============================] - 10s 21ms/step - loss: 132.8357 - distortion_loss: 0.0388 - rate_loss: 55.1559 - distortion_pass_through_loss: 0.0388 - rate_pass_through_loss: 55.1541 - val_loss: 126.8537 - val_distortion_loss: 0.0413 - val_rate_loss: 44.1647 - val_distortion_pass_through_loss: 0.0413 - val_rate_pass_through_loss: 44.1611 Epoch 7/15 469/469 [==============================] - 10s 21ms/step - loss: 129.6230 - distortion_loss: 0.0383 - rate_loss: 52.9344 - distortion_pass_through_loss: 0.0383 - rate_pass_through_loss: 52.9332 - val_loss: 123.1619 - val_distortion_loss: 0.0406 - val_rate_loss: 41.8766 - val_distortion_pass_through_loss: 0.0406 - val_rate_pass_through_loss: 41.8841 Epoch 8/15 469/469 [==============================] - 10s 21ms/step - loss: 127.0022 - distortion_loss: 0.0379 - rate_loss: 51.1485 - distortion_pass_through_loss: 0.0379 - rate_pass_through_loss: 51.1476 - val_loss: 119.7338 - val_distortion_loss: 0.0393 - val_rate_loss: 41.0842 - val_distortion_pass_through_loss: 0.0393 - val_rate_pass_through_loss: 41.0881 Epoch 9/15 469/469 [==============================] - 10s 21ms/step - loss: 124.5822 - distortion_loss: 0.0375 - rate_loss: 49.6230 - distortion_pass_through_loss: 0.0375 - rate_pass_through_loss: 49.6216 - val_loss: 118.7188 - val_distortion_loss: 0.0396 - val_rate_loss: 39.5464 - val_distortion_pass_through_loss: 0.0396 - val_rate_pass_through_loss: 39.5388 Epoch 10/15 469/469 [==============================] - 10s 20ms/step - loss: 122.2946 - distortion_loss: 0.0370 - rate_loss: 48.2343 - distortion_pass_through_loss: 0.0370 - rate_pass_through_loss: 48.2335 - val_loss: 116.0431 - val_distortion_loss: 0.0378 - val_rate_loss: 40.5313 - val_distortion_pass_through_loss: 0.0378 - val_rate_pass_through_loss: 40.5214 Epoch 11/15 469/469 [==============================] - 10s 20ms/step - loss: 120.2840 - distortion_loss: 0.0366 - rate_loss: 47.0336 - distortion_pass_through_loss: 0.0366 - rate_pass_through_loss: 47.0329 - val_loss: 115.0391 - val_distortion_loss: 0.0374 - val_rate_loss: 40.2651 - val_distortion_pass_through_loss: 0.0374 - val_rate_pass_through_loss: 40.2673 Epoch 12/15 469/469 [==============================] - 9s 20ms/step - loss: 118.5191 - distortion_loss: 0.0363 - rate_loss: 45.9115 - distortion_pass_through_loss: 0.0363 - rate_pass_through_loss: 45.9113 - val_loss: 113.4488 - val_distortion_loss: 0.0362 - val_rate_loss: 41.1207 - val_distortion_pass_through_loss: 0.0362 - val_rate_pass_through_loss: 41.1314 Epoch 13/15 469/469 [==============================] - 10s 20ms/step - loss: 116.7709 - distortion_loss: 0.0358 - rate_loss: 45.1241 - distortion_pass_through_loss: 0.0358 - rate_pass_through_loss: 45.1242 - val_loss: 112.9705 - val_distortion_loss: 0.0358 - val_rate_loss: 41.3719 - val_distortion_pass_through_loss: 0.0358 - val_rate_pass_through_loss: 41.3784 Epoch 14/15 469/469 [==============================] - 10s 20ms/step - loss: 115.5197 - distortion_loss: 0.0355 - rate_loss: 44.5284 - distortion_pass_through_loss: 0.0355 - rate_pass_through_loss: 44.5273 - val_loss: 111.5821 - val_distortion_loss: 0.0356 - val_rate_loss: 40.3043 - val_distortion_pass_through_loss: 0.0356 - val_rate_pass_through_loss: 40.3116 Epoch 15/15 469/469 [==============================] - 10s 21ms/step - loss: 114.3993 - distortion_loss: 0.0352 - rate_loss: 44.0437 - distortion_pass_through_loss: 0.0352 - rate_pass_through_loss: 44.0429 - val_loss: 110.9734 - val_distortion_loss: 0.0350 - val_rate_loss: 40.9841 - val_distortion_pass_through_loss: 0.0350 - val_rate_pass_through_loss: 40.9915
压缩一些 MNIST 图像。
- 编码器端包含分析变换和熵模型。
- 解码器端包含合成变换和相同的熵模型。
在测试时,潜在变量将不会有加性噪声,但它们将被量化,然后进行无损压缩,因此我们为它们赋予新的名称。我们分别将它们和图像重建称为 \(\hat x\) 和 \(\hat y\)(遵循 端到端优化图像压缩)。
class MNISTCompressor(tf.keras.Model):
"""Compresses MNIST images to strings."""
def __init__(self, analysis_transform, entropy_model):
self.analysis_transform = analysis_transform
self.entropy_model = entropy_model
def call(self, x):
# Ensure inputs are floats in the range (0, 1).
x = tf.cast(x, self.compute_dtype) / 255.
y = self.analysis_transform(x)
# Also return the exact information content of each digit.
_, bits = self.entropy_model(y, training=False)
return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
"""Decompresses MNIST images from strings."""
def __init__(self, entropy_model, synthesis_transform):
self.entropy_model = entropy_model
self.synthesis_transform = synthesis_transform
def call(self, string):
y_hat = self.entropy_model.decompress(string, ())
x_hat = self.synthesis_transform(y_hat)
# Scale and cast back to 8-bit integer.
return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)
当使用 compression=True
实例化时,熵模型将学习的先验转换为用于范围编码算法的表格。当调用 compress()
def make_mnist_codec(trainer, **kwargs):
# The entropy model must be created with `compression=True` and the same
# instance must be shared between compressor and decompressor.
entropy_model = tfc.ContinuousBatchedEntropyModel(
trainer.prior, coding_rank=1, compression=True, **kwargs)
compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
return compressor, decompressor
compressor, decompressor = make_mnist_codec(trainer)
从验证数据集中获取 16 张图像。您可以通过更改 skip
(originals, _), = validation_dataset.batch(16).skip(3).take(1)
strings, entropies = compressor(originals)
print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x4d63d90ed1 Number of bits actually needed to represent it: 37.99
reconstructions = decompressor(strings)
显示每个 16 个原始数字及其压缩的二进制表示,以及重建的数字。
display_digits(originals, strings, entropies, reconstructions)
在上面,模型经过训练以实现特定的权衡(由 lmbda=2000
让我们首先将 \(\lambda\) 减少到 500。
def train_and_visualize_model(lmbda):
trainer = train_mnist_model(lmbda=lmbda)
compressor, decompressor = make_mnist_codec(trainer)
strings, entropies = compressor(originals)
reconstructions = decompressor(strings)
display_digits(originals, strings, entropies, reconstructions)
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 127.5305 - distortion_loss: 0.0700 - rate_loss: 92.5392 - distortion_pass_through_loss: 0.0700 - rate_pass_through_loss: 92.5329 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 12s 21ms/step - loss: 127.5305 - distortion_loss: 0.0700 - rate_loss: 92.5392 - distortion_pass_through_loss: 0.0700 - rate_pass_through_loss: 92.5329 - val_loss: 107.4705 - val_distortion_loss: 0.0548 - val_rate_loss: 80.0521 - val_distortion_pass_through_loss: 0.0549 - val_rate_pass_through_loss: 80.0560 Epoch 2/15 469/469 [==============================] - 10s 20ms/step - loss: 97.2982 - distortion_loss: 0.0541 - rate_loss: 70.2615 - distortion_pass_through_loss: 0.0541 - rate_pass_through_loss: 70.2563 - val_loss: 86.4681 - val_distortion_loss: 0.0599 - val_rate_loss: 56.5326 - val_distortion_pass_through_loss: 0.0599 - val_rate_pass_through_loss: 56.5317 Epoch 3/15 469/469 [==============================] - 10s 20ms/step - loss: 81.2346 - distortion_loss: 0.0562 - rate_loss: 53.1251 - distortion_pass_through_loss: 0.0562 - rate_pass_through_loss: 53.1215 - val_loss: 71.9185 - val_distortion_loss: 0.0682 - val_rate_loss: 37.8106 - val_distortion_pass_through_loss: 0.0683 - val_rate_pass_through_loss: 37.8011 Epoch 4/15 469/469 [==============================] - 10s 20ms/step - loss: 71.5505 - distortion_loss: 0.0593 - rate_loss: 41.8880 - distortion_pass_through_loss: 0.0593 - rate_pass_through_loss: 41.8856 - val_loss: 63.3954 - val_distortion_loss: 0.0762 - val_rate_loss: 25.3134 - val_distortion_pass_through_loss: 0.0762 - val_rate_pass_through_loss: 25.3113 Epoch 5/15 469/469 [==============================] - 10s 20ms/step - loss: 65.9471 - distortion_loss: 0.0622 - rate_loss: 34.8656 - distortion_pass_through_loss: 0.0622 - rate_pass_through_loss: 34.8637 - val_loss: 57.6301 - val_distortion_loss: 0.0786 - val_rate_loss: 18.3391 - val_distortion_pass_through_loss: 0.0786 - val_rate_pass_through_loss: 18.3298 Epoch 6/15 469/469 [==============================] - 9s 20ms/step - loss: 62.4032 - distortion_loss: 0.0642 - rate_loss: 30.3145 - distortion_pass_through_loss: 0.0642 - rate_pass_through_loss: 30.3131 - val_loss: 55.5859 - val_distortion_loss: 0.0848 - val_rate_loss: 13.1697 - val_distortion_pass_through_loss: 0.0848 - val_rate_pass_through_loss: 13.1684 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 59.9169 - distortion_loss: 0.0656 - rate_loss: 27.1412 - distortion_pass_through_loss: 0.0655 - rate_pass_through_loss: 27.1403 - val_loss: 51.1707 - val_distortion_loss: 0.0758 - val_rate_loss: 13.2823 - val_distortion_pass_through_loss: 0.0758 - val_rate_pass_through_loss: 13.2810 Epoch 8/15 469/469 [==============================] - 9s 20ms/step - loss: 57.7210 - distortion_loss: 0.0660 - rate_loss: 24.7239 - distortion_pass_through_loss: 0.0660 - rate_pass_through_loss: 24.7234 - val_loss: 49.3867 - val_distortion_loss: 0.0736 - val_rate_loss: 12.5935 - val_distortion_pass_through_loss: 0.0736 - val_rate_pass_through_loss: 12.5888 Epoch 9/15 469/469 [==============================] - 10s 20ms/step - loss: 55.7091 - distortion_loss: 0.0658 - rate_loss: 22.7991 - distortion_pass_through_loss: 0.0658 - rate_pass_through_loss: 22.7988 - val_loss: 48.4335 - val_distortion_loss: 0.0695 - val_rate_loss: 13.6779 - val_distortion_pass_through_loss: 0.0695 - val_rate_pass_through_loss: 13.6865 Epoch 10/15 469/469 [==============================] - 10s 20ms/step - loss: 53.9030 - distortion_loss: 0.0651 - rate_loss: 21.3707 - distortion_pass_through_loss: 0.0651 - rate_pass_through_loss: 21.3701 - val_loss: 47.2973 - val_distortion_loss: 0.0689 - val_rate_loss: 12.8318 - val_distortion_pass_through_loss: 0.0690 - val_rate_pass_through_loss: 12.8296 Epoch 11/15 469/469 [==============================] - 9s 20ms/step - loss: 52.3183 - distortion_loss: 0.0641 - rate_loss: 20.2617 - distortion_pass_through_loss: 0.0641 - rate_pass_through_loss: 20.2610 - val_loss: 46.6293 - val_distortion_loss: 0.0671 - val_rate_loss: 13.0571 - val_distortion_pass_through_loss: 0.0672 - val_rate_pass_through_loss: 13.0598 Epoch 12/15 469/469 [==============================] - 10s 20ms/step - loss: 51.0015 - distortion_loss: 0.0633 - rate_loss: 19.3680 - distortion_pass_through_loss: 0.0633 - rate_pass_through_loss: 19.3676 - val_loss: 46.1679 - val_distortion_loss: 0.0642 - val_rate_loss: 14.0836 - val_distortion_pass_through_loss: 0.0642 - val_rate_pass_through_loss: 14.0873 Epoch 13/15 469/469 [==============================] - 9s 20ms/step - loss: 49.9189 - distortion_loss: 0.0627 - rate_loss: 18.5747 - distortion_pass_through_loss: 0.0627 - rate_pass_through_loss: 18.5742 - val_loss: 45.6781 - val_distortion_loss: 0.0640 - val_rate_loss: 13.6870 - val_distortion_pass_through_loss: 0.0640 - val_rate_pass_through_loss: 13.6915 Epoch 14/15 469/469 [==============================] - 9s 20ms/step - loss: 48.9510 - distortion_loss: 0.0621 - rate_loss: 17.9217 - distortion_pass_through_loss: 0.0621 - rate_pass_through_loss: 17.9215 - val_loss: 45.1058 - val_distortion_loss: 0.0614 - val_rate_loss: 14.4060 - val_distortion_pass_through_loss: 0.0614 - val_rate_pass_through_loss: 14.4218 Epoch 15/15 469/469 [==============================] - 9s 20ms/step - loss: 48.1553 - distortion_loss: 0.0615 - rate_loss: 17.4100 - distortion_pass_through_loss: 0.0615 - rate_pass_through_loss: 17.4101 - val_loss: 44.9707 - val_distortion_loss: 0.0609 - val_rate_loss: 14.5456 - val_distortion_pass_through_loss: 0.0609 - val_rate_pass_through_loss: 14.5536
让我们进一步减少 \(\lambda\)。
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 114.0398 - distortion_loss: 0.0768 - rate_loss: 90.9927 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 90.9862 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 12s 21ms/step - loss: 114.0398 - distortion_loss: 0.0768 - rate_loss: 90.9927 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 90.9862 - val_loss: 96.7790 - val_distortion_loss: 0.0688 - val_rate_loss: 76.1469 - val_distortion_pass_through_loss: 0.0688 - val_rate_pass_through_loss: 76.1451 Epoch 2/15 469/469 [==============================] - 10s 20ms/step - loss: 85.8627 - distortion_loss: 0.0613 - rate_loss: 67.4866 - distortion_pass_through_loss: 0.0613 - rate_pass_through_loss: 67.4809 - val_loss: 74.3796 - val_distortion_loss: 0.0796 - val_rate_loss: 50.5063 - val_distortion_pass_through_loss: 0.0797 - val_rate_pass_through_loss: 50.5014 Epoch 3/15 469/469 [==============================] - 10s 21ms/step - loss: 68.7960 - distortion_loss: 0.0645 - rate_loss: 49.4494 - distortion_pass_through_loss: 0.0645 - rate_pass_through_loss: 49.4456 - val_loss: 58.5856 - val_distortion_loss: 0.0911 - val_rate_loss: 31.2701 - val_distortion_pass_through_loss: 0.0912 - val_rate_pass_through_loss: 31.2587 Epoch 4/15 469/469 [==============================] - 10s 21ms/step - loss: 58.2693 - distortion_loss: 0.0693 - rate_loss: 37.4914 - distortion_pass_through_loss: 0.0693 - rate_pass_through_loss: 37.4892 - val_loss: 48.5439 - val_distortion_loss: 0.0983 - val_rate_loss: 19.0433 - val_distortion_pass_through_loss: 0.0984 - val_rate_pass_through_loss: 19.0396 Epoch 5/15 469/469 [==============================] - 10s 21ms/step - loss: 51.9585 - distortion_loss: 0.0734 - rate_loss: 29.9320 - distortion_pass_through_loss: 0.0734 - rate_pass_through_loss: 29.9306 - val_loss: 42.2598 - val_distortion_loss: 0.0994 - val_rate_loss: 12.4350 - val_distortion_pass_through_loss: 0.0995 - val_rate_pass_through_loss: 12.4339 Epoch 6/15 469/469 [==============================] - 10s 20ms/step - loss: 48.0547 - distortion_loss: 0.0769 - rate_loss: 24.9776 - distortion_pass_through_loss: 0.0769 - rate_pass_through_loss: 24.9766 - val_loss: 38.5121 - val_distortion_loss: 0.1000 - val_rate_loss: 8.4996 - val_distortion_pass_through_loss: 0.1000 - val_rate_pass_through_loss: 8.5025 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 45.3418 - distortion_loss: 0.0798 - rate_loss: 21.4144 - distortion_pass_through_loss: 0.0798 - rate_pass_through_loss: 21.4135 - val_loss: 36.2517 - val_distortion_loss: 0.0992 - val_rate_loss: 6.4969 - val_distortion_pass_through_loss: 0.0991 - val_rate_pass_through_loss: 6.5043 Epoch 8/15 469/469 [==============================] - 9s 20ms/step - loss: 43.1311 - distortion_loss: 0.0812 - rate_loss: 18.7581 - distortion_pass_through_loss: 0.0812 - rate_pass_through_loss: 18.7577 - val_loss: 34.4578 - val_distortion_loss: 0.0923 - val_rate_loss: 6.7545 - val_distortion_pass_through_loss: 0.0923 - val_rate_pass_through_loss: 6.7532 Epoch 9/15 469/469 [==============================] - 10s 20ms/step - loss: 41.1835 - distortion_loss: 0.0815 - rate_loss: 16.7441 - distortion_pass_through_loss: 0.0815 - rate_pass_through_loss: 16.7435 - val_loss: 33.3371 - val_distortion_loss: 0.0870 - val_rate_loss: 7.2369 - val_distortion_pass_through_loss: 0.0870 - val_rate_pass_through_loss: 7.2306 Epoch 10/15 469/469 [==============================] - 10s 20ms/step - loss: 39.4039 - distortion_loss: 0.0806 - rate_loss: 15.2191 - distortion_pass_through_loss: 0.0806 - rate_pass_through_loss: 15.2184 - val_loss: 32.8390 - val_distortion_loss: 0.0840 - val_rate_loss: 7.6290 - val_distortion_pass_through_loss: 0.0840 - val_rate_pass_through_loss: 7.6263 Epoch 11/15 469/469 [==============================] - 9s 20ms/step - loss: 37.9107 - distortion_loss: 0.0793 - rate_loss: 14.1125 - distortion_pass_through_loss: 0.0793 - rate_pass_through_loss: 14.1124 - val_loss: 32.6137 - val_distortion_loss: 0.0816 - val_rate_loss: 8.1349 - val_distortion_pass_through_loss: 0.0816 - val_rate_pass_through_loss: 8.1267 Epoch 12/15 469/469 [==============================] - 10s 20ms/step - loss: 36.6647 - distortion_loss: 0.0780 - rate_loss: 13.2545 - distortion_pass_through_loss: 0.0780 - rate_pass_through_loss: 13.2540 - val_loss: 32.2822 - val_distortion_loss: 0.0814 - val_rate_loss: 7.8685 - val_distortion_pass_through_loss: 0.0814 - val_rate_pass_through_loss: 7.8598 Epoch 13/15 469/469 [==============================] - 10s 20ms/step - loss: 35.6982 - distortion_loss: 0.0771 - rate_loss: 12.5623 - distortion_pass_through_loss: 0.0771 - rate_pass_through_loss: 12.5622 - val_loss: 31.9754 - val_distortion_loss: 0.0791 - val_rate_loss: 8.2589 - val_distortion_pass_through_loss: 0.0791 - val_rate_pass_through_loss: 8.2520 Epoch 14/15 469/469 [==============================] - 9s 20ms/step - loss: 34.9956 - distortion_loss: 0.0765 - rate_loss: 12.0543 - distortion_pass_through_loss: 0.0765 - rate_pass_through_loss: 12.0540 - val_loss: 31.8855 - val_distortion_loss: 0.0786 - val_rate_loss: 8.3127 - val_distortion_pass_through_loss: 0.0786 - val_rate_pass_through_loss: 8.2984 Epoch 15/15 469/469 [==============================] - 9s 20ms/step - loss: 34.4267 - distortion_loss: 0.0759 - rate_loss: 11.6673 - distortion_pass_through_loss: 0.0759 - rate_pass_through_loss: 11.6670 - val_loss: 31.7163 - val_distortion_loss: 0.0777 - val_rate_loss: 8.3933 - val_distortion_pass_through_loss: 0.0778 - val_rate_pass_through_loss: 8.3861
compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)
import os
strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)