TensorFlow 1.x 与 TensorFlow 2 - 行为和 API

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

在幕后,TensorFlow 2 遵循与 TF1.x 根本不同的编程范式。

本指南介绍了 TF1.x 和 TF2 在行为和 API 方面的根本差异,以及这些差异与您的迁移过程之间的关系。

主要更改的概述

从根本上说,TF1.x 和 TF2 使用围绕执行(TF2 中的急切执行)、变量、控制流、张量形状和张量相等性比较的不同运行时行为集。要与 TF2 兼容,您的代码必须与 TF2 的所有行为兼容。在迁移过程中,您可以通过 tf.compat.v1.enable_*tf.compat.v1.disable_* API 分别启用或禁用大多数这些行为。唯一的例外是删除集合,这是启用/禁用急切执行的副作用。

从高级别来看,TensorFlow 2

以下部分提供了有关 TF1.x 和 TF2 之间差异的更多背景信息。要详细了解 TF2 背后的设计过程,请阅读 RFC设计文档

API 清理

许多 API 在 TF2 中 已消失或已移动。一些主要变化包括删除 tf.apptf.flagstf.logging 以支持现在开源的 absl-py,重新安置位于 tf.contrib 中的项目,以及通过将使用较少的函数移动到子包(如 tf.math)来清理主 tf.* 命名空间。一些 API 已被其 TF2 等效项替换 - tf.summarytf.keras.metricstf.keras.optimizers

tf.compat.v1: 遗留和兼容性 API 端点

位于 tf.compattf.compat.v1 命名空间下的符号不被视为 TF2 API。这些命名空间包含兼容性符号和 TF 1.x 中的旧版 API 端点。它们旨在帮助从 TF1.x 迁移到 TF2。但是,由于这些 compat.v1 API 并非 TF2 API 的惯用方式,因此不要使用它们编写全新的 TF2 代码。

单个 tf.compat.v1 符号可能是 TF2 兼容的,因为即使启用了 TF2 行为(例如 tf.compat.v1.losses.mean_squared_error),它们仍然可以工作,而其他符号则与 TF2 不兼容(例如 tf.compat.v1.metrics.accuracy)。许多 compat.v1 符号(虽然不是全部)在其文档中包含专门的迁移信息,解释了它们与 TF2 行为的兼容程度,以及如何将它们迁移到 TF2 API。

TF2 升级脚本 中,可以将许多 compat.v1 API 符号映射到等效的 TF2 API,前提是它们是别名或具有相同的参数,但顺序不同。您还可以使用升级脚本自动重命名 TF1.x API。

假朋友 API

在 TF2 tf 命名空间(不在 compat.v1 下)中有一组“假朋友”符号,它们实际上在内部忽略了 TF2 行为,或者与 TF2 行为的完整集不完全兼容。因此,这些 API 可能会与 TF2 代码发生错误,可能以静默的方式。

  • tf.estimator.*: Estimators 在内部创建和使用图和会话。因此,它们不应被视为 TF2 兼容。如果您的代码正在运行 Estimators,则它没有使用 TF2 行为。
  • keras.Model.model_to_estimator(...): 这会在内部创建一个 Estimator,如上所述,它与 TF2 不兼容。
  • tf.Graph().as_default(): 这将进入 TF1.x 图行为,并且不遵循标准的 TF2 兼容 tf.function 行为。进入此类图的代码通常会通过会话运行它们,不应被视为 TF2 兼容。
  • tf.feature_column.* 特征列 API 通常依赖于 TF1 风格的 tf.compat.v1.get_variable 变量创建,并假设创建的变量将通过全局集合访问。由于 TF2 不支持集合,因此在启用 TF2 行为的情况下运行 API 时,API 可能无法正常工作。

其他 API 更改

  • TF2 对设备放置算法进行了重大改进,这使得 tf.colocate_with 的使用变得不必要。如果删除它会导致性能下降,请 提交错误报告

  • 将所有 tf.v1.ConfigProto 的使用替换为来自 tf.config 的等效函数。

急切执行

TF1.x 要求您通过进行 tf.* API 调用手动将抽象语法树(图)拼接在一起,然后通过将一组输出张量和输入张量传递给 session.run 调用来手动编译抽象语法树。TF2 急切地执行(就像 Python 通常那样),使图和会话感觉像是实现细节。

急切执行的一个显著副产品是 tf.control_dependencies 不再需要,因为所有代码行都按顺序执行(在 tf.function 中,具有副作用的代码按写入顺序执行)。

不再有全局变量

TF1.x 严重依赖于隐式全局命名空间和集合。当您调用 tf.Variable 时,它将被放入默认图中的集合中,并且它将保留在那里,即使您丢失了指向它的 Python 变量。然后,您可以恢复该 tf.Variable,但前提是您知道它创建时的名称。如果您无法控制变量的创建,这将很难做到。因此,出现了各种机制来帮助您再次找到变量,以及框架找到用户创建的变量。其中一些包括:变量范围、全局集合、辅助方法(如 tf.get_global_steptf.global_variables_initializer)、优化器隐式地计算所有可训练变量的梯度,等等。TF2 消除了所有这些机制(Variables 2.0 RFC),转而使用默认机制 - 您跟踪自己的变量。如果您丢失了 tf.Variable,它将被垃圾回收。

跟踪变量的要求会增加一些额外的工作,但借助 模型垫片 等工具和 tf.Moduletf.keras.layers.Layer 中隐式面向对象的变量集合 等行为,负担会降到最低。

函数,而不是会话

session.run 调用几乎就像函数调用一样:您指定输入和要调用的函数,然后您会得到一组输出。在 TF2 中,您可以使用 tf.function 装饰 Python 函数,以将其标记为 JIT 编译,以便 TensorFlow 将其作为单个图运行(Functions 2.0 RFC)。这种机制使 TF2 能够获得图模式的所有优势

  • 性能:函数可以被优化(节点修剪、内核融合等)。
  • 可移植性:函数可以被导出/重新导入(SavedModel 2.0 RFC),允许您重用和共享模块化的 TensorFlow 函数。
# TF1.x
outputs = session.run(f(placeholder), feed_dict={placeholder: input})
# TF2
outputs = f(input)

有了自由地交织 Python 和 TensorFlow 代码的能力,您可以利用 Python 的表达能力。但是,可移植的 TensorFlow 在没有 Python 解释器的上下文中执行,例如移动设备、C++ 和 JavaScript。为了帮助您在添加 tf.function 时避免重写代码,请使用 AutoGraph 将 Python 结构的子集转换为其 TensorFlow 等效项

  • for/while -> tf.while_loop(支持 breakcontinue
  • if -> tf.cond
  • for _ in dataset -> dataset.reduce

AutoGraph 支持控制流的任意嵌套,这使得可以高效且简洁地实现许多复杂的 ML 程序,例如序列模型、强化学习、自定义训练循环等等。

适应 TF 2.x 行为更改

只有在您迁移到 TF2 行为的完整集之后,您的迁移到 TF2 才算完成。可以通过 tf.compat.v1.enable_v2_behaviorstf.compat.v1.disable_v2_behaviors 启用或禁用行为的完整集。以下部分将详细讨论每个主要的行为更改。

使用 tf.function

在迁移过程中,您程序中最大的更改可能来自从图和会话到急切执行和 tf.function 的基本编程模型范式转变。请参考 TF2 迁移指南,了解有关从与急切执行和 tf.function 不兼容的 API 迁移到与它们兼容的 API 的更多信息。

以下是一些与任何特定 API 无关的常见程序模式,这些模式在从 tf.Graphtf.compat.v1.Session 切换到使用 tf.function 的急切执行时可能会导致问题。

模式 1:旨在仅执行一次的 Python 对象操作和变量创建被多次执行

在依赖于图和会话的 TF1.x 程序中,通常期望程序中的所有 Python 逻辑只执行一次。但是,使用急切执行和 tf.function,可以合理地预期您的 Python 逻辑至少会执行一次,但可能执行多次(要么是急切地执行多次,要么是在不同的 tf.function 跟踪中执行多次)。有时,tf.function 甚至会在相同的输入上跟踪两次,导致意外行为(参见示例 1 和 2)。请参考 tf.function 指南,了解更多详细信息。

示例 1:变量创建

考虑下面的示例,其中函数在被调用时创建了一个变量

def f():
  v = tf.Variable(1.0)
  return v

with tf.Graph().as_default():
  with tf.compat.v1.Session() as sess:
    res = f()
    sess.run(tf.compat.v1.global_variables_initializer())
    sess.run(res)

但是,简单地将包含变量创建的上述函数包装在 tf.function 中是不允许的。 tf.function 仅支持在第一次调用时 创建单例变量。为了强制执行这一点,当 tf.function 在第一次调用中检测到变量创建时,它将尝试再次跟踪,如果在第二次跟踪中存在变量创建,则会引发错误。

@tf.function
def f():
  print("trace") # This will print twice because the python body is run twice
  v = tf.Variable(1.0)
  return v

try:
  f()
except ValueError as e:
  print(e)

一种解决方法是在第一次调用中创建变量后缓存并重用它。

class Model(tf.Module):
  def __init__(self):
    self.v = None

  @tf.function
  def __call__(self):
    print("trace") # This will print twice because the python body is run twice
    if self.v is None:
      self.v = tf.Variable(0)
    return self.v

m = Model()
m()

示例 2:由于 tf.function 重新追踪导致的超出范围张量

如示例 1 所示,tf.function 在检测到第一次调用中创建变量时将重新追踪。这可能会造成额外的困惑,因为两次追踪将创建两个图。当来自重新追踪的第二个图尝试访问第一次追踪期间生成的图中的张量时,Tensorflow 将引发错误,抱怨该张量超出范围。为了演示这种情况,以下代码在第一次 tf.function 调用中创建数据集。这将按预期运行。

class Model(tf.Module):
  def __init__(self):
    self.dataset = None

  @tf.function
  def __call__(self):
    print("trace") # This will print once: only traced once
    if self.dataset is None:
      self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
    it = iter(self.dataset)
    return next(it)

m = Model()
m()

但是,如果我们也尝试在第一次 tf.function 调用中创建变量,代码将引发错误,抱怨数据集超出范围。这是因为数据集在第一个图中,而第二个图也试图访问它。

class Model(tf.Module):
  def __init__(self):
    self.v = None
    self.dataset = None

  @tf.function
  def __call__(self):
    print("trace") # This will print twice because the python body is run twice
    if self.v is None:
      self.v = tf.Variable(0)
    if self.dataset is None:
      self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
    it = iter(self.dataset)
    return [self.v, next(it)]

m = Model()
try:
  m()
except TypeError as e:
  print(e) # <tf.Tensor ...> is out of scope and cannot be used here.

最直接的解决方案是确保变量创建和数据集创建都在 tf.function 调用之外。例如

class Model(tf.Module):
  def __init__(self):
    self.v = None
    self.dataset = None

  def initialize(self):
    if self.dataset is None:
      self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
    if self.v is None:
      self.v = tf.Variable(0)

  @tf.function
  def __call__(self):
    it = iter(self.dataset)
    return [self.v, next(it)]

m = Model()
m.initialize()
m()

但是,有时在 tf.function 中创建变量是不可避免的(例如,某些 TF keras 优化器 中的槽变量)。尽管如此,我们仍然可以简单地将数据集创建移到 tf.function 调用之外。我们可以依赖这一点的原因是,tf.function 将接收数据集作为隐式输入,并且两个图都可以正确访问它。

class Model(tf.Module):
  def __init__(self):
    self.v = None
    self.dataset = None

  def initialize(self):
    if self.dataset is None:
      self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])

  @tf.function
  def __call__(self):
    if self.v is None:
      self.v = tf.Variable(0)
    it = iter(self.dataset)
    return [self.v, next(it)]

m = Model()
m.initialize()
m()

示例 3:由于字典使用导致的意外 Tensorflow 对象重新创建

tf.function 对 Python 副作用(例如追加到列表或检查/添加到字典)的支持非常差。更多详细信息请参见 "使用 tf.function 提高性能"。在下面的示例中,代码使用字典来缓存数据集和迭代器。对于相同的键,每次调用模型都会返回数据集的相同迭代器。

class Model(tf.Module):
  def __init__(self):
    self.datasets = {}
    self.iterators = {}

  def __call__(self, key):
    if key not in self.datasets:
      self.datasets[key] = tf.compat.v1.data.Dataset.from_tensor_slices([1, 2, 3])
      self.iterators[key] = self.datasets[key].make_initializable_iterator()
    return self.iterators[key]

with tf.Graph().as_default():
  with tf.compat.v1.Session() as sess:
    m = Model()
    it = m('a')
    sess.run(it.initializer)
    for _ in range(3):
      print(sess.run(it.get_next())) # prints 1, 2, 3

但是,上述模式在 tf.function 中不会按预期工作。在追踪期间,tf.function 将忽略添加到字典的 Python 副作用。相反,它只记住创建新的数据集和迭代器。因此,每次调用模型都会始终返回一个新的迭代器。除非数值结果或性能足够显著,否则很难注意到这个问题。因此,我们建议用户在将 tf.function 随意包装到 Python 代码之前,仔细考虑代码。

class Model(tf.Module):
  def __init__(self):
    self.datasets = {}
    self.iterators = {}

  @tf.function
  def __call__(self, key):
    if key not in self.datasets:
      self.datasets[key] = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      self.iterators[key] = iter(self.datasets[key])
    return self.iterators[key]

m = Model()
for _ in range(3):
  print(next(m('a'))) # prints 1, 1, 1

我们可以使用 tf.init_scope 将数据集和迭代器创建提升到图之外,以实现预期的行为

class Model(tf.Module):
  def __init__(self):
    self.datasets = {}
    self.iterators = {}

  @tf.function
  def __call__(self, key):
    if key not in self.datasets:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.datasets[key] = tf.data.Dataset.from_tensor_slices([1, 2, 3])
        self.iterators[key] = iter(self.datasets[key])
    return self.iterators[key]

m = Model()
for _ in range(3):
  print(next(m('a'))) # prints 1, 2, 3

一般经验法则是避免在逻辑中依赖 Python 副作用,只将它们用于调试跟踪。

示例 4:操作全局 Python 列表

以下 TF1.x 代码使用一个全局损失列表,它用于仅维护当前训练步骤生成的损失列表。请注意,无论会话运行多少个训练步骤,追加损失到列表的 Python 逻辑只会被调用一次。

all_losses = []

class Model():
  def __call__(...):
    ...
    all_losses.append(regularization_loss)
    all_losses.append(label_loss_a)
    all_losses.append(label_loss_b)
    ...

g = tf.Graph()
with g.as_default():
  ...
  # initialize all objects
  model = Model()
  optimizer = ...
  ...
  # train step
  model(...)
  total_loss = tf.reduce_sum(all_losses)
  optimizer.minimize(total_loss)
  ...
...
sess = tf.compat.v1.Session(graph=g)
sess.run(...)  

但是,如果将此 Python 逻辑随意映射到具有急切执行的 TF2,则全局损失列表将在每个训练步骤中追加新值。这意味着以前预期列表只包含当前训练步骤的损失的训练步骤代码现在实际上看到了迄今为止运行的所有训练步骤的损失列表。这是一种意外的行为变化,需要在每个步骤开始时清除列表,或者将其设为训练步骤的局部变量。

all_losses = []

class Model():
  def __call__(...):
    ...
    all_losses.append(regularization_loss)
    all_losses.append(label_loss_a)
    all_losses.append(label_loss_b)
    ...

# initialize all objects
model = Model()
optimizer = ...

def train_step(...)
  ...
  model(...)
  total_loss = tf.reduce_sum(all_losses) # global list is never cleared,
  # Accidentally accumulates sum loss across all training steps
  optimizer.minimize(total_loss)
  ...

模式 2:在切换到急切执行时,意外地将意图在每个步骤中重新计算的符号张量与初始值一起缓存。

这种模式通常会导致您的代码在急切执行时在 tf.functions 之外静默地出现错误行为,但在 tf.function 内部发生初始值缓存时会引发 InaccessibleTensorError。但是,请注意,为了避免上述 模式 1,您通常会无意中以这样一种方式构建代码,即这种初始值缓存将在任何能够引发错误的 tf.function 之外发生。因此,如果您知道您的程序可能容易受到这种模式的影响,请格外小心。

解决此模式的通用方法是重新构建代码或在必要时使用 Python 可调用对象,以确保每次都重新计算值,而不是意外地将其缓存。

示例 1:依赖于全局步骤的学习率/超参数/等调度

在以下代码片段中,预期每次运行会话时都会读取最新的 global_step 值,并计算新的学习率。

g = tf.Graph()
with g.as_default():
  ...
  global_step = tf.Variable(0)
  learning_rate = 1.0 / global_step
  opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
  ...
  global_step.assign_add(1)
...
sess = tf.compat.v1.Session(graph=g)
sess.run(...)

但是,在尝试切换到急切执行时,请注意最终可能只计算一次学习率,然后重复使用,而不是按照预期的调度进行。

global_step = tf.Variable(0)
learning_rate = 1.0 / global_step # Wrong! Only computed once!
opt = tf.keras.optimizers.SGD(learning_rate)

def train_step(...):
  ...
  opt.apply_gradients(...)
  global_step.assign_add(1)
  ...

由于此特定示例是一种常见模式,并且优化器应该只初始化一次,而不是在每个训练步骤中初始化,因此 TF2 优化器支持 tf.keras.optimizers.schedules.LearningRateSchedule 调度或 Python 可调用对象作为学习率和其他超参数的参数。

示例 2:作为对象属性分配的符号随机数初始化,然后通过指针重复使用,在切换到急切执行时被意外地缓存

考虑以下 NoiseAdder 模块

class NoiseAdder(tf.Module):
  def __init__(shape, mean):
    self.noise_distribution = tf.random.normal(shape=shape, mean=mean)
    self.trainable_scale = tf.Variable(1.0, trainable=True)

  def add_noise(input):
    return (self.noise_distribution + input) * self.trainable_scale

在 TF1.x 中按如下方式使用它,将每次运行会话时计算一个新的随机噪声张量

g = tf.Graph()
with g.as_default():
  ...
  # initialize all variable-containing objects
  noise_adder = NoiseAdder(shape, mean)
  ...
  # computation pass
  x_with_noise = noise_adder.add_noise(x)
  ...
...
sess = tf.compat.v1.Session(graph=g)
sess.run(...)

但是,在 TF2 中,在开始时初始化 noise_adder 将导致 noise_distribution 只计算一次,并对所有训练步骤冻结

...
# initialize all variable-containing objects
noise_adder = NoiseAdder(shape, mean) # Freezes `self.noise_distribution`!
...
# computation pass
x_with_noise = noise_adder.add_noise(x)
...

为了解决这个问题,重构 NoiseAdder,以便每次需要新的随机张量时都调用 tf.random.normal,而不是每次都引用同一个张量对象。

class NoiseAdder(tf.Module):
  def __init__(shape, mean):
    self.noise_distribution = lambda: tf.random.normal(shape=shape, mean=mean)
    self.trainable_scale = tf.Variable(1.0, trainable=True)

  def add_noise(input):
    return (self.noise_distribution() + input) * self.trainable_scale

模式 3:TF1.x 代码直接依赖于并按名称查找张量

TF1.x 代码测试通常依赖于检查图中存在哪些张量或操作。在一些罕见的情况下,建模代码也将依赖于这些按名称的查找。

在急切执行时,在 tf.function 之外根本不会生成张量名称,因此所有 tf.Tensor.name 的使用必须发生在 tf.function 内部。请记住,即使在同一个 tf.function 中,实际生成的名称也很可能在 TF1.x 和 TF2 之间有所不同,并且 API 保证不确保跨 TF 版本生成的名称的稳定性。

模式 4:TF1.x 会话选择性地仅运行生成图的一部分

在 TF1.x 中,您可以构建一个图,然后通过选择一组不需要运行图中每个操作的输入和输出,选择性地仅运行图的子集。

例如,您可能在一个图中同时拥有生成器和鉴别器,并使用单独的 tf.compat.v1.Session.run 调用在仅训练鉴别器或仅训练生成器之间交替。

在 TF2 中,由于 tf.function 和急切执行中的自动控制依赖关系,不会对 tf.function 跟踪进行选择性修剪。即使例如,只有鉴别器或生成器的输出从 tf.function 输出,也会运行包含所有变量更新的完整图。

因此,您需要使用多个包含程序不同部分的 tf.function,或者使用 tf.function 的条件参数,您可以在该参数上进行分支,以便只执行您实际想要运行的内容。

集合移除

启用急切执行时,与图集合相关的 compat.v1 API(包括在幕后读取或写入集合的 API,例如 tf.compat.v1.trainable_variables)不再可用。有些可能会引发 ValueError,而另一些可能会静默地返回空列表。

TF1.x 中集合的最标准用法是维护初始化器、全局步骤、权重、正则化损失、模型输出损失以及需要运行的变量更新(例如来自 BatchNormalization 层)。

要处理这些标准用法中的每一个

  1. 初始化器 - 忽略。启用急切执行时,不需要手动变量初始化。
  2. 全局步长 - 请参阅 tf.compat.v1.train.get_or_create_global_step 的文档以获取迁移说明。
  3. 权重 - 按照 模型映射指南 中的指导,将您的模型映射到 tf.Modules/tf.keras.layers.Layers/tf.keras.Models,然后使用它们各自的权重跟踪机制,例如 tf.module.trainable_variables
  4. 正则化损失 - 按照 模型映射指南 中的指导,将您的模型映射到 tf.Modules/tf.keras.layers.Layers/tf.keras.Models,然后使用 tf.keras.losses。或者,您也可以手动跟踪您的正则化损失。
  5. 模型输出损失 - 使用 tf.keras.Model 损失管理机制,或者单独跟踪您的损失,而无需使用集合。
  6. 权重更新 - 忽略此集合。Eager 执行和 tf.function(使用 autograph 和 auto-control-dependencies)意味着所有变量更新将自动运行。因此,您不必在最后显式运行所有权重更新,但请注意,这意味着权重更新可能发生在与 TF1.x 代码中不同的时间,具体取决于您如何使用控制依赖项。
  7. 摘要 - 请参阅 迁移摘要 API 指南

更复杂的集合使用(例如使用自定义集合)可能需要您重构代码,以维护自己的全局存储,或者使其完全不依赖于全局存储。

ResourceVariables 而不是 ReferenceVariables

ResourceVariablesReferenceVariables 具有更强的读写一致性保证。这将导致关于您在使用变量时是否会观察到先前写入结果的更可预测、更容易理解的语义。这种变化极不可能导致现有代码引发错误或静默失败。

但是,虽然不太可能,但这些更强的一致性保证可能会增加特定程序的内存使用量。如果您发现这种情况,请提交 问题。此外,如果您有依赖于与对应于变量读取的图中的运算符名称进行精确字符串比较的单元测试,请注意,启用资源变量可能会稍微改变这些运算符的名称。

为了隔离这种行为变化对您代码的影响,如果禁用了 Eager 执行,您可以使用 tf.compat.v1.disable_resource_variables()tf.compat.v1.enable_resource_variables() 来全局禁用或启用这种行为变化。如果启用了 Eager 执行,将始终使用 ResourceVariables

控制流 v2

在 TF1.x 中,控制流运算符(例如 tf.condtf.while_loop)会内联低级运算符,例如 SwitchMerge 等。TF2 提供了改进的功能控制流运算符,这些运算符使用单独的 tf.function 跟踪来实现每个分支,并支持高阶微分。

为了隔离这种行为变化对您代码的影响,如果禁用了 Eager 执行,您可以使用 tf.compat.v1.disable_control_flow_v2()tf.compat.v1.enable_control_flow_v2() 来全局禁用或启用这种行为变化。但是,您只能在也禁用了 Eager 执行的情况下禁用控制流 v2。如果启用了它,将始终使用控制流 v2。

这种行为变化可以极大地改变使用控制流的生成 TF 程序的结构,因为它们将包含多个嵌套的函数跟踪,而不是一个扁平的图。因此,任何高度依赖于生成跟踪的精确语义的代码可能都需要一些修改。这包括

  • 依赖于运算符和张量名称的代码
  • 从 TensorFlow 控制流分支外部引用在该分支内创建的张量的代码。这可能会产生 InaccessibleTensorError

这种行为变化旨在对性能保持中立或积极,但如果您遇到控制流 v2 的性能比 TF1.x 控制流差的问题,请提交 问题,并附上重现步骤。

TensorShape API 行为变化

TensorShape 类已简化为保存 int,而不是 tf.compat.v1.Dimension 对象。因此,无需调用 .value 来获取 int

仍然可以通过 tf.TensorShape.dims 访问单个 tf.compat.v1.Dimension 对象。

为了隔离这种行为变化对您代码的影响,您可以使用 tf.compat.v1.disable_v2_tensorshape()tf.compat.v1.enable_v2_tensorshape() 来全局禁用或启用这种行为变化。

以下演示了 TF1.x 和 TF2 之间的差异。

import tensorflow as tf
# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape

如果您在 TF1.x 中有以下内容

value = shape[i].value

那么在 TF2 中执行以下操作

value = shape[i]
value

如果您在 TF1.x 中有以下内容

for dim in shape:
    value = dim.value
    print(value)

然后,在 TF2 中执行以下操作

for value in shape:
  print(value)

如果您在 TF1.x 中有以下内容(或使用任何其他维度方法)

dim = shape[i]
dim.assert_is_compatible_with(other_dim)

那么在 TF2 中执行以下操作

other_dim = 16
Dimension = tf.compat.v1.Dimension

if shape.rank is None:
  dim = Dimension(None)
else:
  dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
shape = tf.TensorShape(None)

if shape:
  dim = shape.dims[i]
  dim.is_compatible_with(other_dim) # or any other dimension method

tf.TensorShape 的布尔值如果秩已知则为 True,否则为 False

print(bool(tf.TensorShape([])))      # Scalar
print(bool(tf.TensorShape([0])))     # 0-length vector
print(bool(tf.TensorShape([1])))     # 1-length vector
print(bool(tf.TensorShape([None])))  # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100])))       # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None)))  # A tensor with unknown rank.

由于 TensorShape 更改而导致的潜在错误

TensorShape 行为变化不太可能静默地破坏您的代码。但是,您可能会看到与形状相关的代码开始引发 AttributeError,因为 intNone 没有与 tf.compat.v1.Dimension 相同的属性。以下是一些这些 AttributeError 的示例

try:
  # Create a shape and choose an index
  shape = tf.TensorShape([16, None, 256])
  value = shape[0].value
except AttributeError as e:
  # 'int' object has no attribute 'value'
  print(e)
try:
  # Create a shape and choose an index
  shape = tf.TensorShape([16, None, 256])
  dim = shape[1]
  other_dim = shape[2]
  dim.assert_is_compatible_with(other_dim)
except AttributeError as e:
  # 'NoneType' object has no attribute 'assert_is_compatible_with'
  print(e)

按值比较张量

在 TF2 中,变量和张量上的二元 ==!= 运算符已更改为按值比较,而不是像 TF1.x 中那样按对象引用比较。此外,张量和变量不再直接可哈希或在集合或字典键中使用,因为可能无法按值对它们进行哈希。相反,它们公开了一个 .ref() 方法,您可以使用它来获取对张量或变量的可哈希引用的。

为了隔离这种行为变化的影响,您可以使用 tf.compat.v1.disable_tensor_equality()tf.compat.v1.enable_tensor_equality() 来全局禁用或启用这种行为变化。

例如,在 TF1.x 中,当您使用 == 运算符时,两个具有相同值的变量将返回 false

tf.compat.v1.disable_tensor_equality()
x = tf.Variable(0.0)
y = tf.Variable(0.0)

x == y

而在 TF2 中,如果启用了张量相等性检查,x == y 将返回 True

tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)
y = tf.Variable(0.0)

x == y

因此,在 TF2 中,如果您需要按对象引用进行比较,请确保使用 isis not

tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)
y = tf.Variable(0.0)

x is y

哈希张量和变量

使用 TF1.x 行为,您过去可以直接将变量和张量添加到需要哈希的数据结构中,例如 setdict 键。

x = tf.Variable(0.0)
set([x, tf.constant(2.0)])

但是,在 TF2 中,如果启用了张量相等性,张量和变量将变得不可哈希,因为 ==!= 运算符的语义已更改为值相等性检查。

tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)

try:
  set([x, tf.constant(2.0)])
except TypeError as e:
  # TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
  print(e)

因此,在 TF2 中,如果您需要将张量或变量对象用作键或 set 内容,您可以使用 tensor.ref() 来获取可哈希引用,该引用可以用作键

tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)

tensor_set = set([x.ref(), tf.constant(2.0).ref()])
assert x.ref() in tensor_set

tensor_set

如果需要,您还可以使用 reference.deref() 从引用中获取张量或变量

referenced_var = x.ref().deref()
assert referenced_var is x
referenced_var

资源和进一步阅读

  • 访问 迁移到 TF2 部分,以详细了解如何从 TF1.x 迁移到 TF2。
  • 阅读 模型映射指南,了解有关将 TF1.x 模型直接映射到 TF2 中工作的更多信息。