在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
本教程演示了如何使用 深度卷积生成对抗网络 (DCGAN) 生成手写数字图像。代码使用 Keras 顺序 API 编写,并使用 tf.GradientTape
训练循环。
什么是 GAN?
生成对抗网络 (GAN) 是当今计算机科学中最有趣的思想之一。两个模型通过对抗过程同时训练。一个生成器(“艺术家”)学习创建看起来真实的图像,而一个鉴别器(“艺术评论家”)学习区分真实图像和假图像。
在训练过程中,生成器在创建看起来真实的图像方面逐渐变得更好,而鉴别器在区分它们方面也逐渐变得更好。当鉴别器无法再区分真实图像和假图像时,该过程达到平衡。
本笔记本在 MNIST 数据集上演示了此过程。以下动画显示了在训练 50 个时期时生成器生成的一系列图像。这些图像最初是随机噪声,随着时间的推移,越来越像手写数字。
要了解有关 GAN 的更多信息,请参阅 MIT 的 深度学习入门 课程。
设置
import tensorflow as tf
tf.__version__
'2.16.1'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
加载和准备数据集
您将使用 MNIST 数据集来训练生成器和鉴别器。生成器将生成类似于 MNIST 数据的手写数字。
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
创建模型
生成器和鉴别器都是使用 Keras 顺序 API 定义的。
生成器
生成器使用 tf.keras.layers.Conv2DTranspose
(上采样)层从种子(随机噪声)生成图像。从一个 Dense
层开始,该层以种子作为输入,然后进行多次上采样,直到达到所需的图像大小 28x28x1。注意,除了使用 tanh 的输出层外,每个层都使用 tf.keras.layers.LeakyReLU
激活。
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
使用(尚未训练的)生成器创建图像。
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/core/dense.py:88: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs) <matplotlib.image.AxesImage at 0x7f1c44809fa0>
鉴别器
鉴别器是一个基于 CNN 的图像分类器。
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
使用(尚未训练的)鉴别器将生成的图像分类为真实或虚假。该模型将被训练以对真实图像输出正值,对虚假图像输出负值。
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[0.00231416]], shape=(1, 1), dtype=float32) /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/convolutional/base_conv.py:99: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(
定义损失和优化器
为两个模型定义损失函数和优化器。
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
鉴别器损失
此方法量化了鉴别器区分真实图像和虚假图像的能力。它将鉴别器对真实图像的预测与一个 1 数组进行比较,并将鉴别器对虚假(生成)图像的预测与一个 0 数组进行比较。
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
生成器损失
生成器的损失量化了它欺骗鉴别器的程度。直观地说,如果生成器表现良好,鉴别器将把虚假图像分类为真实(或 1)。在这里,将鉴别器对生成图像的决策与一个 1 数组进行比较。
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
鉴别器和生成器优化器是不同的,因为您将分别训练两个网络。
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
保存检查点
此笔记本还演示了如何保存和恢复模型,这在长时间运行的训练任务中断时非常有用。
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
定义训练循环
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])
训练循环从生成器接收随机种子作为输入开始。该种子用于生成图像。然后使用鉴别器对真实图像(从训练集中提取)和虚假图像(由生成器生成)进行分类。计算每个模型的损失,并使用梯度更新生成器和鉴别器。
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as you go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
生成并保存图像
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
训练模型
调用上面定义的 train()
方法来同时训练生成器和鉴别器。注意,训练 GAN 可能很棘手。重要的是生成器和鉴别器不要相互压制(例如,它们以相似的速率训练)。
在训练开始时,生成的图像看起来像随机噪声。随着训练的进行,生成的数字将看起来越来越真实。在大约 50 个 epoch 后,它们类似于 MNIST 数字。这可能需要大约 1 分钟/epoch(使用 Colab 上的默认设置)。
train(train_dataset, EPOCHS)
恢复最新的检查点。
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f1c44809d00>
创建 GIF
# Display a single image using the epoch number
def display_image(epoch_no):
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)
使用 imageio
使用训练期间保存的图像创建动画 GIF。
anim_file = 'dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
/tmpfs/tmp/ipykernel_125567/1982054950.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. image = imageio.imread(filename) /tmpfs/tmp/ipykernel_125567/1982054950.py:9: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. image = imageio.imread(filename)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)
下一步
本教程展示了编写和训练 GAN 所需的完整代码。作为下一步,您可能希望尝试使用不同的数据集,例如 Kaggle 上可用的 大规模名人面部属性(CelebA)数据集。要了解有关 GAN 的更多信息,请参阅 NIPS 2016 教程:生成对抗网络。