使用 Model Garden 进行图像分类

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

本教程对 TensorFlow Model Garden 包 (tensorflow-models) 中的残差网络 (ResNet) 进行微调,以对 CIFAR 数据集中的图像进行分类。

Model Garden 包含一组使用 TensorFlow 高级 API 实现的最先进的视觉模型。这些实现展示了建模的最佳实践,使用户能够充分利用 TensorFlow 进行研究和产品开发。

本教程使用 ResNet 模型,这是一个最先进的图像分类器。本教程使用 ResNet-18 模型,这是一个具有 18 层的卷积神经网络。

本教程演示了如何

  1. 使用 TensorFlow 模型包中的模型。
  2. 对预构建的 ResNet 进行微调以进行图像分类。
  3. 导出经过微调的 ResNet 模型。

设置

安装并导入必要的模块。

pip install -U -q "tf-models-official"

导入 TensorFlow、TensorFlow Datasets 和一些辅助库。

import pprint
import tempfile

from IPython import display
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-17 11:52:54.005237: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-17 11:52:54.005294: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-17 11:52:54.005338: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

tensorflow_models 包包含 ResNet 视觉模型,而 official.vision.serving 模型包含用于保存和导出经过微调模型的函数。

import tensorflow_models as tfm

# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib

为 Cifar-10 数据集配置 ResNet-18 模型

CIFAR10 数据集包含 60,000 张彩色图像,分为 10 个互斥的类别,每个类别包含 6,000 张图像。

在 Model Garden 中,定义模型的一组参数称为配置。Model Garden 可以通过 工厂 根据一组已知参数创建配置。

使用 resnet_imagenet 工厂配置,如 tfm.vision.configs.image_classification.image_classification_imagenet 中定义。该配置设置为训练 ResNet 以收敛于 ImageNet

exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds,ds_info = tfds.load(
tfds_name,
with_info=True)
ds_info
2023-10-17 11:52:59.285390: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_dir='gs://tensorflow-datasets/datasets/cifar10/3.0.2',
    file_format=tfrecord,
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=string),
        'image': Image(shape=(32, 32, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learning multiple layers of features from tiny images},
        institution = {},
        year = {2009}
    }""",
)

调整模型和数据集配置,使其适用于 Cifar-10 (cifar10)。

# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18

# Configure training and testing data
batch_size = 128

exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size

exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size

调整训练器配置。

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'TPU'
else:
  print('Running on CPU is slow, so only train for a few steps.')
  device = 'CPU'

if device=='CPU':
  train_steps = 20
  exp_config.trainer.steps_per_loop = 5
else:
  train_steps=5000
  exp_config.trainer.steps_per_loop = 100

exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
Running on CPU is slow, so only train for a few steps.

打印修改后的配置。

pprint.pprint(exp_config.as_dict())

display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None,
             'batchnorm_spatial_persistent': False,
             'dataset_num_private_threads': None,
             'default_shard_dim': -1,
             'distribution_strategy': 'mirrored',
             'enable_xla': True,
             'gpu_thread_mode': None,
             'loss_scale': None,
             'mixed_precision_dtype': None,
             'num_cores_per_replica': 1,
             'num_gpus': 0,
             'num_packs': 1,
             'per_gpu_thread_count': 0,
             'run_eagerly': False,
             'task_index': -1,
             'tpu': None,
             'tpu_enable_xla_dynamic_padder': None,
             'use_tpu_mp_strategy': False,
             'worker_hosts': None},
 'task': {'allow_image_summary': False,
          'differential_privacy_config': None,
          'eval_input_partition_dims': [],
          'evaluation': {'precision_and_recall_thresholds': None,
                         'report_per_class_precision_and_recall': False,
                         'top_k': 5},
          'freeze_backbone': False,
          'init_checkpoint': None,
          'init_checkpoint_modules': 'all',
          'losses': {'l2_weight_decay': 0.0001,
                     'label_smoothing': 0.0,
                     'loss_weight': 1.0,
                     'one_hot': True,
                     'soft_labels': False,
                     'use_binary_cross_entropy': False},
          'model': {'add_head_batch_norm': False,
                    'backbone': {'resnet': {'bn_trainable': True,
                                            'depth_multiplier': 1.0,
                                            'model_id': 18,
                                            'replace_stem_max_pool': False,
                                            'resnetd_shortcut': False,
                                            'scale_stem': True,
                                            'se_ratio': 0.0,
                                            'stem_type': 'v0',
                                            'stochastic_depth_drop_rate': 0.0},
                                 'type': 'resnet'},
                    'dropout_rate': 0.0,
                    'input_size': [32, 32, 3],
                    'kernel_initializer': 'random_uniform',
                    'norm_activation': {'activation': 'relu',
                                        'norm_epsilon': 1e-05,
                                        'norm_momentum': 0.9,
                                        'use_sync_bn': False},
                    'num_classes': 10,
                    'output_softmax': False},
          'model_output_keys': [],
          'name': None,
          'train_data': {'apply_tf_data_service_before_batching': False,
                         'aug_crop': True,
                         'aug_policy': None,
                         'aug_rand_hflip': True,
                         'aug_type': None,
                         'autotune_algorithm': None,
                         'block_length': 1,
                         'cache': False,
                         'center_crop_fraction': 0.875,
                         'color_jitter': 0.0,
                         'crop_area_range': (0.08, 1.0),
                         'cycle_length': 10,
                         'decode_jpeg_only': True,
                         'decoder': {'simple_decoder': {'attribute_names': [],
                                                        'mask_binarize_threshold': None,
                                                        'regenerate_source_id': False},
                                     'type': 'simple_decoder'},
                         'deterministic': None,
                         'drop_remainder': True,
                         'dtype': 'float32',
                         'enable_shared_tf_data_service_between_parallel_trainers': False,
                         'enable_tf_data_service': False,
                         'file_type': 'tfrecord',
                         'global_batch_size': 128,
                         'image_field_key': 'image/encoded',
                         'input_path': '',
                         'is_multilabel': False,
                         'is_training': True,
                         'label_field_key': 'image/class/label',
                         'mixup_and_cutmix': None,
                         'prefetch_buffer_size': None,
                         'randaug_magnitude': 10,
                         'random_erasing': None,
                         'repeated_augment': None,
                         'seed': None,
                         'sharding': True,
                         'shuffle_buffer_size': 10000,
                         'tf_data_service_address': None,
                         'tf_data_service_job_name': None,
                         'tf_resize_method': 'bilinear',
                         'tfds_as_supervised': False,
                         'tfds_data_dir': '',
                         'tfds_name': 'cifar10',
                         'tfds_skip_decoding_feature': '',
                         'tfds_split': 'train',
                         'three_augment': False,
                         'trainer_id': None,
                         'weights': None},
          'train_input_partition_dims': [],
          'validation_data': {'apply_tf_data_service_before_batching': False,
                              'aug_crop': True,
                              'aug_policy': None,
                              'aug_rand_hflip': True,
                              'aug_type': None,
                              'autotune_algorithm': None,
                              'block_length': 1,
                              'cache': False,
                              'center_crop_fraction': 0.875,
                              'color_jitter': 0.0,
                              'crop_area_range': (0.08, 1.0),
                              'cycle_length': 10,
                              'decode_jpeg_only': True,
                              'decoder': {'simple_decoder': {'attribute_names': [],
                                                             'mask_binarize_threshold': None,
                                                             'regenerate_source_id': False},
                                          'type': 'simple_decoder'},
                              'deterministic': None,
                              'drop_remainder': True,
                              'dtype': 'float32',
                              'enable_shared_tf_data_service_between_parallel_trainers': False,
                              'enable_tf_data_service': False,
                              'file_type': 'tfrecord',
                              'global_batch_size': 128,
                              'image_field_key': 'image/encoded',
                              'input_path': '',
                              'is_multilabel': False,
                              'is_training': False,
                              'label_field_key': 'image/class/label',
                              'mixup_and_cutmix': None,
                              'prefetch_buffer_size': None,
                              'randaug_magnitude': 10,
                              'random_erasing': None,
                              'repeated_augment': None,
                              'seed': None,
                              'sharding': True,
                              'shuffle_buffer_size': 10000,
                              'tf_data_service_address': None,
                              'tf_data_service_job_name': None,
                              'tf_resize_method': 'bilinear',
                              'tfds_as_supervised': False,
                              'tfds_data_dir': '',
                              'tfds_name': 'cifar10',
                              'tfds_skip_decoding_feature': '',
                              'tfds_split': 'test',
                              'three_augment': False,
                              'trainer_id': None,
                              'weights': None} },
 'trainer': {'allow_tpu_summary': False,
             'best_checkpoint_eval_metric': '',
             'best_checkpoint_export_subdir': '',
             'best_checkpoint_metric_comp': 'higher',
             'checkpoint_interval': 20,
             'continuous_eval_timeout': 3600,
             'eval_tf_function': True,
             'eval_tf_while_loop': False,
             'loss_upper_bound': 1000000.0,
             'max_to_keep': 5,
             'optimizer_config': {'ema': None,
                                  'learning_rate': {'cosine': {'alpha': 0.0,
                                                               'decay_steps': 20,
                                                               'initial_learning_rate': 0.1,
                                                               'name': 'CosineDecay',
                                                               'offset': 0},
                                                    'type': 'cosine'},
                                  'optimizer': {'sgd': {'clipnorm': None,
                                                        'clipvalue': None,
                                                        'decay': 0.0,
                                                        'global_clipnorm': None,
                                                        'momentum': 0.9,
                                                        'name': 'SGD',
                                                        'nesterov': False},
                                                'type': 'sgd'},
                                  'warmup': {'linear': {'name': 'linear',
                                                        'warmup_learning_rate': 0,
                                                        'warmup_steps': 100},
                                             'type': 'linear'} },
             'preemption_on_demand_checkpoint': True,
             'recovery_begin_steps': 0,
             'recovery_max_trials': 0,
             'steps_per_loop': 5,
             'summary_interval': 100,
             'train_steps': 20,
             'train_tf_function': True,
             'train_tf_while_loop': True,
             'validation_interval': 1000,
             'validation_steps': 78,
             'validation_summary_subdir': 'validation'} }
<IPython.core.display.Javascript object>

设置分布式策略。

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if exp_config.runtime.mixed_precision_dtype == tf.float16:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tf.tpu.experimental.initialize_tpu_system()
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Warning: this will be really slow.

config_definitions.TaskConfig 创建 Task 对象 (tfm.core.base_task.Task)。

Task 对象包含构建数据集、构建模型以及运行训练和评估所需的所有方法。这些方法由 tfm.core.train_lib.run_experiment 驱动。

with distribution_strategy.scope():
  model_dir = tempfile.mkdtemp()
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

#  tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  print()
  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')
  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')
images.shape: (128, 32, 32, 3)  images.dtype: tf.float32
labels.shape: (128,)            labels.dtype: tf.int32
2023-10-17 11:53:02.248801: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

可视化训练数据

数据加载器使用 preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB) 应用 z 分数归一化,因此数据集返回的图像无法通过标准工具直接显示。可视化代码需要将数据重新缩放到 [0,1] 范围内。

plt.hist(images.numpy().flatten());

使用 ds_info(它是 tfds.core.DatasetInfo 的实例)来查找每个类 ID 的文本描述。

label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'

可视化一批数据。

def show_batch(images, labels, predictions=None):
  plt.figure(figsize=(10, 10))
  min = images.numpy().min()
  max = images.numpy().max()
  delta = max - min

  for i in range(12):
    plt.subplot(6, 6, i + 1)
    plt.imshow((images[i]-min) / delta)
    if predictions is None:
      plt.title(label_info.int2str(labels[i]))
    else:
      if labels[i] == predictions[i]:
        color = 'g'
      else:
        color = 'r'
      plt.title(label_info.int2str(predictions[i]), color=color)
    plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  show_batch(images, labels)
2023-10-17 11:53:04.198417: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

可视化测试数据

可视化验证数据集中的图像批次。

plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
  show_batch(images, labels)
2023-10-17 11:53:07.007846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

训练和评估

model, eval_logs = tfm.core.train_lib.run_experiment(
    distribution_strategy=distribution_strategy,
    task=task,
    mode='train_and_eval',
    params=exp_config,
    model_dir=model_dir,
    run_post_eval=True)
restoring or initializing model...
INFO:tensorflow:Customized initialization is done through the passed `init_fn`.
INFO:tensorflow:Customized initialization is done through the passed `init_fn`.
train | step:      0 | training until step 20...
2023-10-17 11:53:09.849007: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
train | step:      5 | steps/sec:    0.5 | output: 
    {'accuracy': 0.103125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4828125,
     'training_loss': 2.7998607}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-5.
train | step:     10 | steps/sec:    0.8 | output: 
    {'accuracy': 0.0828125,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.4984375,
     'training_loss': 2.8205295}
train | step:     15 | steps/sec:    0.8 | output: 
    {'accuracy': 0.0921875,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.503125,
     'training_loss': 2.8169343}
train | step:     20 | steps/sec:    0.8 | output: 
    {'accuracy': 0.1015625,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.45,
     'training_loss': 2.8760865}
 eval | step:     20 | running 78 steps of evaluation...
 eval | step:     20 | steps/sec:   24.4 | eval time:    3.2 sec | output: 
    {'accuracy': 0.09485176,
     'steps_per_second': 24.40085348913806,
     'top_5_accuracy': 0.49589342,
     'validation_loss': 2.5864375}
saved checkpoint to /tmpfs/tmp/tmpu0ate1h5/ckpt-20.
2023-10-17 11:53:43.844533: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://tensorflowcn.cn/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5253: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://tensorflowcn.cn/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
eval | step:     20 | running 78 steps of evaluation...
2023-10-17 11:53:45.627213: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
eval | step:     20 | steps/sec:   40.1 | eval time:    1.9 sec | output: 
    {'accuracy': 0.09485176,
     'steps_per_second': 40.14298727815298,
     'top_5_accuracy': 0.49589342,
     'validation_loss': 2.5864375}
#  tf.keras.utils.plot_model(model, show_shapes=True)

打印 accuracytop_5_accuracyvalidation_loss 评估指标。

for key, value in eval_logs.items():
    if isinstance(value, tf.Tensor):
      value = value.numpy()
    print(f'{key:20}: {value:.3f}')
accuracy            : 0.095
top_5_accuracy      : 0.496
validation_loss     : 2.586
steps_per_second    : 40.143

将经过处理的训练数据的批次通过模型,并查看结果

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  predictions = model.predict(images)
  predictions = tf.argmax(predictions, axis=-1)

show_batch(images, labels, tf.cast(predictions, tf.int32))

if device=='CPU':
  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
2023-10-17 11:53:49.840600: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
4/4 [==============================] - 1s 13ms/step
2023-10-17 11:53:50.778301: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

导出 SavedModel

train_lib.run_experiment 返回的 keras.Model 对象期望数据由数据集加载器使用 preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB) 中相同的均值和方差统计信息进行归一化。此导出函数处理这些细节,因此您可以传递 tf.uint8 图像并获得正确的结果。

# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
    input_type='image_tensor',
    batch_size=1,
    input_image_size=[32, 32],
    params=exp_config,
    checkpoint_path=tf.train.latest_checkpoint(model_dir),
    export_dir='./export/')
INFO:tensorflow:Assets written to: ./export/assets
INFO:tensorflow:Assets written to: ./export/assets

测试导出的模型。

# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']

可视化预测。

plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
  predictions = []
  for image in data['image']:
    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
    predictions.append(index)
  show_batch(data['image'], data['label'], predictions)

  if device=='CPU':
    plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
2023-10-17 11:54:01.438509: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.