迁移提前停止

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

此笔记本演示了如何设置带有提前停止的模型训练,首先使用 TensorFlow 1 中的 tf.estimator.Estimator 和提前停止钩子,然后使用 TensorFlow 2 中的 Keras API 或自定义训练循环。提前停止是一种正则化技术,如果例如验证损失达到某个阈值,它将停止训练。

在 TensorFlow 2 中,有三种方法可以实现提前停止

设置

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-10-04 01:37:29.125012: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-04 01:37:29.125061: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-04 01:37:29.125095: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow 1:使用提前停止钩子和 tf.estimator 进行提前停止

首先定义用于 MNIST 数据集加载和预处理的函数,以及要与 tf.estimator.Estimator 一起使用的模型定义

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

在 TensorFlow 1 中,提前停止通过使用 tf.estimator.experimental.make_early_stopping_hook 设置提前停止钩子来实现。您将钩子作为 make_early_stopping_hook 方法的参数传递给 should_stop_fn,它可以接受一个没有任何参数的函数。一旦 should_stop_fn 返回 True,训练就会停止。

以下示例演示了如何实现一种提前停止技术,该技术将训练时间限制在最多 20 秒

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpp8x_ipb8
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpp8x_ipb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
2023-10-04 01:37:32.147816: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
2023-10-04 01:37:34.401540: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 2.2961025, step = 0
INFO:tensorflow:loss = 2.2961025, step = 0
INFO:tensorflow:global_step/sec: 335.823
INFO:tensorflow:global_step/sec: 335.823
INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec)
INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec)
INFO:tensorflow:global_step/sec: 423.864
INFO:tensorflow:global_step/sec: 423.864
INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec)
INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec)
INFO:tensorflow:global_step/sec: 436.472
INFO:tensorflow:global_step/sec: 436.472
INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec)
INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec)
INFO:tensorflow:global_step/sec: 432.337
INFO:tensorflow:global_step/sec: 432.337
INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec)
INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec)
INFO:tensorflow:global_step/sec: 425.537
INFO:tensorflow:global_step/sec: 425.537
INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec)
INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec)
INFO:tensorflow:global_step/sec: 505.423
INFO:tensorflow:global_step/sec: 505.423
INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec)
INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec)
INFO:tensorflow:global_step/sec: 518.525
INFO:tensorflow:global_step/sec: 518.525
INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec)
INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec)
INFO:tensorflow:global_step/sec: 509.36
INFO:tensorflow:global_step/sec: 509.36
INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec)
INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec)
INFO:tensorflow:global_step/sec: 531.176
INFO:tensorflow:global_step/sec: 531.176
INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec)
INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec)
INFO:tensorflow:global_step/sec: 465.699
INFO:tensorflow:global_step/sec: 465.699
INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec)
INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec)
INFO:tensorflow:global_step/sec: 495.799
INFO:tensorflow:global_step/sec: 495.799
INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec)
INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec)
INFO:tensorflow:global_step/sec: 504.403
INFO:tensorflow:global_step/sec: 504.403
INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec)
INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec)
INFO:tensorflow:global_step/sec: 482.232
INFO:tensorflow:global_step/sec: 482.232
INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec)
INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec)
INFO:tensorflow:global_step/sec: 501.553
INFO:tensorflow:global_step/sec: 501.553
INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec)
INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec)
INFO:tensorflow:global_step/sec: 459.337
INFO:tensorflow:global_step/sec: 459.337
INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec)
INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec)
INFO:tensorflow:global_step/sec: 511.454
INFO:tensorflow:global_step/sec: 511.454
INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec)
INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec)
INFO:tensorflow:global_step/sec: 487.452
INFO:tensorflow:global_step/sec: 487.452
INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec)
INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec)
INFO:tensorflow:global_step/sec: 500.287
INFO:tensorflow:global_step/sec: 500.287
INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec)
INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec)
INFO:tensorflow:global_step/sec: 476.868
INFO:tensorflow:global_step/sec: 476.868
INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec)
INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec)
INFO:tensorflow:global_step/sec: 484.689
INFO:tensorflow:global_step/sec: 484.689
INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec)
INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec)
INFO:tensorflow:global_step/sec: 496.691
INFO:tensorflow:global_step/sec: 496.691
INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec)
INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec)
INFO:tensorflow:global_step/sec: 499.682
INFO:tensorflow:global_step/sec: 499.682
INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec)
INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec)
INFO:tensorflow:global_step/sec: 507.391
INFO:tensorflow:global_step/sec: 507.391
INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec)
INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.414
INFO:tensorflow:global_step/sec: 475.414
INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec)
INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec)
INFO:tensorflow:global_step/sec: 499.39
INFO:tensorflow:global_step/sec: 499.39
INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec)
INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec)
INFO:tensorflow:global_step/sec: 500.659
INFO:tensorflow:global_step/sec: 500.659
INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec)
INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec)
INFO:tensorflow:global_step/sec: 500.201
INFO:tensorflow:global_step/sec: 500.201
INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec)
INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec)
INFO:tensorflow:global_step/sec: 517.488
INFO:tensorflow:global_step/sec: 517.488
INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec)
INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec)
INFO:tensorflow:global_step/sec: 466.55
INFO:tensorflow:global_step/sec: 466.55
INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec)
INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec)
INFO:tensorflow:global_step/sec: 499.818
INFO:tensorflow:global_step/sec: 499.818
INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec)
INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec)
INFO:tensorflow:global_step/sec: 477.253
INFO:tensorflow:global_step/sec: 477.253
INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec)
INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec)
INFO:tensorflow:global_step/sec: 512.079
INFO:tensorflow:global_step/sec: 512.079
INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec)
INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec)
INFO:tensorflow:global_step/sec: 461.596
INFO:tensorflow:global_step/sec: 461.596
INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec)
INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec)
INFO:tensorflow:global_step/sec: 488.092
INFO:tensorflow:global_step/sec: 488.092
INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec)
INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec)
INFO:tensorflow:global_step/sec: 490.26
INFO:tensorflow:global_step/sec: 490.26
INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec)
INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec)
INFO:tensorflow:global_step/sec: 487.323
INFO:tensorflow:global_step/sec: 487.323
INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec)
INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec)
INFO:tensorflow:global_step/sec: 492.262
INFO:tensorflow:global_step/sec: 492.262
INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec)
INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec)
INFO:tensorflow:global_step/sec: 449.227
INFO:tensorflow:global_step/sec: 449.227
INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec)
INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec)
INFO:tensorflow:global_step/sec: 521.924
INFO:tensorflow:global_step/sec: 521.924
INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec)
INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec)
INFO:tensorflow:global_step/sec: 535.844
INFO:tensorflow:global_step/sec: 535.844
INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec)
INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec)
INFO:tensorflow:global_step/sec: 496.551
INFO:tensorflow:global_step/sec: 496.551
INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec)
INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec)
INFO:tensorflow:global_step/sec: 508.086
INFO:tensorflow:global_step/sec: 508.086
INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec)
INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec)
INFO:tensorflow:global_step/sec: 462.149
INFO:tensorflow:global_step/sec: 462.149
INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec)
INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec)
INFO:tensorflow:global_step/sec: 503.226
INFO:tensorflow:global_step/sec: 503.226
INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec)
INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec)
INFO:tensorflow:global_step/sec: 510.874
INFO:tensorflow:global_step/sec: 510.874
INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec)
INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec)
INFO:tensorflow:global_step/sec: 511.332
INFO:tensorflow:global_step/sec: 511.332
INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec)
INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.525
INFO:tensorflow:global_step/sec: 472.525
INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 505.262
INFO:tensorflow:global_step/sec: 505.262
INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 515.453
INFO:tensorflow:global_step/sec: 515.453
INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec)
INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec)
INFO:tensorflow:global_step/sec: 479.259
INFO:tensorflow:global_step/sec: 479.259
INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec)
INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec)
INFO:tensorflow:global_step/sec: 491.361
INFO:tensorflow:global_step/sec: 491.361
INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec)
INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec)
INFO:tensorflow:global_step/sec: 471.256
INFO:tensorflow:global_step/sec: 471.256
INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec)
INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec)
INFO:tensorflow:global_step/sec: 506.672
INFO:tensorflow:global_step/sec: 506.672
INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec)
INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec)
INFO:tensorflow:global_step/sec: 491.333
INFO:tensorflow:global_step/sec: 491.333
INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec)
INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec)
INFO:tensorflow:global_step/sec: 500.638
INFO:tensorflow:global_step/sec: 500.638
INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec)
INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec)
INFO:tensorflow:global_step/sec: 508.206
INFO:tensorflow:global_step/sec: 508.206
INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec)
INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec)
INFO:tensorflow:global_step/sec: 450.634
INFO:tensorflow:global_step/sec: 450.634
INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec)
INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec)
INFO:tensorflow:global_step/sec: 525.95
INFO:tensorflow:global_step/sec: 525.95
INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec)
INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec)
INFO:tensorflow:global_step/sec: 483.04
INFO:tensorflow:global_step/sec: 483.04
INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 480.064
INFO:tensorflow:global_step/sec: 480.064
INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec)
INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec)
INFO:tensorflow:global_step/sec: 451.147
INFO:tensorflow:global_step/sec: 451.147
INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec)
INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec)
INFO:tensorflow:global_step/sec: 490.621
INFO:tensorflow:global_step/sec: 490.621
INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec)
INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec)
INFO:tensorflow:global_step/sec: 499.641
INFO:tensorflow:global_step/sec: 499.641
INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec)
INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec)
INFO:tensorflow:global_step/sec: 498.738
INFO:tensorflow:global_step/sec: 498.738
INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec)
INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec)
INFO:tensorflow:global_step/sec: 523.196
INFO:tensorflow:global_step/sec: 523.196
INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec)
INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec)
INFO:tensorflow:global_step/sec: 455.054
INFO:tensorflow:global_step/sec: 455.054
INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec)
INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec)
INFO:tensorflow:global_step/sec: 522.506
INFO:tensorflow:global_step/sec: 522.506
INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec)
INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec)
INFO:tensorflow:global_step/sec: 506.766
INFO:tensorflow:global_step/sec: 506.766
INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec)
INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec)
INFO:tensorflow:global_step/sec: 481.896
INFO:tensorflow:global_step/sec: 481.896
INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec)
INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec)
INFO:tensorflow:global_step/sec: 506.833
INFO:tensorflow:global_step/sec: 506.833
INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec)
INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec)
INFO:tensorflow:global_step/sec: 467.501
INFO:tensorflow:global_step/sec: 467.501
INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec)
INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec)
INFO:tensorflow:global_step/sec: 494.573
INFO:tensorflow:global_step/sec: 494.573
INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec)
INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec)
INFO:tensorflow:global_step/sec: 522.627
INFO:tensorflow:global_step/sec: 522.627
INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec)
INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec)
INFO:tensorflow:global_step/sec: 486.106
INFO:tensorflow:global_step/sec: 486.106
INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec)
INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec)
INFO:tensorflow:global_step/sec: 479.817
INFO:tensorflow:global_step/sec: 479.817
INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec)
INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec)
INFO:tensorflow:global_step/sec: 487.355
INFO:tensorflow:global_step/sec: 487.355
INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec)
INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec)
INFO:tensorflow:global_step/sec: 526.322
INFO:tensorflow:global_step/sec: 526.322
INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec)
INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec)
INFO:tensorflow:global_step/sec: 501.783
INFO:tensorflow:global_step/sec: 501.783
INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec)
INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec)
INFO:tensorflow:global_step/sec: 485.409
INFO:tensorflow:global_step/sec: 485.409
INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec)
INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec)
INFO:tensorflow:global_step/sec: 462.017
INFO:tensorflow:global_step/sec: 462.017
INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec)
INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec)
INFO:tensorflow:global_step/sec: 488.439
INFO:tensorflow:global_step/sec: 488.439
INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec)
INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec)
INFO:tensorflow:global_step/sec: 485.238
INFO:tensorflow:global_step/sec: 485.238
INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec)
INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec)
INFO:tensorflow:Requesting early stopping at global step 8286
INFO:tensorflow:Requesting early stopping at global step 8286
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287...
INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52
INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
2023-10-04 01:37:52.957202: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.51915s
INFO:tensorflow:Inference Time : 0.51915s
INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53
INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53
INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305
INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Loss for final step: 0.28788015.
INFO:tensorflow:Loss for final step: 0.28788015.
({'loss': 0.21179305, 'global_step': 8287}, [])

TensorFlow 2:使用内置回调和 Model.fit 进行提前停止

准备 MNIST 数据集和一个简单的 Keras 模型

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

在 TensorFlow 2 中,当您使用内置 Keras Model.fit(或 Model.evaluate)时,您可以通过将内置回调 - tf.keras.callbacks.EarlyStopping - 传递给 Model.fitcallbacks 参数来配置提前停止。

EarlyStopping 回调会监控用户指定的指标,并在指标停止改进时结束训练。(查看 使用内置方法进行训练和评估API 文档 了解更多信息。)

以下是一个提前停止回调的示例,它会监控损失,并在显示无改进的时期数设置为 3 (patience) 后停止训练

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 3s 4ms/step - loss: 0.2326 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.1214 - val_sparse_categorical_accuracy: 0.9630
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.1004 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.1009 - val_sparse_categorical_accuracy: 0.9684
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0689 - sparse_categorical_accuracy: 0.9792 - val_loss: 0.1061 - val_sparse_categorical_accuracy: 0.9674
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0528 - sparse_categorical_accuracy: 0.9835 - val_loss: 0.1244 - val_sparse_categorical_accuracy: 0.9640
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0424 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1001 - val_sparse_categorical_accuracy: 0.9730
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0355 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9728
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.1029 - val_sparse_categorical_accuracy: 0.9748
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0319 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1267 - val_sparse_categorical_accuracy: 0.9707
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9733
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0281 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.1246 - val_sparse_categorical_accuracy: 0.9718
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0270 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1184 - val_sparse_categorical_accuracy: 0.9786
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0217 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1231 - val_sparse_categorical_accuracy: 0.9756
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.1466 - val_sparse_categorical_accuracy: 0.9716
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0220 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1475 - val_sparse_categorical_accuracy: 0.9754
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1472 - val_sparse_categorical_accuracy: 0.9746
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1454 - val_sparse_categorical_accuracy: 0.9755
16

TensorFlow 2:使用自定义回调和 Model.fit 进行提前停止

您还可以实现一个 自定义提前停止回调,它也可以传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。

在本例中,训练过程将在 self.model.stop_training 设置为 True 时停止。

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9785
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1728 - val_sparse_categorical_accuracy: 0.9755
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1567 - val_sparse_categorical_accuracy: 0.9787
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0198 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9740
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0146 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1658 - val_sparse_categorical_accuracy: 0.9756
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2150 - val_sparse_categorical_accuracy: 0.9726
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1890 - val_sparse_categorical_accuracy: 0.9769
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1852 - val_sparse_categorical_accuracy: 0.9772
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.2051 - val_sparse_categorical_accuracy: 0.9761
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9727
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.2008 - val_sparse_categorical_accuracy: 0.9757
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9755
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2270 - val_sparse_categorical_accuracy: 0.9739
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2169 - val_sparse_categorical_accuracy: 0.9775
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2282 - val_sparse_categorical_accuracy: 0.9773
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2723 - val_sparse_categorical_accuracy: 0.9738
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2223 - val_sparse_categorical_accuracy: 0.9784
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2489 - val_sparse_categorical_accuracy: 0.9770
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2607 - val_sparse_categorical_accuracy: 0.9753
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2267 - val_sparse_categorical_accuracy: 0.9776
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0105 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2758 - val_sparse_categorical_accuracy: 0.9733
21

TensorFlow 2:使用自定义训练循环进行提前停止

在 TensorFlow 2 中,如果您没有使用 内置 Keras 方法 进行训练和评估,则可以在 自定义训练循环 中实现提前停止。

首先使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标。

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

使用 tf.GradientTape@tf.function 装饰器 定义参数更新函数以加速

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

接下来,编写一个自定义训练循环,您可以在其中手动实现提前停止规则。

以下示例展示了如何在验证损失在一定数量的 epoch 内没有改善时停止训练。

epochs = 100
patience = 5
wait = 0
best = float('inf')

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss < best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.4644
Seen so far: 128 samples
Training loss at step 200: 0.2622
Seen so far: 25728 samples
Training loss at step 400: 0.2129
Seen so far: 51328 samples
Training acc over epoch: 0.9297
Validation acc: 0.9610
Time taken: 2.07s

Start of epoch 1
Training loss at step 0: 0.0756
Seen so far: 128 samples
Training loss at step 200: 0.1215
Seen so far: 25728 samples
Training loss at step 400: 0.1359
Seen so far: 51328 samples
Training acc over epoch: 0.9685
Validation acc: 0.9664
Time taken: 1.36s

Start of epoch 2
Training loss at step 0: 0.0390
Seen so far: 128 samples
Training loss at step 200: 0.0995
Seen so far: 25728 samples
Training loss at step 400: 0.1270
Seen so far: 51328 samples
Training acc over epoch: 0.9778
Validation acc: 0.9690
Time taken: 1.37s

Start of epoch 3
Training loss at step 0: 0.0282
Seen so far: 128 samples
Training loss at step 200: 0.0673
Seen so far: 25728 samples
Training loss at step 400: 0.0700
Seen so far: 51328 samples
Training acc over epoch: 0.9826
Validation acc: 0.9695
Time taken: 1.35s

Start of epoch 4
Training loss at step 0: 0.0249
Seen so far: 128 samples
Training loss at step 200: 0.0403
Seen so far: 25728 samples
Training loss at step 400: 0.0324
Seen so far: 51328 samples
Training acc over epoch: 0.9862
Validation acc: 0.9693
Time taken: 1.37s

Start of epoch 5
Training loss at step 0: 0.0279
Seen so far: 128 samples
Training loss at step 200: 0.0529
Seen so far: 25728 samples
Training loss at step 400: 0.0592
Seen so far: 51328 samples
Training acc over epoch: 0.9875
Validation acc: 0.9660
Time taken: 1.32s

Start of epoch 6
Training loss at step 0: 0.0200
Seen so far: 128 samples
Training loss at step 200: 0.0344
Seen so far: 25728 samples
Training loss at step 400: 0.0142
Seen so far: 51328 samples
Training acc over epoch: 0.9884
Validation acc: 0.9720
Time taken: 1.36s

Start of epoch 7
Training loss at step 0: 0.0325
Seen so far: 128 samples
Training loss at step 200: 0.0499
Seen so far: 25728 samples
Training loss at step 400: 0.0167
Seen so far: 51328 samples
Training acc over epoch: 0.9898
Validation acc: 0.9711
Time taken: 1.39s

Start of epoch 8
Training loss at step 0: 0.0155
Seen so far: 128 samples
Training loss at step 200: 0.0219
Seen so far: 25728 samples
Training loss at step 400: 0.0213
Seen so far: 51328 samples
Training acc over epoch: 0.9914
Validation acc: 0.9715
Time taken: 1.34s

Start of epoch 9
Training loss at step 0: 0.0068
Seen so far: 128 samples
Training loss at step 200: 0.0288
Seen so far: 25728 samples
Training loss at step 400: 0.0320
Seen so far: 51328 samples
Training acc over epoch: 0.9921
Validation acc: 0.9743
Time taken: 1.34s

Start of epoch 10
Training loss at step 0: 0.0069
Seen so far: 128 samples
Training loss at step 200: 0.0419
Seen so far: 25728 samples
Training loss at step 400: 0.0458
Seen so far: 51328 samples
Training acc over epoch: 0.9919
Validation acc: 0.9720
Time taken: 1.34s

Start of epoch 11
Training loss at step 0: 0.0015
Seen so far: 128 samples
Training loss at step 200: 0.0113
Seen so far: 25728 samples
Training loss at step 400: 0.0285
Seen so far: 51328 samples
Training acc over epoch: 0.9930
Validation acc: 0.9740
Time taken: 1.33s

Start of epoch 12
Training loss at step 0: 0.0010
Seen so far: 128 samples
Training loss at step 200: 0.0450
Seen so far: 25728 samples
Training loss at step 400: 0.0561
Seen so far: 51328 samples
Training acc over epoch: 0.9926
Validation acc: 0.9744
Time taken: 1.34s

Start of epoch 13
Training loss at step 0: 0.0038
Seen so far: 128 samples
Training loss at step 200: 0.0309
Seen so far: 25728 samples
Training loss at step 400: 0.0124
Seen so far: 51328 samples
Training acc over epoch: 0.9937
Validation acc: 0.9740
Time taken: 1.36s

Start of epoch 14
Training loss at step 0: 0.0011
Seen so far: 128 samples
Training loss at step 200: 0.0068
Seen so far: 25728 samples
Training loss at step 400: 0.0110
Seen so far: 51328 samples
Training acc over epoch: 0.9930
Validation acc: 0.9736
Time taken: 1.35s

Start of epoch 15
Training loss at step 0: 0.0043
Seen so far: 128 samples
Training loss at step 200: 0.0085
Seen so far: 25728 samples
Training loss at step 400: 0.0042
Seen so far: 51328 samples
Training acc over epoch: 0.9942
Validation acc: 0.9752
Time taken: 1.34s

Start of epoch 16
Training loss at step 0: 0.0208
Seen so far: 128 samples
Training loss at step 200: 0.0042
Seen so far: 25728 samples
Training loss at step 400: 0.1063
Seen so far: 51328 samples
Training acc over epoch: 0.9947
Validation acc: 0.9740
Time taken: 1.34s

Start of epoch 17
Training loss at step 0: 0.0067
Seen so far: 128 samples
Training loss at step 200: 0.0277
Seen so far: 25728 samples
Training loss at step 400: 0.0787
Seen so far: 51328 samples
Training acc over epoch: 0.9951
Validation acc: 0.9729
Time taken: 1.33s

Start of epoch 18
Training loss at step 0: 0.0017
Seen so far: 128 samples
Training loss at step 200: 0.0131
Seen so far: 25728 samples
Training loss at step 400: 0.0431
Seen so far: 51328 samples
Training acc over epoch: 0.9943
Validation acc: 0.9739
Time taken: 1.40s

Start of epoch 19
Training loss at step 0: 0.0004
Seen so far: 128 samples
Training loss at step 200: 0.0220
Seen so far: 25728 samples
Training loss at step 400: 0.0662
Seen so far: 51328 samples
Training acc over epoch: 0.9952
Validation acc: 0.9738
Time taken: 1.34s

Start of epoch 20
Training loss at step 0: 0.0003
Seen so far: 128 samples
Training loss at step 200: 0.0306
Seen so far: 25728 samples
Training loss at step 400: 0.0083
Seen so far: 51328 samples
Training acc over epoch: 0.9955
Validation acc: 0.9753
Time taken: 1.37s

Start of epoch 21
Training loss at step 0: 0.0016
Seen so far: 128 samples
Training loss at step 200: 0.0003
Seen so far: 25728 samples
Training loss at step 400: 0.0069
Seen so far: 51328 samples
Training acc over epoch: 0.9946
Validation acc: 0.9729
Time taken: 1.36s

Start of epoch 22
Training loss at step 0: 0.0626
Seen so far: 128 samples
Training loss at step 200: 0.0013
Seen so far: 25728 samples
Training loss at step 400: 0.0278
Seen so far: 51328 samples
Training acc over epoch: 0.9946
Validation acc: 0.9740
Time taken: 1.34s

Start of epoch 23
Training loss at step 0: 0.0318
Seen so far: 128 samples
Training loss at step 200: 0.0514
Seen so far: 25728 samples
Training loss at step 400: 0.0001
Seen so far: 51328 samples
Training acc over epoch: 0.9952
Validation acc: 0.9758
Time taken: 1.36s

Start of epoch 24
Training loss at step 0: 0.0004
Seen so far: 128 samples
Training loss at step 200: 0.0043
Seen so far: 25728 samples
Training loss at step 400: 0.0339
Seen so far: 51328 samples
Training acc over epoch: 0.9956
Validation acc: 0.9752
Time taken: 1.37s

Start of epoch 25
Training loss at step 0: 0.0000
Seen so far: 128 samples
Training loss at step 200: 0.0057
Seen so far: 25728 samples
Training loss at step 400: 0.1485
Seen so far: 51328 samples
Training acc over epoch: 0.9961
Validation acc: 0.9733
Time taken: 1.35s

Start of epoch 26
Training loss at step 0: 0.0005
Seen so far: 128 samples
Training loss at step 200: 0.0992
Seen so far: 25728 samples
Training loss at step 400: 0.0033
Seen so far: 51328 samples
Training acc over epoch: 0.9972
Validation acc: 0.9776
Time taken: 1.32s

Start of epoch 27
Training loss at step 0: 0.0004
Seen so far: 128 samples
Training loss at step 200: 0.0402
Seen so far: 25728 samples
Training loss at step 400: 0.0002
Seen so far: 51328 samples
Training acc over epoch: 0.9966
Validation acc: 0.9784
Time taken: 1.34s

Start of epoch 28
Training loss at step 0: 0.0005
Seen so far: 128 samples
Training loss at step 200: 0.0307
Seen so far: 25728 samples
Training loss at step 400: 0.1466
Seen so far: 51328 samples
Training acc over epoch: 0.9948
Validation acc: 0.9742
Time taken: 1.37s

Start of epoch 29
Training loss at step 0: 0.0092
Seen so far: 128 samples
Training loss at step 200: 0.0039
Seen so far: 25728 samples
Training loss at step 400: 0.0358
Seen so far: 51328 samples
Training acc over epoch: 0.9959
Validation acc: 0.9791
Time taken: 1.33s

后续步骤