在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看 | 下载笔记本 |
概述
本指南提供了一份使用 TensorFlow 2 (TF2) 编写代码的最佳实践清单,适用于最近从 TensorFlow 1 (TF1) 迁移过来的用户。有关将 TF1 代码迁移到 TF2 的更多信息,请参阅指南中的迁移部分。
设置
导入 TensorFlow 和本指南中示例所需的依赖项。
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-04 01:22:53.526066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:22:53.526110: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:22:53.526158: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow 2 的惯用代码建议
将代码重构为更小的模块
一个好的做法是将代码重构为更小的函数,并在需要时调用它们。为了获得最佳性能,您应该尝试将尽可能大的计算块装饰在tf.function
中(请注意,由tf.function
调用的嵌套 Python 函数不需要单独装饰,除非您希望为tf.function
使用不同的jit_compile
设置)。根据您的用例,这可能是多个训练步骤,甚至可能是整个训练循环。对于推理用例,它可能只是一个模型前向传递。
为某些tf.keras.optimizer
调整默认学习率
某些 Keras 优化器在 TF2 中具有不同的学习率。如果您发现模型的收敛行为发生了变化,请检查默认学习率。
对于optimizers.SGD
、optimizers.Adam
或optimizers.RMSprop
,没有任何更改。
以下默认学习率已更改
optimizers.Adagrad
从0.01
更改为0.001
optimizers.Adadelta
从1.0
更改为0.001
optimizers.Adamax
从0.002
更改为0.001
optimizers.Nadam
从0.002
更改为0.001
使用tf.Module
和 Keras 层来管理变量
tf.Module
和 tf.keras.layers.Layer
提供了方便的 variables
和 trainable_variables
属性,它们递归地收集所有依赖变量。这使得在使用变量的地方轻松管理变量变得容易。
Keras 层/模型继承自 tf.train.Checkpointable
并与 @tf.function
集成,这使得可以直接从 Keras 对象检查点或导出 SavedModels。您不必使用 Keras 的 Model.fit
API 来利用这些集成。
阅读 Keras 指南中关于 迁移学习和微调 的部分,了解如何使用 Keras 收集相关变量的子集。
结合 tf.data.Dataset
和 tf.function
TensorFlow Datasets 包 (tfds
) 包含将预定义数据集加载为 tf.data.Dataset
对象的实用程序。对于此示例,您可以使用 tfds
加载 MNIST 数据集。
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
2023-10-04 01:22:57.406511: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
然后准备用于训练的数据
- 重新缩放每个图像。
- 打乱示例的顺序。
- 收集图像和标签的批次。
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
为了使示例简短,将数据集修剪为仅返回 5 个批次
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2023-10-04 01:22:58.048011: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
使用常规 Python 迭代来迭代适合内存的训练数据。否则,tf.data.Dataset
是从磁盘流式传输训练数据的最佳方法。数据集是 可迭代的(而不是迭代器),并且在急切执行中就像其他 Python 可迭代对象一样工作。您可以通过将代码包装在 tf.function
中来充分利用数据集异步预取/流式传输功能,该功能使用 AutoGraph 将 Python 迭代替换为等效的图形操作。
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
如果您使用 Keras Model.fit
API,您不必担心数据集迭代。
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
使用 Keras 训练循环
如果您不需要对训练过程进行低级控制,建议使用 Keras 的内置 fit
、evaluate
和 predict
方法。这些方法提供了一个统一的接口来训练模型,无论实现方式(顺序、函数或子类)如何。
这些方法的优点包括
- 它们接受 Numpy 数组、Python 生成器和
tf.data.Datasets
。 - 它们自动应用正则化和激活损失。
- 它们支持
tf.distribute
,其中训练代码保持不变 无论硬件配置如何。 - 它们支持任意可调用对象作为损失和指标。
- 它们支持回调,例如
tf.keras.callbacks.TensorBoard
和自定义回调。 - 它们性能良好,自动使用 TensorFlow 图。
以下是如何使用 Dataset
训练模型的示例。有关其工作原理的详细信息,请查看 教程。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 2s 44ms/step - loss: 1.6644 - accuracy: 0.4906 Epoch 2/5 2023-10-04 01:22:59.569439: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 9ms/step - loss: 0.5173 - accuracy: 0.9062 Epoch 3/5 2023-10-04 01:23:00.062308: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 9ms/step - loss: 0.3418 - accuracy: 0.9469 Epoch 4/5 2023-10-04 01:23:00.384057: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 8ms/step - loss: 0.2707 - accuracy: 0.9781 Epoch 5/5 2023-10-04 01:23:00.766486: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 8ms/step - loss: 0.2195 - accuracy: 0.9812 2023-10-04 01:23:01.120149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 4ms/step - loss: 1.6036 - accuracy: 0.6250 Loss 1.6036441326141357, Accuracy 0.625 2023-10-04 01:23:01.572685: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
自定义训练并编写自己的循环
如果 Keras 模型适合您,但您需要更多灵活性并控制训练步骤或外部训练循环,您可以实现自己的训练步骤,甚至实现整个训练循环。查看 Keras 指南中关于 自定义 fit
的部分,了解更多信息。
您还可以将许多内容实现为 tf.keras.callbacks.Callback
。
此方法具有 前面提到的 许多优点,但它让您能够控制训练步骤,甚至控制外部循环。
标准训练循环有三个步骤
- 迭代 Python 生成器或
tf.data.Dataset
以获取示例批次。 - 使用
tf.GradientTape
收集梯度。 - 使用
tf.keras.optimizers
中的一种来将权重更新应用于模型的变量。
请记住
- 始终在子类化层和模型的
call
方法上包含一个training
参数。 - 确保使用正确设置的
training
参数调用模型。 - 根据使用情况,模型变量可能在模型运行在数据批次上之前不存在。
- 您需要手动处理模型的正则化损失等内容。
无需运行变量初始化器或添加手动控制依赖项。 tf.function
会为您处理自动控制依赖项和创建时的变量初始化。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2023-10-04 01:23:02.652222: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2023-10-04 01:23:02.957452: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2023-10-04 01:23:03.632425: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2023-10-04 01:23:03.877866: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2023-10-04 01:23:04.197488: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
利用 tf.function
和 Python 控制流
tf.function
提供了一种将数据依赖的控制流转换为图形模式等效项(如 tf.cond
和 tf.while_loop
)的方法。
数据依赖的控制流出现的一个常见地方是在序列模型中。 tf.keras.layers.RNN
包装了一个 RNN 单元,允许您静态或动态地展开递归。例如,您可以重新实现动态展开,如下所示。
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
阅读 tf.function
指南 以获取更多信息。
新式指标和损失
指标和损失都是可以在急切执行和 tf.function
中工作的对象。
损失对象是可调用的,并期望 (y_true
, y_pred
) 作为参数
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
使用指标收集和显示数据
您可以使用 tf.metrics
来聚合数据,并使用 tf.summary
来记录摘要并使用上下文管理器将其重定向到写入器。摘要直接发出到写入器,这意味着您必须在调用点提供 step
值。
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
使用 tf.metrics
来聚合数据,然后再将其作为摘要记录。指标是有状态的;它们累积值,并在您调用 result
方法(例如 Mean.result
)时返回累积结果。使用 Model.reset_states
清除累积的值。
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
通过将 TensorBoard 指向摘要日志目录来可视化生成的摘要
tensorboard --logdir /tmp/summaries
使用 tf.summary
API 来写入摘要数据,以便在 TensorBoard 中可视化。有关更多信息,请阅读 tf.summary
指南。
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2023-10-04 01:23:05.220607: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.176 accuracy: 0.994 2023-10-04 01:23:05.554495: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.153 accuracy: 0.991 2023-10-04 01:23:06.043597: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.134 accuracy: 0.994 2023-10-04 01:23:06.297768: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.108 accuracy: 1.000 Epoch: 4 loss: 0.095 accuracy: 1.000 2023-10-04 01:23:06.678292: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Keras 指标名称
Keras 模型在处理指标名称方面始终如一。当您在指标列表中传递字符串时,该确切字符串将用作指标的 name
。这些名称在 model.fit
返回的历史对象中以及传递给 keras.callbacks
的日志中可见。设置为您在指标列表中传递的字符串。
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 9ms/step - loss: 0.1077 - acc: 0.9937 - accuracy: 0.9937 - my_accuracy: 0.9937 2023-10-04 01:23:07.849601: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
调试
使用急切执行来逐步运行您的代码,以检查形状、数据类型和值。某些 API(如 tf.function
、tf.keras
等)旨在使用图形执行,以提高性能和可移植性。在调试时,使用 tf.config.run_functions_eagerly(True)
在此代码中使用急切执行。
例如
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
这在 Keras 模型和其他支持急切执行的 API 中也能正常工作
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
注意
tf.keras.Model
方法(如fit
、evaluate
和predict
)作为 图形 执行,在幕后使用tf.function
。当使用
tf.keras.Model.compile
时,将run_eagerly = True
设置为禁用将Model
逻辑包装在tf.function
中。使用
tf.data.experimental.enable_debug_mode
为tf.data
启用调试模式。阅读 API 文档 以获取更多详细信息。
不要在您的对象中保留 tf.Tensors
这些张量对象可能是在 tf.function
中或在急切执行上下文中创建的,并且这些张量的行为不同。始终仅将 tf.Tensor
用于中间值。
要跟踪状态,请使用 tf.Variable
,因为它们始终可以在两种上下文中使用。阅读 tf.Variable
指南 了解更多信息。