TensorFlow 2 实用指南

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

概述

本指南提供了一份使用 TensorFlow 2 (TF2) 编写代码的最佳实践清单,适用于最近从 TensorFlow 1 (TF1) 迁移过来的用户。有关将 TF1 代码迁移到 TF2 的更多信息,请参阅指南中的迁移部分

设置

导入 TensorFlow 和本指南中示例所需的依赖项。

import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-04 01:22:53.526066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-04 01:22:53.526110: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-04 01:22:53.526158: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow 2 的惯用代码建议

将代码重构为更小的模块

一个好的做法是将代码重构为更小的函数,并在需要时调用它们。为了获得最佳性能,您应该尝试将尽可能大的计算块装饰在tf.function 中(请注意,由tf.function 调用的嵌套 Python 函数不需要单独装饰,除非您希望为tf.function 使用不同的jit_compile 设置)。根据您的用例,这可能是多个训练步骤,甚至可能是整个训练循环。对于推理用例,它可能只是一个模型前向传递。

为某些tf.keras.optimizer 调整默认学习率

某些 Keras 优化器在 TF2 中具有不同的学习率。如果您发现模型的收敛行为发生了变化,请检查默认学习率。

对于optimizers.SGDoptimizers.Adamoptimizers.RMSprop,没有任何更改。

以下默认学习率已更改

使用tf.Module 和 Keras 层来管理变量

tf.Moduletf.keras.layers.Layer 提供了方便的 variablestrainable_variables 属性,它们递归地收集所有依赖变量。这使得在使用变量的地方轻松管理变量变得容易。

Keras 层/模型继承自 tf.train.Checkpointable 并与 @tf.function 集成,这使得可以直接从 Keras 对象检查点或导出 SavedModels。您不必使用 Keras 的 Model.fit API 来利用这些集成。

阅读 Keras 指南中关于 迁移学习和微调 的部分,了解如何使用 Keras 收集相关变量的子集。

结合 tf.data.Datasettf.function

TensorFlow Datasets 包 (tfds) 包含将预定义数据集加载为 tf.data.Dataset 对象的实用程序。对于此示例,您可以使用 tfds 加载 MNIST 数据集。

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
2023-10-04 01:22:57.406511: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://tensorflowcn.cn/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

然后准备用于训练的数据

  • 重新缩放每个图像。
  • 打乱示例的顺序。
  • 收集图像和标签的批次。
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

为了使示例简短,将数据集修剪为仅返回 5 个批次

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2023-10-04 01:22:58.048011: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

使用常规 Python 迭代来迭代适合内存的训练数据。否则,tf.data.Dataset 是从磁盘流式传输训练数据的最佳方法。数据集是 可迭代的(而不是迭代器),并且在急切执行中就像其他 Python 可迭代对象一样工作。您可以通过将代码包装在 tf.function 中来充分利用数据集异步预取/流式传输功能,该功能使用 AutoGraph 将 Python 迭代替换为等效的图形操作。

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

如果您使用 Keras Model.fit API,您不必担心数据集迭代。

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

使用 Keras 训练循环

如果您不需要对训练过程进行低级控制,建议使用 Keras 的内置 fitevaluatepredict 方法。这些方法提供了一个统一的接口来训练模型,无论实现方式(顺序、函数或子类)如何。

这些方法的优点包括

  • 它们接受 Numpy 数组、Python 生成器和 tf.data.Datasets
  • 它们自动应用正则化和激活损失。
  • 它们支持 tf.distribute,其中训练代码保持不变 无论硬件配置如何
  • 它们支持任意可调用对象作为损失和指标。
  • 它们支持回调,例如 tf.keras.callbacks.TensorBoard 和自定义回调。
  • 它们性能良好,自动使用 TensorFlow 图。

以下是如何使用 Dataset 训练模型的示例。有关其工作原理的详细信息,请查看 教程

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 2s 44ms/step - loss: 1.6644 - accuracy: 0.4906
Epoch 2/5
2023-10-04 01:22:59.569439: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 9ms/step - loss: 0.5173 - accuracy: 0.9062
Epoch 3/5
2023-10-04 01:23:00.062308: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 9ms/step - loss: 0.3418 - accuracy: 0.9469
Epoch 4/5
2023-10-04 01:23:00.384057: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 8ms/step - loss: 0.2707 - accuracy: 0.9781
Epoch 5/5
2023-10-04 01:23:00.766486: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 8ms/step - loss: 0.2195 - accuracy: 0.9812
2023-10-04 01:23:01.120149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 4ms/step - loss: 1.6036 - accuracy: 0.6250
Loss 1.6036441326141357, Accuracy 0.625
2023-10-04 01:23:01.572685: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

自定义训练并编写自己的循环

如果 Keras 模型适合您,但您需要更多灵活性并控制训练步骤或外部训练循环,您可以实现自己的训练步骤,甚至实现整个训练循环。查看 Keras 指南中关于 自定义 fit 的部分,了解更多信息。

您还可以将许多内容实现为 tf.keras.callbacks.Callback

此方法具有 前面提到的 许多优点,但它让您能够控制训练步骤,甚至控制外部循环。

标准训练循环有三个步骤

  1. 迭代 Python 生成器或 tf.data.Dataset 以获取示例批次。
  2. 使用 tf.GradientTape 收集梯度。
  3. 使用 tf.keras.optimizers 中的一种来将权重更新应用于模型的变量。

请记住

  • 始终在子类化层和模型的 call 方法上包含一个 training 参数。
  • 确保使用正确设置的 training 参数调用模型。
  • 根据使用情况,模型变量可能在模型运行在数据批次上之前不存在。
  • 您需要手动处理模型的正则化损失等内容。

无需运行变量初始化器或添加手动控制依赖项。 tf.function 会为您处理自动控制依赖项和创建时的变量初始化。

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2023-10-04 01:23:02.652222: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2023-10-04 01:23:02.957452: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2023-10-04 01:23:03.632425: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2023-10-04 01:23:03.877866: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2023-10-04 01:23:04.197488: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

利用 tf.function 和 Python 控制流

tf.function 提供了一种将数据依赖的控制流转换为图形模式等效项(如 tf.condtf.while_loop)的方法。

数据依赖的控制流出现的一个常见地方是在序列模型中。 tf.keras.layers.RNN 包装了一个 RNN 单元,允许您静态或动态地展开递归。例如,您可以重新实现动态展开,如下所示。

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

阅读 tf.function 指南 以获取更多信息。

新式指标和损失

指标和损失都是可以在急切执行和 tf.function 中工作的对象。

损失对象是可调用的,并期望 (y_true, y_pred) 作为参数

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

使用指标收集和显示数据

您可以使用 tf.metrics 来聚合数据,并使用 tf.summary 来记录摘要并使用上下文管理器将其重定向到写入器。摘要直接发出到写入器,这意味着您必须在调用点提供 step 值。

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

使用 tf.metrics 来聚合数据,然后再将其作为摘要记录。指标是有状态的;它们累积值,并在您调用 result 方法(例如 Mean.result)时返回累积结果。使用 Model.reset_states 清除累积的值。

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

通过将 TensorBoard 指向摘要日志目录来可视化生成的摘要

tensorboard --logdir /tmp/summaries

使用 tf.summary API 来写入摘要数据,以便在 TensorBoard 中可视化。有关更多信息,请阅读 tf.summary 指南

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2023-10-04 01:23:05.220607: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.176
  accuracy: 0.994
2023-10-04 01:23:05.554495: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.153
  accuracy: 0.991
2023-10-04 01:23:06.043597: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.134
  accuracy: 0.994
2023-10-04 01:23:06.297768: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.108
  accuracy: 1.000
Epoch:  4
  loss:     0.095
  accuracy: 1.000
2023-10-04 01:23:06.678292: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Keras 指标名称

Keras 模型在处理指标名称方面始终如一。当您在指标列表中传递字符串时,该确切字符串将用作指标的 name。这些名称在 model.fit 返回的历史对象中以及传递给 keras.callbacks 的日志中可见。设置为您在指标列表中传递的字符串。

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 9ms/step - loss: 0.1077 - acc: 0.9937 - accuracy: 0.9937 - my_accuracy: 0.9937
2023-10-04 01:23:07.849601: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

调试

使用急切执行来逐步运行您的代码,以检查形状、数据类型和值。某些 API(如 tf.functiontf.keras 等)旨在使用图形执行,以提高性能和可移植性。在调试时,使用 tf.config.run_functions_eagerly(True) 在此代码中使用急切执行。

例如

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

这在 Keras 模型和其他支持急切执行的 API 中也能正常工作

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

注意

不要在您的对象中保留 tf.Tensors

这些张量对象可能是在 tf.function 中或在急切执行上下文中创建的,并且这些张量的行为不同。始终仅将 tf.Tensor 用于中间值。

要跟踪状态,请使用 tf.Variable,因为它们始终可以在两种上下文中使用。阅读 tf.Variable 指南 了解更多信息。

资源和进一步阅读

  • 阅读 TF2 指南教程,了解更多关于如何使用 TF2 的信息。

  • 如果您之前使用过 TF1.x,强烈建议您将代码迁移到 TF2。阅读 迁移指南 了解更多信息。