在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
此笔记本演示了如何设置带有提前停止的模型训练,首先使用 TensorFlow 1 中的 tf.estimator.Estimator
和提前停止钩子,然后使用 TensorFlow 2 中的 Keras API 或自定义训练循环。提前停止是一种正则化技术,如果例如验证损失达到某个阈值,它将停止训练。
在 TensorFlow 2 中,有三种方法可以实现提前停止
- 使用内置 Keras 回调 -
tf.keras.callbacks.EarlyStopping
- 并将其传递给Model.fit
。 - 定义一个自定义回调,并将其传递给 Keras
Model.fit
。 - 在 自定义训练循环(使用
tf.GradientTape
)中编写一个自定义提前停止规则。
设置
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-10-04 01:37:29.125012: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:37:29.125061: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:37:29.125095: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow 1:使用提前停止钩子和 tf.estimator 进行提前停止
首先定义用于 MNIST 数据集加载和预处理的函数,以及要与 tf.estimator.Estimator
一起使用的模型定义
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
在 TensorFlow 1 中,提前停止通过使用 tf.estimator.experimental.make_early_stopping_hook
设置提前停止钩子来实现。您将钩子作为 make_early_stopping_hook
方法的参数传递给 should_stop_fn
,它可以接受一个没有任何参数的函数。一旦 should_stop_fn
返回 True
,训练就会停止。
以下示例演示了如何实现一种提前停止技术,该技术将训练时间限制在最多 20 秒
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpp8x_ipb8 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpp8x_ipb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. 2023-10-04 01:37:32.147816: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. 2023-10-04 01:37:34.401540: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 2.2961025, step = 0 INFO:tensorflow:loss = 2.2961025, step = 0 INFO:tensorflow:global_step/sec: 335.823 INFO:tensorflow:global_step/sec: 335.823 INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec) INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec) INFO:tensorflow:global_step/sec: 423.864 INFO:tensorflow:global_step/sec: 423.864 INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec) INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec) INFO:tensorflow:global_step/sec: 436.472 INFO:tensorflow:global_step/sec: 436.472 INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec) INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec) INFO:tensorflow:global_step/sec: 432.337 INFO:tensorflow:global_step/sec: 432.337 INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec) INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec) INFO:tensorflow:global_step/sec: 425.537 INFO:tensorflow:global_step/sec: 425.537 INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec) INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec) INFO:tensorflow:global_step/sec: 505.423 INFO:tensorflow:global_step/sec: 505.423 INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec) INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec) INFO:tensorflow:global_step/sec: 518.525 INFO:tensorflow:global_step/sec: 518.525 INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec) INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec) INFO:tensorflow:global_step/sec: 509.36 INFO:tensorflow:global_step/sec: 509.36 INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec) INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec) INFO:tensorflow:global_step/sec: 531.176 INFO:tensorflow:global_step/sec: 531.176 INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec) INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec) INFO:tensorflow:global_step/sec: 465.699 INFO:tensorflow:global_step/sec: 465.699 INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec) INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec) INFO:tensorflow:global_step/sec: 495.799 INFO:tensorflow:global_step/sec: 495.799 INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec) INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec) INFO:tensorflow:global_step/sec: 504.403 INFO:tensorflow:global_step/sec: 504.403 INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec) INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec) INFO:tensorflow:global_step/sec: 482.232 INFO:tensorflow:global_step/sec: 482.232 INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec) INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec) INFO:tensorflow:global_step/sec: 501.553 INFO:tensorflow:global_step/sec: 501.553 INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec) INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec) INFO:tensorflow:global_step/sec: 459.337 INFO:tensorflow:global_step/sec: 459.337 INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec) INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec) INFO:tensorflow:global_step/sec: 511.454 INFO:tensorflow:global_step/sec: 511.454 INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec) INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec) INFO:tensorflow:global_step/sec: 487.452 INFO:tensorflow:global_step/sec: 487.452 INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec) INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec) INFO:tensorflow:global_step/sec: 500.287 INFO:tensorflow:global_step/sec: 500.287 INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec) INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec) INFO:tensorflow:global_step/sec: 476.868 INFO:tensorflow:global_step/sec: 476.868 INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec) INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec) INFO:tensorflow:global_step/sec: 484.689 INFO:tensorflow:global_step/sec: 484.689 INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec) INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec) INFO:tensorflow:global_step/sec: 496.691 INFO:tensorflow:global_step/sec: 496.691 INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec) INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec) INFO:tensorflow:global_step/sec: 499.682 INFO:tensorflow:global_step/sec: 499.682 INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec) INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec) INFO:tensorflow:global_step/sec: 507.391 INFO:tensorflow:global_step/sec: 507.391 INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec) INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec) INFO:tensorflow:global_step/sec: 475.414 INFO:tensorflow:global_step/sec: 475.414 INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec) INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec) INFO:tensorflow:global_step/sec: 499.39 INFO:tensorflow:global_step/sec: 499.39 INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec) INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec) INFO:tensorflow:global_step/sec: 500.659 INFO:tensorflow:global_step/sec: 500.659 INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec) INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec) INFO:tensorflow:global_step/sec: 500.201 INFO:tensorflow:global_step/sec: 500.201 INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec) INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec) INFO:tensorflow:global_step/sec: 517.488 INFO:tensorflow:global_step/sec: 517.488 INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec) INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec) INFO:tensorflow:global_step/sec: 466.55 INFO:tensorflow:global_step/sec: 466.55 INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec) INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec) INFO:tensorflow:global_step/sec: 499.818 INFO:tensorflow:global_step/sec: 499.818 INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec) INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec) INFO:tensorflow:global_step/sec: 477.253 INFO:tensorflow:global_step/sec: 477.253 INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec) INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec) INFO:tensorflow:global_step/sec: 512.079 INFO:tensorflow:global_step/sec: 512.079 INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec) INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec) INFO:tensorflow:global_step/sec: 461.596 INFO:tensorflow:global_step/sec: 461.596 INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec) INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec) INFO:tensorflow:global_step/sec: 488.092 INFO:tensorflow:global_step/sec: 488.092 INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec) INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec) INFO:tensorflow:global_step/sec: 490.26 INFO:tensorflow:global_step/sec: 490.26 INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec) INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec) INFO:tensorflow:global_step/sec: 487.323 INFO:tensorflow:global_step/sec: 487.323 INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec) INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec) INFO:tensorflow:global_step/sec: 492.262 INFO:tensorflow:global_step/sec: 492.262 INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec) INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec) INFO:tensorflow:global_step/sec: 449.227 INFO:tensorflow:global_step/sec: 449.227 INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec) INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec) INFO:tensorflow:global_step/sec: 521.924 INFO:tensorflow:global_step/sec: 521.924 INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec) INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec) INFO:tensorflow:global_step/sec: 535.844 INFO:tensorflow:global_step/sec: 535.844 INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec) INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec) INFO:tensorflow:global_step/sec: 496.551 INFO:tensorflow:global_step/sec: 496.551 INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec) INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec) INFO:tensorflow:global_step/sec: 508.086 INFO:tensorflow:global_step/sec: 508.086 INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec) INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec) INFO:tensorflow:global_step/sec: 462.149 INFO:tensorflow:global_step/sec: 462.149 INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec) INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec) INFO:tensorflow:global_step/sec: 503.226 INFO:tensorflow:global_step/sec: 503.226 INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec) INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec) INFO:tensorflow:global_step/sec: 510.874 INFO:tensorflow:global_step/sec: 510.874 INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec) INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec) INFO:tensorflow:global_step/sec: 511.332 INFO:tensorflow:global_step/sec: 511.332 INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec) INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec) INFO:tensorflow:global_step/sec: 472.525 INFO:tensorflow:global_step/sec: 472.525 INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec) INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec) INFO:tensorflow:global_step/sec: 505.262 INFO:tensorflow:global_step/sec: 505.262 INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec) INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec) INFO:tensorflow:global_step/sec: 515.453 INFO:tensorflow:global_step/sec: 515.453 INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec) INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec) INFO:tensorflow:global_step/sec: 479.259 INFO:tensorflow:global_step/sec: 479.259 INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec) INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec) INFO:tensorflow:global_step/sec: 491.361 INFO:tensorflow:global_step/sec: 491.361 INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec) INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec) INFO:tensorflow:global_step/sec: 471.256 INFO:tensorflow:global_step/sec: 471.256 INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec) INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec) INFO:tensorflow:global_step/sec: 506.672 INFO:tensorflow:global_step/sec: 506.672 INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec) INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec) INFO:tensorflow:global_step/sec: 491.333 INFO:tensorflow:global_step/sec: 491.333 INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec) INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec) INFO:tensorflow:global_step/sec: 500.638 INFO:tensorflow:global_step/sec: 500.638 INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec) INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec) INFO:tensorflow:global_step/sec: 508.206 INFO:tensorflow:global_step/sec: 508.206 INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec) INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec) INFO:tensorflow:global_step/sec: 450.634 INFO:tensorflow:global_step/sec: 450.634 INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec) INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec) INFO:tensorflow:global_step/sec: 525.95 INFO:tensorflow:global_step/sec: 525.95 INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec) INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec) INFO:tensorflow:global_step/sec: 483.04 INFO:tensorflow:global_step/sec: 483.04 INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec) INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec) INFO:tensorflow:global_step/sec: 480.064 INFO:tensorflow:global_step/sec: 480.064 INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec) INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec) INFO:tensorflow:global_step/sec: 451.147 INFO:tensorflow:global_step/sec: 451.147 INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec) INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec) INFO:tensorflow:global_step/sec: 490.621 INFO:tensorflow:global_step/sec: 490.621 INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec) INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec) INFO:tensorflow:global_step/sec: 499.641 INFO:tensorflow:global_step/sec: 499.641 INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec) INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec) INFO:tensorflow:global_step/sec: 498.738 INFO:tensorflow:global_step/sec: 498.738 INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec) INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec) INFO:tensorflow:global_step/sec: 523.196 INFO:tensorflow:global_step/sec: 523.196 INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec) INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec) INFO:tensorflow:global_step/sec: 455.054 INFO:tensorflow:global_step/sec: 455.054 INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec) INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec) INFO:tensorflow:global_step/sec: 522.506 INFO:tensorflow:global_step/sec: 522.506 INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec) INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec) INFO:tensorflow:global_step/sec: 506.766 INFO:tensorflow:global_step/sec: 506.766 INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec) INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec) INFO:tensorflow:global_step/sec: 481.896 INFO:tensorflow:global_step/sec: 481.896 INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec) INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec) INFO:tensorflow:global_step/sec: 506.833 INFO:tensorflow:global_step/sec: 506.833 INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec) INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec) INFO:tensorflow:global_step/sec: 467.501 INFO:tensorflow:global_step/sec: 467.501 INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec) INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec) INFO:tensorflow:global_step/sec: 494.573 INFO:tensorflow:global_step/sec: 494.573 INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec) INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec) INFO:tensorflow:global_step/sec: 522.627 INFO:tensorflow:global_step/sec: 522.627 INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec) INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec) INFO:tensorflow:global_step/sec: 486.106 INFO:tensorflow:global_step/sec: 486.106 INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec) INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec) INFO:tensorflow:global_step/sec: 479.817 INFO:tensorflow:global_step/sec: 479.817 INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec) INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec) INFO:tensorflow:global_step/sec: 487.355 INFO:tensorflow:global_step/sec: 487.355 INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec) INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec) INFO:tensorflow:global_step/sec: 526.322 INFO:tensorflow:global_step/sec: 526.322 INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec) INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec) INFO:tensorflow:global_step/sec: 501.783 INFO:tensorflow:global_step/sec: 501.783 INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec) INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec) INFO:tensorflow:global_step/sec: 485.409 INFO:tensorflow:global_step/sec: 485.409 INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec) INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec) INFO:tensorflow:global_step/sec: 462.017 INFO:tensorflow:global_step/sec: 462.017 INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec) INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec) INFO:tensorflow:global_step/sec: 488.439 INFO:tensorflow:global_step/sec: 488.439 INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec) INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec) INFO:tensorflow:global_step/sec: 485.238 INFO:tensorflow:global_step/sec: 485.238 INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec) INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec) INFO:tensorflow:Requesting early stopping at global step 8286 INFO:tensorflow:Requesting early stopping at global step 8286 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287... INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52 INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 2023-10-04 01:37:52.957202: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 0.51915s INFO:tensorflow:Inference Time : 0.51915s INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53 INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53 INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305 INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 INFO:tensorflow:Loss for final step: 0.28788015. INFO:tensorflow:Loss for final step: 0.28788015. ({'loss': 0.21179305, 'global_step': 8287}, [])
TensorFlow 2:使用内置回调和 Model.fit 进行提前停止
准备 MNIST 数据集和一个简单的 Keras 模型
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
在 TensorFlow 2 中,当您使用内置 Keras Model.fit
(或 Model.evaluate
)时,您可以通过将内置回调 - tf.keras.callbacks.EarlyStopping
- 传递给 Model.fit
的 callbacks
参数来配置提前停止。
EarlyStopping
回调会监控用户指定的指标,并在指标停止改进时结束训练。(查看 使用内置方法进行训练和评估 或 API 文档 了解更多信息。)
以下是一个提前停止回调的示例,它会监控损失,并在显示无改进的时期数设置为 3
(patience
) 后停止训练
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 3s 4ms/step - loss: 0.2326 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.1214 - val_sparse_categorical_accuracy: 0.9630 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.1004 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.1009 - val_sparse_categorical_accuracy: 0.9684 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0689 - sparse_categorical_accuracy: 0.9792 - val_loss: 0.1061 - val_sparse_categorical_accuracy: 0.9674 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0528 - sparse_categorical_accuracy: 0.9835 - val_loss: 0.1244 - val_sparse_categorical_accuracy: 0.9640 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0424 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1001 - val_sparse_categorical_accuracy: 0.9730 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0355 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9728 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.1029 - val_sparse_categorical_accuracy: 0.9748 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0319 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1267 - val_sparse_categorical_accuracy: 0.9707 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9733 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0281 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.1246 - val_sparse_categorical_accuracy: 0.9718 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0270 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1184 - val_sparse_categorical_accuracy: 0.9786 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0217 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1231 - val_sparse_categorical_accuracy: 0.9756 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.1466 - val_sparse_categorical_accuracy: 0.9716 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0220 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1475 - val_sparse_categorical_accuracy: 0.9754 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1472 - val_sparse_categorical_accuracy: 0.9746 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1454 - val_sparse_categorical_accuracy: 0.9755 16
TensorFlow 2:使用自定义回调和 Model.fit 进行提前停止
您还可以实现一个 自定义提前停止回调,它也可以传递给 Model.fit
(或 Model.evaluate
)的 callbacks
参数。
在本例中,训练过程将在 self.model.stop_training
设置为 True
时停止。
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9785 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1728 - val_sparse_categorical_accuracy: 0.9755 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1567 - val_sparse_categorical_accuracy: 0.9787 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0198 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9740 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0146 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1658 - val_sparse_categorical_accuracy: 0.9756 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2150 - val_sparse_categorical_accuracy: 0.9726 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1890 - val_sparse_categorical_accuracy: 0.9769 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1852 - val_sparse_categorical_accuracy: 0.9772 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.2051 - val_sparse_categorical_accuracy: 0.9761 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9727 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.2008 - val_sparse_categorical_accuracy: 0.9757 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9755 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2270 - val_sparse_categorical_accuracy: 0.9739 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2169 - val_sparse_categorical_accuracy: 0.9775 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2282 - val_sparse_categorical_accuracy: 0.9773 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2723 - val_sparse_categorical_accuracy: 0.9738 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2223 - val_sparse_categorical_accuracy: 0.9784 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2489 - val_sparse_categorical_accuracy: 0.9770 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2607 - val_sparse_categorical_accuracy: 0.9753 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2267 - val_sparse_categorical_accuracy: 0.9776 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0105 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2758 - val_sparse_categorical_accuracy: 0.9733 21
TensorFlow 2:使用自定义训练循环进行提前停止
在 TensorFlow 2 中,如果您没有使用 内置 Keras 方法 进行训练和评估,则可以在 自定义训练循环 中实现提前停止。
首先使用 Keras API 定义另一个简单的模型、优化器、损失函数和指标。
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
使用 tf.GradientTape 和 @tf.function
装饰器 定义参数更新函数以加速。
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
接下来,编写一个自定义训练循环,您可以在其中手动实现提前停止规则。
以下示例展示了如何在验证损失在一定数量的 epoch 内没有改善时停止训练。
epochs = 100
patience = 5
wait = 0
best = float('inf')
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss < best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.4644 Seen so far: 128 samples Training loss at step 200: 0.2622 Seen so far: 25728 samples Training loss at step 400: 0.2129 Seen so far: 51328 samples Training acc over epoch: 0.9297 Validation acc: 0.9610 Time taken: 2.07s Start of epoch 1 Training loss at step 0: 0.0756 Seen so far: 128 samples Training loss at step 200: 0.1215 Seen so far: 25728 samples Training loss at step 400: 0.1359 Seen so far: 51328 samples Training acc over epoch: 0.9685 Validation acc: 0.9664 Time taken: 1.36s Start of epoch 2 Training loss at step 0: 0.0390 Seen so far: 128 samples Training loss at step 200: 0.0995 Seen so far: 25728 samples Training loss at step 400: 0.1270 Seen so far: 51328 samples Training acc over epoch: 0.9778 Validation acc: 0.9690 Time taken: 1.37s Start of epoch 3 Training loss at step 0: 0.0282 Seen so far: 128 samples Training loss at step 200: 0.0673 Seen so far: 25728 samples Training loss at step 400: 0.0700 Seen so far: 51328 samples Training acc over epoch: 0.9826 Validation acc: 0.9695 Time taken: 1.35s Start of epoch 4 Training loss at step 0: 0.0249 Seen so far: 128 samples Training loss at step 200: 0.0403 Seen so far: 25728 samples Training loss at step 400: 0.0324 Seen so far: 51328 samples Training acc over epoch: 0.9862 Validation acc: 0.9693 Time taken: 1.37s Start of epoch 5 Training loss at step 0: 0.0279 Seen so far: 128 samples Training loss at step 200: 0.0529 Seen so far: 25728 samples Training loss at step 400: 0.0592 Seen so far: 51328 samples Training acc over epoch: 0.9875 Validation acc: 0.9660 Time taken: 1.32s Start of epoch 6 Training loss at step 0: 0.0200 Seen so far: 128 samples Training loss at step 200: 0.0344 Seen so far: 25728 samples Training loss at step 400: 0.0142 Seen so far: 51328 samples Training acc over epoch: 0.9884 Validation acc: 0.9720 Time taken: 1.36s Start of epoch 7 Training loss at step 0: 0.0325 Seen so far: 128 samples Training loss at step 200: 0.0499 Seen so far: 25728 samples Training loss at step 400: 0.0167 Seen so far: 51328 samples Training acc over epoch: 0.9898 Validation acc: 0.9711 Time taken: 1.39s Start of epoch 8 Training loss at step 0: 0.0155 Seen so far: 128 samples Training loss at step 200: 0.0219 Seen so far: 25728 samples Training loss at step 400: 0.0213 Seen so far: 51328 samples Training acc over epoch: 0.9914 Validation acc: 0.9715 Time taken: 1.34s Start of epoch 9 Training loss at step 0: 0.0068 Seen so far: 128 samples Training loss at step 200: 0.0288 Seen so far: 25728 samples Training loss at step 400: 0.0320 Seen so far: 51328 samples Training acc over epoch: 0.9921 Validation acc: 0.9743 Time taken: 1.34s Start of epoch 10 Training loss at step 0: 0.0069 Seen so far: 128 samples Training loss at step 200: 0.0419 Seen so far: 25728 samples Training loss at step 400: 0.0458 Seen so far: 51328 samples Training acc over epoch: 0.9919 Validation acc: 0.9720 Time taken: 1.34s Start of epoch 11 Training loss at step 0: 0.0015 Seen so far: 128 samples Training loss at step 200: 0.0113 Seen so far: 25728 samples Training loss at step 400: 0.0285 Seen so far: 51328 samples Training acc over epoch: 0.9930 Validation acc: 0.9740 Time taken: 1.33s Start of epoch 12 Training loss at step 0: 0.0010 Seen so far: 128 samples Training loss at step 200: 0.0450 Seen so far: 25728 samples Training loss at step 400: 0.0561 Seen so far: 51328 samples Training acc over epoch: 0.9926 Validation acc: 0.9744 Time taken: 1.34s Start of epoch 13 Training loss at step 0: 0.0038 Seen so far: 128 samples Training loss at step 200: 0.0309 Seen so far: 25728 samples Training loss at step 400: 0.0124 Seen so far: 51328 samples Training acc over epoch: 0.9937 Validation acc: 0.9740 Time taken: 1.36s Start of epoch 14 Training loss at step 0: 0.0011 Seen so far: 128 samples Training loss at step 200: 0.0068 Seen so far: 25728 samples Training loss at step 400: 0.0110 Seen so far: 51328 samples Training acc over epoch: 0.9930 Validation acc: 0.9736 Time taken: 1.35s Start of epoch 15 Training loss at step 0: 0.0043 Seen so far: 128 samples Training loss at step 200: 0.0085 Seen so far: 25728 samples Training loss at step 400: 0.0042 Seen so far: 51328 samples Training acc over epoch: 0.9942 Validation acc: 0.9752 Time taken: 1.34s Start of epoch 16 Training loss at step 0: 0.0208 Seen so far: 128 samples Training loss at step 200: 0.0042 Seen so far: 25728 samples Training loss at step 400: 0.1063 Seen so far: 51328 samples Training acc over epoch: 0.9947 Validation acc: 0.9740 Time taken: 1.34s Start of epoch 17 Training loss at step 0: 0.0067 Seen so far: 128 samples Training loss at step 200: 0.0277 Seen so far: 25728 samples Training loss at step 400: 0.0787 Seen so far: 51328 samples Training acc over epoch: 0.9951 Validation acc: 0.9729 Time taken: 1.33s Start of epoch 18 Training loss at step 0: 0.0017 Seen so far: 128 samples Training loss at step 200: 0.0131 Seen so far: 25728 samples Training loss at step 400: 0.0431 Seen so far: 51328 samples Training acc over epoch: 0.9943 Validation acc: 0.9739 Time taken: 1.40s Start of epoch 19 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0220 Seen so far: 25728 samples Training loss at step 400: 0.0662 Seen so far: 51328 samples Training acc over epoch: 0.9952 Validation acc: 0.9738 Time taken: 1.34s Start of epoch 20 Training loss at step 0: 0.0003 Seen so far: 128 samples Training loss at step 200: 0.0306 Seen so far: 25728 samples Training loss at step 400: 0.0083 Seen so far: 51328 samples Training acc over epoch: 0.9955 Validation acc: 0.9753 Time taken: 1.37s Start of epoch 21 Training loss at step 0: 0.0016 Seen so far: 128 samples Training loss at step 200: 0.0003 Seen so far: 25728 samples Training loss at step 400: 0.0069 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9729 Time taken: 1.36s Start of epoch 22 Training loss at step 0: 0.0626 Seen so far: 128 samples Training loss at step 200: 0.0013 Seen so far: 25728 samples Training loss at step 400: 0.0278 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9740 Time taken: 1.34s Start of epoch 23 Training loss at step 0: 0.0318 Seen so far: 128 samples Training loss at step 200: 0.0514 Seen so far: 25728 samples Training loss at step 400: 0.0001 Seen so far: 51328 samples Training acc over epoch: 0.9952 Validation acc: 0.9758 Time taken: 1.36s Start of epoch 24 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0043 Seen so far: 25728 samples Training loss at step 400: 0.0339 Seen so far: 51328 samples Training acc over epoch: 0.9956 Validation acc: 0.9752 Time taken: 1.37s Start of epoch 25 Training loss at step 0: 0.0000 Seen so far: 128 samples Training loss at step 200: 0.0057 Seen so far: 25728 samples Training loss at step 400: 0.1485 Seen so far: 51328 samples Training acc over epoch: 0.9961 Validation acc: 0.9733 Time taken: 1.35s Start of epoch 26 Training loss at step 0: 0.0005 Seen so far: 128 samples Training loss at step 200: 0.0992 Seen so far: 25728 samples Training loss at step 400: 0.0033 Seen so far: 51328 samples Training acc over epoch: 0.9972 Validation acc: 0.9776 Time taken: 1.32s Start of epoch 27 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0402 Seen so far: 25728 samples Training loss at step 400: 0.0002 Seen so far: 51328 samples Training acc over epoch: 0.9966 Validation acc: 0.9784 Time taken: 1.34s Start of epoch 28 Training loss at step 0: 0.0005 Seen so far: 128 samples Training loss at step 200: 0.0307 Seen so far: 25728 samples Training loss at step 400: 0.1466 Seen so far: 51328 samples Training acc over epoch: 0.9948 Validation acc: 0.9742 Time taken: 1.37s Start of epoch 29 Training loss at step 0: 0.0092 Seen so far: 128 samples Training loss at step 200: 0.0039 Seen so far: 25728 samples Training loss at step 400: 0.0358 Seen so far: 51328 samples Training acc over epoch: 0.9959 Validation acc: 0.9791 Time taken: 1.33s
后续步骤
- 在 API 文档 中了解有关 Keras 内置提前停止回调 API 的更多信息。
- 学习 编写自定义 Keras 回调,包括 在最小损失时提前停止。
- 了解有关 使用 Keras 内置方法进行训练和评估 的信息。
- 在 过拟合和欠拟合 教程中探索常见的正则化技术,该教程使用
EarlyStopping
回调。