本教程对 TensorFlow Model Garden 包 (tensorflow-models
) 中的残差网络 (ResNet) 进行微调,以对 CIFAR 数据集中的图像进行分类。
Model Garden 包含一组使用 TensorFlow 高级 API 实现的最先进的视觉模型。这些实现展示了建模的最佳实践,使用户能够充分利用 TensorFlow 进行研究和产品开发。
本教程使用 ResNet 模型,这是一个最先进的图像分类器。本教程使用 ResNet-18 模型,这是一个具有 18 层的卷积神经网络。
- 使用 TensorFlow 模型包中的模型。
- 对预构建的 ResNet 进行微调以进行图像分类。
- 导出经过微调的 ResNet 模型。
pip install -U -q "tf-models-official"
导入 TensorFlow、TensorFlow Datasets 和一些辅助库。
import pprint
import tempfile
from IPython import display
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
包包含 ResNet 视觉模型,而 official.vision.serving
import tensorflow_models as tfm
# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib
为 Cifar-10 数据集配置 ResNet-18 模型
CIFAR10 数据集包含 60,000 张彩色图像,分为 10 个互斥的类别,每个类别包含 6,000 张图像。
在 Model Garden 中,定义模型的一组参数称为配置。Model Garden 可以通过 工厂 根据一组已知参数创建配置。
使用 resnet_imagenet
工厂配置,如 tfm.vision.configs.image_classification.image_classification_imagenet
中定义。该配置设置为训练 ResNet 以收敛于 ImageNet。
exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(
2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... tfds.core.DatasetInfo( name='cifar10', full_name='cifar10/3.0.2', description=""" The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. """, homepage='https://www.cs.toronto.edu/~kriz/cifar.html', data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2', file_format=tfrecord, download_size=162.17 MiB, dataset_size=132.40 MiB, features=FeaturesDict({ 'id': Text(shape=(), dtype=string), 'image': Image(shape=(32, 32, 3), dtype=uint8), 'label': ClassLabel(shape=(), dtype=int64, num_classes=10), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=50000, num_shards=1>, }, citation="""@TECHREPORT{Krizhevsky09learningmultiple, author = {Alex Krizhevsky}, title = {Learning multiple layers of features from tiny images}, institution = {}, year = {2009} }""", )
调整模型和数据集配置,使其适用于 Cifar-10 (cifar10
# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18
# Configure training and testing data
batch_size = 128
exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size
exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
if device=='CPU':
train_steps = 20
exp_config.trainer.steps_per_loop = 5
exp_config.trainer.steps_per_loop = 100
exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
Running on CPU is slow, so only train for a few steps.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if exp_config.runtime.mixed_precision_dtype == tf.float16:
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Warning: this will be really slow.
从 config_definitions.TaskConfig
创建 Task
对象 (tfm.core.base_task.Task
该 Task
对象包含构建数据集、构建模型以及运行训练和评估所需的所有方法。这些方法由 tfm.core.train_lib.run_experiment
with distribution_strategy.scope():
model_dir = tempfile.mkdtemp()
task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')
数据加载器使用 preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
应用 z 分数归一化,因此数据集返回的图像无法通过标准工具直接显示。可视化代码需要将数据重新缩放到 [0,1] 范围内。
使用 ds_info
(它是 tfds.core.DatasetInfo
的实例)来查找每个类 ID 的文本描述。
label_info = ds_info.features['label']
def show_batch(images, labels, predictions=None):
plt.figure(figsize=(10, 10))
min = images.numpy().min()
max = images.numpy().max()
delta = max - min
for i in range(12):
plt.subplot(6, 6, i + 1)
plt.imshow((images[i]-min) / delta)
if predictions is None:
if labels[i] == predictions[i]:
color = 'g'
color = 'r'
plt.title(label_info.int2str(predictions[i]), color=color)
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
show_batch(images, labels)
plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
show_batch(images, labels)
model, eval_logs = tfm.core.train_lib.run_experiment(
restoring or initializing model... INFO:tensorflow:Customized initialization is done through the passed `init_fn`. INFO:tensorflow:Customized initialization is done through the passed `init_fn`. train | step: 0 | training until step 20... 2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. train | step: 5 | steps/sec: 0.5 | output: {'accuracy': 0.103125, 'learning_rate': 0.0, 'top_5_accuracy': 0.4828125, 'training_loss': 2.7998607} saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5. train | step: 10 | steps/sec: 0.8 | output: {'accuracy': 0.0828125, 'learning_rate': 0.0, 'top_5_accuracy': 0.4984375, 'training_loss': 2.8205295} train | step: 15 | steps/sec: 0.8 | output: {'accuracy': 0.0921875, 'learning_rate': 0.0, 'top_5_accuracy': 0.503125, 'training_loss': 2.8169343} train | step: 20 | steps/sec: 0.8 | output: {'accuracy': 0.1015625, 'learning_rate': 0.0, 'top_5_accuracy': 0.45, 'training_loss': 2.8760865} eval | step: 20 | running 78 steps of evaluation... eval | step: 20 | steps/sec: 24.4 | eval time: 3.2 sec | output: {'accuracy': 0.09485176, 'steps_per_second': 24.40085348913806, 'top_5_accuracy': 0.49589342, 'validation_loss': 2.5864375} saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20. 2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://tensorflowcn.cn/guide/migrate for instructions on how to migrate your code to TensorFlow v2. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://tensorflowcn.cn/guide/migrate for instructions on how to migrate your code to TensorFlow v2. eval | step: 20 | running 78 steps of evaluation... 2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. eval | step: 20 | steps/sec: 40.1 | eval time: 1.9 sec | output: {'accuracy': 0.09485176, 'steps_per_second': 40.14298727815298, 'top_5_accuracy': 0.49589342, 'validation_loss': 2.5864375}
# tf.keras.utils.plot_model(model, show_shapes=True)
打印 accuracy
和 validation_loss
for key, value in eval_logs.items():
if isinstance(value, tf.Tensor):
value = value.numpy()
print(f'{key:20}: {value:.3f}')
accuracy : 0.095 top_5_accuracy : 0.496 validation_loss : 2.586 steps_per_second : 40.143
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
predictions = model.predict(images)
predictions = tf.argmax(predictions, axis=-1)
show_batch(images, labels, tf.cast(predictions, tf.int32))
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
导出 SavedModel
由 train_lib.run_experiment
返回的 keras.Model
对象期望数据由数据集加载器使用 preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
中相同的均值和方差统计信息进行归一化。此导出函数处理这些细节,因此您可以传递 tf.uint8
# Saving and exporting the trained model
input_image_size=[32, 32],
INFO:tensorflow:Assets written to: ./export/assets INFO:tensorflow:Assets written to: ./export/assets
# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']
plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
predictions = []
for image in data['image']:
index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
show_batch(data['image'], data['label'], predictions)
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
