在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
本笔记本演示了如何在 MNIST 数据集上训练变分自动编码器 (VAE) (1, 2)。VAE 是自动编码器的一种概率方法,它将高维输入数据压缩成更小的表示。与将输入映射到潜在向量上的传统自动编码器不同,VAE 将输入数据映射到概率分布的参数中,例如高斯的均值和方差。这种方法产生了连续的、结构化的潜在空间,这对于图像生成很有用。
设置
pip install tensorflow-probability
# to generate gifs
pip install imageio
pip install git+https://github.com/tensorflow/docs
from IPython import display
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time
加载 MNIST 数据集
每个 MNIST 图像最初是一个包含 784 个整数的向量,每个整数介于 0-255 之间,表示像素的强度。在我们的模型中,使用伯努利分布对每个像素进行建模,并静态地对数据集进行二值化。
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
def preprocess_images(images):
images = images.reshape((images.shape[0], 28, 28, 1)) / 255.
return np.where(images > .5, 1.0, 0.0).astype('float32')
train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)
train_size = 60000
batch_size = 32
test_size = 10000
使用 tf.data 对数据进行批处理和洗牌
train_dataset = (tf.data.Dataset.from_tensor_slices(train_images)
.shuffle(train_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(test_images)
.shuffle(test_size).batch(batch_size))
使用 tf.keras.Sequential 定义编码器和解码器网络
在这个 VAE 示例中,使用两个小的卷积网络作为编码器和解码器网络。在文献中,这些网络也被分别称为推理/识别模型和生成模型。使用 tf.keras.Sequential
简化实现。在下述描述中,令 \(x\) 和 \(z\) 分别表示观测值和潜在变量。
编码器网络
这定义了近似后验分布 \(q(z|x)\),它将观测值作为输入,并输出一组参数,用于指定潜在表示 \(z\) 的条件分布。在本例中,只需将分布建模为对角高斯,网络输出分解高斯的均值和对数方差参数。为了数值稳定性,输出对数方差而不是方差本身。
解码器网络
这定义了观测值 \(p(x|z)\) 的条件分布,它将潜在样本 \(z\) 作为输入,并输出观测值的条件分布的参数。将潜在分布先验 \(p(z)\) 建模为单位高斯。
重新参数化技巧
为了在训练期间为解码器生成样本 \(z\),可以从编码器输出的参数定义的潜在分布中采样,给定输入观测值 \(x\)。但是,这种采样操作会创建一个瓶颈,因为反向传播无法流过随机节点。
为了解决这个问题,使用重新参数化技巧。在我们的示例中,使用解码器参数和另一个参数 \(\epsilon\) 近似 \(z\),如下所示
\[z = \mu + \sigma \odot \epsilon\]
其中 \(\mu\) 和 \(\sigma\) 分别表示高斯分布的均值和标准差。它们可以从解码器输出中推导出。\(\epsilon\) 可以被认为是用于保持 \(z\) 随机性的随机噪声。从标准正态分布中生成 \(\epsilon\)。
潜在变量 \(z\) 现在由 \(\mu\)、\(\sigma\) 和 \(\epsilon\) 的函数生成,这将使模型能够分别通过 \(\mu\) 和 \(\sigma\) 对编码器中的梯度进行反向传播,同时通过 \(\epsilon\) 保持随机性。
网络架构
对于编码器网络,使用两个卷积层,然后是一个全连接层。在解码器网络中,通过使用一个全连接层,然后是三个卷积转置层(在某些情况下也称为反卷积层)来镜像此架构。注意,在训练 VAE 时,通常的做法是避免使用批归一化,因为使用小批次带来的额外随机性可能会加剧采样带来的不稳定性。
class CVAE(tf.keras.Model):
"""Convolutional variational autoencoder."""
def __init__(self, latent_dim):
super(CVAE, self).__init__()
self.latent_dim = latent_dim
self.encoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim + latent_dim),
]
)
self.decoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(7, 7, 32)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='same',
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=3, strides=2, padding='same',
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=3, strides=1, padding='same'),
]
)
@tf.function
def sample(self, eps=None):
if eps is None:
eps = tf.random.normal(shape=(100, self.latent_dim))
return self.decode(eps, apply_sigmoid=True)
def encode(self, x):
mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
return mean, logvar
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * .5) + mean
def decode(self, z, apply_sigmoid=False):
logits = self.decoder(z)
if apply_sigmoid:
probs = tf.sigmoid(logits)
return probs
return logits
定义损失函数和优化器
VAE 通过最大化边缘对数似然的证据下界 (ELBO) 来训练
\[\log p(x) \ge \text{ELBO} = \mathbb{E}_{q(z|x)}\left[\log \frac{p(x, z)}{q(z|x)}\right].\]
在实践中,优化此期望的单样本蒙特卡罗估计
\[\log p(x| z) + \log p(z) - \log q(z|x),\]
其中 \(z\) 从 \(q(z|x)\) 中采样。
optimizer = tf.keras.optimizers.Adam(1e-4)
def log_normal_pdf(sample, mean, logvar, raxis=1):
log2pi = tf.math.log(2. * np.pi)
return tf.reduce_sum(
-.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
axis=raxis)
def compute_loss(model, x):
mean, logvar = model.encode(x)
z = model.reparameterize(mean, logvar)
x_logit = model.decode(z)
cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
logpz = log_normal_pdf(z, 0., 0.)
logqz_x = log_normal_pdf(z, mean, logvar)
return -tf.reduce_mean(logpx_z + logpz - logqz_x)
@tf.function
def train_step(model, x, optimizer):
"""Executes one training step and returns the loss.
This function computes the loss and gradients, and uses the latter to
update the model's parameters.
"""
with tf.GradientTape() as tape:
loss = compute_loss(model, x)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
训练
- 首先迭代数据集
- 在每次迭代中,将图像传递给编码器以获得近似后验 \(q(z|x)\) 的均值和对数方差参数集
- 然后应用 重新参数化技巧 从 \(q(z|x)\) 中采样
- 最后,将重新参数化的样本传递给解码器以获得生成分布 \(p(x|z)\) 的 logits
- 注意:由于使用 keras 加载的数据集在训练集中有 60k 个数据点,在测试集中有 10k 个数据点,因此我们得到的测试集 ELBO 略高于文献中报道的结果,文献中使用的是 Larochelle 的 MNIST 的动态二值化。
生成图像
- 训练完成后,就可以生成一些图像了
- 首先从单位高斯先验分布 \(p(z)\) 中采样一组潜在向量
- 然后,生成器将潜在样本 \(z\) 转换为观测值的 logits,从而得到分布 \(p(x|z)\)
- 这里,绘制伯努利分布的概率
epochs = 10
# set the dimensionality of the latent space to a plane for visualization later
latent_dim = 2
num_examples_to_generate = 16
# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
shape=[num_examples_to_generate, latent_dim])
model = CVAE(latent_dim)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/core/input_layer.py:25: UserWarning: Argument `input_shape` is deprecated. Use `shape` instead. warnings.warn(
def generate_and_save_images(model, epoch, test_sample):
mean, logvar = model.encode(test_sample)
z = model.reparameterize(mean, logvar)
predictions = model.sample(z)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0], cmap='gray')
plt.axis('off')
# tight_layout minimizes the overlap between 2 sub-plots
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
# Pick a sample of the test set for generating output images
assert batch_size >= num_examples_to_generate
for test_batch in test_dataset.take(1):
test_sample = test_batch[0:num_examples_to_generate, :, :, :]
2024-03-13 03:09:51.643821: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
generate_and_save_images(model, 0, test_sample)
for epoch in range(1, epochs + 1):
start_time = time.time()
for train_x in train_dataset:
train_step(model, train_x, optimizer)
end_time = time.time()
loss = tf.keras.metrics.Mean()
for test_x in test_dataset:
loss(compute_loss(model, test_x))
elbo = -loss.result()
display.clear_output(wait=False)
print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
.format(epoch, elbo, end_time - start_time))
generate_and_save_images(model, epoch, test_sample)
Epoch: 10, Test set ELBO: -156.40652465820312, time elapse for current epoch: 7.916414976119995
显示来自最后一个训练周期的生成图像
def display_image(epoch_no):
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
plt.imshow(display_image(epoch))
plt.axis('off') # Display images
(-0.5, 399.5, 399.5, -0.5)
显示所有保存图像的动画 GIF
anim_file = 'cvae.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
/tmpfs/tmp/ipykernel_129481/1290275450.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. image = imageio.imread(filename) /tmpfs/tmp/ipykernel_129481/1290275450.py:9: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. image = imageio.imread(filename)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)
显示来自潜在空间的数字的二维流形
运行下面的代码将显示不同数字类的连续分布,每个数字在二维潜在空间中相互变形。使用 TensorFlow Probability 为潜在空间生成标准正态分布。
def plot_latent_images(model, n, digit_size=28):
"""Plots n x n digit images decoded from the latent space."""
norm = tfp.distributions.Normal(0, 1)
grid_x = norm.quantile(np.linspace(0.05, 0.95, n))
grid_y = norm.quantile(np.linspace(0.05, 0.95, n))
image_width = digit_size*n
image_height = image_width
image = np.zeros((image_height, image_width))
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z = np.array([[xi, yi]])
x_decoded = model.sample(z)
digit = tf.reshape(x_decoded[0], (digit_size, digit_size))
image[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit.numpy()
plt.figure(figsize=(10, 10))
plt.imshow(image, cmap='Greys_r')
plt.axis('Off')
plt.show()
plot_latent_images(model, 20)
下一步
本教程演示了如何使用 TensorFlow 实现卷积变分自动编码器。
下一步,您可以尝试通过增加网络大小来改进模型输出。例如,您可以尝试将每个 Conv2D
和 Conv2DTranspose
层的 filter
参数设置为 512。请注意,为了生成最终的二维潜在图像图,您需要将 latent_dim
保持为 2。此外,随着网络大小的增加,训练时间也会增加。
您还可以尝试使用其他数据集(例如 CIFAR-10)实现 VAE。
VAE 可以以多种不同的风格和不同的复杂程度实现。您可以在以下来源中找到其他实现
如果您想了解更多关于 VAE 的详细信息,请参考 变分自动编码器简介。