在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
在 TensorFlow 2 中,急切执行 默认情况下处于启用状态。用户界面直观且灵活(运行一次性操作更容易、更快),但这可能会以性能和可部署性为代价。
您可以使用 tf.function
从您的程序中创建图。它是一个转换工具,可以从您的 Python 代码中创建独立于 Python 的数据流图。这将帮助您创建高性能且可移植的模型,并且它是使用 SavedModel
所必需的。
本指南将帮助您了解 tf.function
在幕后的工作原理,以便您可以有效地使用它。
主要要点和建议是
- 在急切模式下调试,然后使用
@tf.function
进行装饰。 - 不要依赖 Python 副作用,例如对象变异或列表追加。
tf.function
最适合 TensorFlow 操作;NumPy 和 Python 调用将转换为常量。
设置
import tensorflow as tf
2023-11-28 02:22:39.038158: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-28 02:22:39.038208: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-28 02:22:39.039647: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
定义一个辅助函数来演示您可能会遇到的错误类型
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
基础
用法
您定义的 tf.function
(例如,通过应用 @tf.function
装饰器)就像一个核心 TensorFlow 操作:您可以急切地执行它;您可以计算梯度;等等。
@tf.function # The decorator converts `add` into a `PolymorphicFunction`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
您可以在其他 tf.function
中使用 tf.function
。
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
tf.function
可能比急切代码更快,尤其是对于具有许多小型操作的图。但是,对于具有少量昂贵操作(如卷积)的图,您可能看不到太多加速。
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.005143817999964995 Function conv: 0.005329717999984496 Note how there's not much difference in performance for convolutions
追踪
本节介绍了 tf.function
在幕后的工作原理,包括将来可能会更改的实现细节。但是,一旦您了解了追踪的原因和时机,就可以更有效地使用 tf.function
!
什么是“追踪”?
一个 tf.function
在 TensorFlow 图 中运行您的程序。但是,tf.Graph
无法表示您在急切的 TensorFlow 程序中编写的所有内容。例如,Python 支持多态性,但 tf.Graph
要求其输入具有指定的数据类型和维度。或者,您可能会执行诸如读取命令行参数、引发错误或使用更复杂的 Python 对象之类的辅助任务;这些事情都不能在 tf.Graph
中运行。
tf.function
通过将您的代码分为两个阶段来弥合这一差距
1) 在第一个阶段,称为“追踪”,tf.function
创建一个新的 tf.Graph
。Python 代码正常运行,但所有 TensorFlow 操作(如将两个张量相加)都延迟:它们被 tf.Graph
捕获,而不是运行。
2) 在第二阶段,运行包含在第一阶段延迟的所有内容的 tf.Graph
。此阶段比追踪阶段快得多。
根据其输入,tf.function
并不总是会在被调用时运行第一个阶段。请参阅下面的 “追踪规则”,以更好地了解它是如何做出该决定的。跳过第一个阶段并仅执行第二个阶段是您获得 TensorFlow 高性能的原因。
当 tf.function
决定进行追踪时,追踪阶段会紧随其后地进行第二阶段,因此调用 tf.function
既会创建也会运行 tf.Graph
。稍后您将看到如何仅使用 get_concrete_function
运行追踪阶段。
当您将不同类型的参数传递给 tf.function
时,两个阶段都会运行
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
请注意,如果您使用相同的参数类型重复调用 tf.function
,TensorFlow 将跳过追踪阶段并重用先前追踪的图,因为生成的图将是相同的。
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
您可以使用 pretty_printed_concrete_signatures()
查看所有可用的追踪
print(double.pretty_printed_concrete_signatures())
Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None) Output Type: TensorSpec(shape=(), dtype=tf.int32, name=None) Captures: None Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None) Output Type: TensorSpec(shape=(), dtype=tf.float32, name=None) Captures: None Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None) Output Type: TensorSpec(shape=(), dtype=tf.string, name=None) Captures: None
到目前为止,您已经看到 tf.function
在 TensorFlow 的图追踪逻辑之上创建了一个缓存的动态调度层。为了更具体地说明术语
- 一个
tf.Graph
是 TensorFlow 计算的原始、与语言无关的可移植表示。 - 追踪是从 Python 代码生成新的
tf.Graph
的过程。 tf.Graph
的实例专门针对它所追踪的特定输入类型。不同的类型需要重新追踪。- 每个追踪的
tf.Graph
都有一个对应的ConcreteFunction
。 - 一个
tf.function
管理ConcreteFunction
的缓存,并为您的输入选择合适的那个。 tf.function
包装将要追踪的 Python 函数,返回一个tf.types.experimental.PolymorphicFunction
对象。
追踪规则
当被调用时,tf.function
首先使用每个参数的 tf.types.experimental.TraceType
评估每个输入参数的类型。这用于构建一个 tf.types.experimental.FunctionType
,描述所需 ConcreteFunction
的签名。我们将此 FunctionType
与现有 ConcreteFunction
的 FunctionType
进行比较。如果找到匹配的 ConcreteFunction
,则调用将被分派给它。如果没有找到匹配项,则会为所需的 FunctionType
追踪一个新的 ConcreteFunction
。
如果找到多个匹配项,则选择最具体的签名。匹配是通过 子类型化 完成的,就像 C++ 或 Java 中的普通函数调用一样,例如,TensorShape([1, 2])
是 TensorShape([None, None])
的子类型,因此对使用 TensorShape([1, 2])
的 tf.function 的调用可以分派给使用 TensorShape([None, None])
生成的 ConcreteFunction
,但如果一个使用 TensorShape([1, None])
的 ConcreteFunction
也存在,那么它将被优先考虑,因为它更具体。
TraceType
是从输入参数中确定的,如下所示
- 对于
Tensor
,类型由Tensor
的dtype
和shape
参数化;排名形状是未排名形状的子类型;固定维度是未知维度的子类型 - 对于
Variable
,类型类似于Tensor
,但也包括变量的唯一资源 ID,这对于正确连接控制依赖项是必要的 - 对于 Python 原生值,类型对应于值本身。例如,值
3
的TraceType
是LiteralTraceType<3>
,而不是int
。 - 对于 Python 有序容器,例如
list
和tuple
等,类型由其元素的类型参数化;例如,[1, 2]
的类型是ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
,而[2, 1]
的类型是ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
,这是不同的。 - 对于 Python 映射,例如
dict
,类型也是从相同的键到值的类型而不是实际值的映射。例如,{1: 2, 3: 4}
的类型是MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
。但是,与有序容器不同,{1: 2, 3: 4}
和{3: 4, 1: 2}
具有等效类型。 - 对于实现
__tf_tracing_type__
方法的 Python 对象,类型是该方法返回的任何内容。 对于任何其他 Python 对象,类型都是通用的
TraceType
,匹配过程是- 首先它检查对象是否与先前追踪中使用的对象相同(使用 Python
id()
或is
)。请注意,即使对象已更改,这仍然会匹配,因此,如果您将 Python 对象用作tf.function
参数,最好使用不可变的。 - 接下来,它检查对象是否等于先前追踪中使用的对象(使用 Python
==
)。
请注意,此过程仅对对象保留一个 弱引用,因此仅在对象在范围内/未被删除时才有效。
- 首先它检查对象是否与先前追踪中使用的对象相同(使用 Python
控制重新追踪
重新追踪(即您的 tf.function
创建多个追踪)有助于确保 TensorFlow 为每组输入生成正确的图。但是,追踪是一个昂贵的操作!如果您的 tf.function
为每次调用重新追踪一个新图,您会发现您的代码执行速度比不使用 tf.function
时慢。
要控制追踪行为,您可以使用以下技术
将固定的 input_signature
传递给 tf.function
这会强制 tf.function
将自身限制为仅一个由 input_signature
枚举的类型组成的 tf.types.experimental.FunctionType
。无法分派到此 FunctionType
的调用将引发错误。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'TypeError'>: Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3657259638.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2, 2), dtype=int32, numpy= array([[1, 2], [3, 4]], dtype=int32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3657259638.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
使用未知维度以获得灵活性
由于 TensorFlow 根据张量的形状进行匹配,因此使用 None
维度作为通配符将允许 tf.function
重用不同大小输入的追踪。如果您有不同长度的序列,或者每个批次的图像大小不同,就会出现不同大小的输入。您可以查看 Transformer 和 Deep Dream 教程以获取示例。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
使用 reduce_retracing
以获得自动灵活性
当启用 reduce_retracing
时,tf.function
会自动识别输入类型的超类型,并自动选择跟踪更通用的图。它不如直接设置 input_signature
效率高,但在需要支持多种类型时很有用。
@tf.function(reduce_retracing=True)
def g(x):
print('Tracing with', x)
return x
# Traces once.
print(g(tf.constant([1, 2, 3])))
# Traces again, but more generalized this time.
print(g(tf.constant([1, 2, 3, 4, 5])))
# No more tracing!
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7])))
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])))
Tracing with Tensor("x:0", shape=(3,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32) tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32) tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)
传递张量而不是 Python 字面量
通常,Python 参数用于控制超参数和图构建 - 例如,num_layers=10
或 training=True
或 nonlinearity='relu'
。因此,如果 Python 参数发生变化,您可能需要重新跟踪图。
但是,Python 参数可能没有用于控制图构建。在这些情况下,Python 值的变化可能会触发不必要的重新跟踪。例如,以下训练循环,AutoGraph 会动态展开它。尽管有多个跟踪,但生成的图实际上是相同的,因此重新跟踪是不必要的。
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
如果您需要强制重新跟踪,请创建一个新的 tf.function
。单独的 tf.function
对象保证不会共享跟踪。
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
使用跟踪协议
在可能的情况下,您应该优先将 Python 类型转换为 tf.experimental.ExtensionType
。此外,ExtensionType
的 TraceType
是与其关联的 tf.TypeSpec
。因此,如果需要,您可以简单地覆盖默认的 tf.TypeSpec
来控制 ExtensionType
的 Tracing Protocol
。有关详细信息,请参阅 扩展类型 指南中的“自定义 ExtensionType 的 TypeSpec”部分。
否则,要直接控制 tf.function
何时针对特定 Python 类型进行重新跟踪,您可以自己为其实现 Tracing Protocol
。
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
return fruit_a.flavor + fruit_b.flavor
class Fruit:
flavor = tf.constant([0, 0])
class Apple(Fruit):
flavor = tf.constant([1, 2])
class Mango(Fruit):
flavor = tf.constant([3, 4])
# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.
class FruitTraceType(tf.types.experimental.TraceType):
def __init__(self, fruit):
self.fruit_type = type(fruit)
self.fruit_value = fruit
def is_subtype_of(self, other):
# True if self subtypes `other` and `other`'s type matches FruitTraceType.
return (type(other) is FruitTraceType and
self.fruit_type is other.fruit_type)
def most_specific_common_supertype(self, others):
# `self` is the specific common supertype if all input types match it.
return self if all(self == other for other in others) else None
def placeholder_value(self, placeholder_context=None):
# Use the fruit itself instead of the type for correct tracing.
return self.fruit_value
def __eq__(self, other):
return type(other) is FruitTraceType and self.fruit_type == other.fruit_type
def __hash__(self):
return hash(self.fruit_type)
class FruitWithTraceType:
def __tf_tracing_type__(self, context):
return FruitTraceType(self)
class AppleWithTraceType(FruitWithTraceType):
flavor = tf.constant([1, 2])
class MangoWithTraceType(FruitWithTraceType):
flavor = tf.constant([3, 4])
# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>
获取具体函数
每次跟踪函数时,都会创建一个新的具体函数。您可以使用 get_concrete_function
直接获取具体函数。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
打印 ConcreteFunction
会显示其输入参数(带类型)和输出类型的摘要。
print(double_strings)
ConcreteFunction Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None) Output Type: TensorSpec(shape=(), dtype=tf.string, name=None) Captures: None
您也可以直接检索具体函数的签名。
print(double_strings.function_type)
(a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None)
使用与类型不兼容的具体跟踪将引发错误
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs bound_arguments = function_type.bind_with_defaults( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults with_default_args[arg_name] = constraint.cast( TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None) The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl return self._call_with_structured_signature(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature function_type_utils.canonicalize_function_inputs( TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (<tf.Tensor: shape=(), dtype=int32, numpy=1>,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]
您可能会注意到,Python 参数在具体函数的输入签名中得到了特殊处理。在 TensorFlow 2.3 之前,Python 参数只是从具体函数的签名中删除。从 TensorFlow 2.3 开始,Python 参数保留在签名中,但被限制为采用跟踪期间设置的值。
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction Input Parameters: a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=<unknown>, dtype=tf.float32, name=None) b (POSITIONAL_OR_KEYWORD): Literal[2] Output Type: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None) Captures: None
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py", line 442, in bind_function_inputs bound_arguments = function_type.bind_with_defaults( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py", line 277, in bind_with_defaults with_default_args[arg_name] = constraint.cast( ValueError: Can not cast 3 to Literal[2] The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1180, in _call_impl return self._call_with_structured_signature(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1260, in _call_with_structured_signature function_type_utils.canonicalize_function_inputs( TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1183, in _call_impl return self._call_with_flat_signature(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py", line 1234, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (<tf.Tensor: shape=(), dtype=float32, numpy=10.0>,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=<unknown>, dtype=tf.float32, name=None). Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.
获取图
虽然检索实际的 tf.Graph
对象不是您通常需要做的事情,但您可以轻松地从任何具体函数中获取它。
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
实际上,tf.Graph
不能直接调用。我们实际上使用 tf.types.experimental.AtomicFunction
来执行 tf.Graph
描述的计算。您可以访问描述跟踪的 tf.Graph
的 AtomicFunction
,并直接调用它,而不是 ConcreteFunction
atomic_fn = double_strings.inference_fn
atomic_fn(tf.constant("a"))
<tf.Tensor: shape=(), dtype=string, numpy=b'aa'>
这对于高性能场景具有较低的 Python 开销的优势。但它应该只用于前向推理(不支持梯度),并且捕获的张量值(如果有)需要显式提供。
调试
通常,在急切模式下调试代码比在 tf.function
内更容易。您应该确保您的代码在急切模式下执行无错误,然后再用 tf.function
装饰。为了帮助调试过程,您可以调用 tf.config.run_functions_eagerly(True)
来全局禁用和重新启用 tf.function
。
在跟踪仅在 tf.function
内出现的错误时,以下是一些提示
- 普通的 Python
print
调用仅在跟踪期间执行,帮助您跟踪函数何时被(重新)跟踪。 tf.print
调用将在每次执行时执行,并且可以帮助您跟踪执行期间的中间值。tf.debugging.enable_check_numerics
是一种跟踪 NaN 和 Inf 在哪里创建的简单方法。pdb
(Python 调试器)可以帮助您了解跟踪期间发生了什么。(警告:pdb
会将您带入 AutoGraph 转换后的源代码。)
AutoGraph 转换
AutoGraph 是一个库,默认情况下在 tf.function
中启用,并将 Python 急切代码的子集转换为与图兼容的 TensorFlow 操作。这包括控制流,如 if
、for
、while
。
TensorFlow 操作(如 tf.cond
和 tf.while_loop
)继续工作,但控制流通常在用 Python 编写时更容易编写和理解。
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.143583655 0.698347807 0.767881036 0.857545733 0.0981599092] [0.142604977 0.603318036 0.645695686 0.694991 0.0978458375] [0.141646087 0.539406419 0.568765223 0.601178706 0.0975347683] [0.140706316 0.492538482 0.514451861 0.537887812 0.0972266495] [0.139785022 0.456228882 0.473406643 0.491387457 0.0969214365] [0.138881624 0.427005619 0.440947741 0.455316931 0.096619077] [0.137995526 0.402815819 0.414429694 0.426259726 0.0963195339] [0.137126192 0.38235572 0.392227381 0.402190775 0.0960227624] [0.136273116 0.364751518 0.373278826 0.381821901 0.0957287177] [0.135435775 0.349392414 0.356856316 0.364288628 0.0954373628] [0.134613693 0.335836589 0.342441976 0.34898597 0.0951486528] [0.133806422 0.323755413 0.329655707 0.335475951 0.094862543] [0.133013532 0.312898636 0.318211377 0.323432535 0.0945790112] [0.132234573 0.303071797 0.307888716 0.312607288 0.0942980051] [0.13146916 0.294121206 0.298515141 0.302807152 0.0940194875] [0.13071692 0.285923541 0.289953142 0.29387942 0.0937434211] [0.12997745 0.278378516 0.282091677 0.285701483 0.0934697762] [0.129250407 0.2714037 0.274839878 0.278173655 0.0931985155] [0.12853545 0.264930487 0.268122554 0.271213919 0.0929296] [0.127832219 0.258901358 0.261877 0.264754 0.092663005] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.12714042, 0.25326762, 0.2560503 , 0.25873667, 0.0923987 ], dtype=float32)>
如果您好奇,您可以检查 AutoGraph 生成的代码。
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1 ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
条件语句
AutoGraph 会将一些 if <condition>
语句转换为等效的 tf.cond
调用。如果 <condition>
是一个张量,则进行此替换。否则,if
语句将作为 Python 条件语句执行。
Python 条件语句在跟踪期间执行,因此条件语句的恰好一个分支将被添加到图中。如果没有 AutoGraph,如果存在数据相关的控制流,则此跟踪的图将无法采取备用分支。
tf.cond
会跟踪并添加条件语句的两个分支到图中,在执行时动态选择一个分支。跟踪可能会产生意想不到的副作用;查看 AutoGraph 跟踪效果 以获取更多信息。
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
有关 AutoGraph 转换的 if 语句的更多限制,请参阅 参考文档。
循环
AutoGraph 会将一些 for
和 while
语句转换为等效的 TensorFlow 循环操作,如 tf.while_loop
。如果未转换,则 for
或 while
循环将作为 Python 循环执行。
此替换在以下情况下进行
for x in y
:如果y
是一个张量,则转换为tf.while_loop
。在y
是tf.data.Dataset
的特殊情况下,会生成tf.data.Dataset
操作的组合。while <condition>
:如果<condition>
是一个张量,则转换为tf.while_loop
。
Python 循环在跟踪期间执行,为循环的每次迭代向 tf.Graph
添加额外的操作。
TensorFlow 循环会跟踪循环体,并在执行时动态选择要运行的迭代次数。循环体在生成的 tf.Graph
中只出现一次。
有关 AutoGraph 转换的 for
和 while
语句的更多限制,请参阅 参考文档。
遍历 Python 数据
一个常见的陷阱是在 tf.function
中遍历 Python/NumPy 数据。此循环将在跟踪过程中执行,为循环的每次迭代向 tf.Graph
添加模型的副本。
如果您想将整个训练循环包装在 tf.function
中,最安全的方法是将您的数据包装为 tf.data.Dataset
,以便 AutoGraph 可以动态展开训练循环。
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
在将 Python/NumPy 数据包装在 Dataset 中时,请注意 tf.data.Dataset.from_generator
与 tf.data.Dataset.from_tensor_slices
。前者会将数据保留在 Python 中,并通过 tf.py_function
获取它,这可能会影响性能,而后者会将数据的副本捆绑为图中的一个大型 tf.constant()
节点,这可能会影响内存。
通过 TFRecordDataset
、CsvDataset
等从文件读取数据是使用数据的最有效方法,因为这样 TensorFlow 本身可以管理数据的异步加载和预取,而无需涉及 Python。要了解更多信息,请参阅 tf.data
:构建 TensorFlow 输入管道 指南。
在循环中累积值
一个常见的模式是累积循环中的中间值。通常,这是通过将值追加到 Python 列表或将条目添加到 Python 字典来实现的。但是,由于这些是 Python 副作用,它们在动态展开循环中不会按预期工作。使用 tf.TensorArray
累积动态展开循环的结果。
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.36914825, 0.64223015, 0.7850807 , 0.19980955], [0.4491886 , 1.4083506 , 1.0351617 , 0.20833313], [0.7401295 , 1.7583194 , 1.1593785 , 0.32083678]], [[0.09818649, 0.09965849, 0.28532243, 0.2933966 ], [0.56936 , 1.0815177 , 0.7327199 , 0.6250684 ], [1.5300128 , 1.22948 , 0.8870441 , 0.770558 ]]], dtype=float32)>
限制
tf.function
在设计上有一些限制,在将 Python 函数转换为 tf.function
时,您应该注意这些限制。
执行 Python 副作用
副作用,例如打印、追加到列表和修改全局变量,在 tf.function
中的行为可能出乎意料,有时会执行两次或根本不执行。它们只发生在您第一次使用一组输入调用 tf.function
时。之后,跟踪的 tf.Graph
会重新执行,而不会执行 Python 代码。
一般经验法则是避免在您的逻辑中依赖 Python 副作用,只将它们用于调试您的跟踪。否则,像 tf.data
、tf.print
、tf.summary
、tf.Variable.assign
和 tf.TensorArray
这样的 TensorFlow API 是确保您的代码在每次调用时都由 TensorFlow 运行时执行的最佳方法。
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
如果您想在每次调用 tf.function
时执行 Python 代码,tf. py_function
是一个出口。tf.py_function
的缺点是它不可移植,性能也不高,无法使用 SavedModel
保存,并且在分布式(多 GPU、TPU)设置中效果不佳。此外,由于 tf.py_function
必须连接到图中,它会将所有输入/输出转换为张量。
@tf.py_function(Tout=tf.float32)
def py_plus(x, y):
print('Executing eagerly.')
return x + y
@tf.function
def tf_wrapper(x, y):
print('Tracing.')
return py_plus(x, y)
tf.function
将在第一次调用时进行跟踪
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Tracing. Executing eagerly. 3.0
但内部的 tf.py_function
每次都会急切地执行
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Executing eagerly. 3.0
更改 Python 全局变量和自由变量
更改 Python 全局变量和 自由变量 被视为 Python 副作用,因此它只发生在跟踪期间。
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
有时,意外行为很难注意到。在下面的示例中,counter
旨在保护变量的增量。但是,因为它是一个 Python 整数而不是 TensorFlow 对象,它的值在第一次跟踪时被捕获。当使用 tf.function
时,assign_add
将在底层图中无条件地记录。因此,每次调用 tf.function
时,v
都会增加 1。这个问题在尝试使用 tf.function
装饰器将他们的图模式 Tensorflow 代码迁移到 Tensorflow 2 的用户中很常见,当使用 Python 副作用(示例中的 counter
)来确定要运行的操作(示例中的 assign_add
)时。通常,用户只有在看到可疑的数值结果或性能明显低于预期(例如,如果被保护的操作非常昂贵)时才会意识到这一点。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
实现预期行为的一种解决方法是使用 tf.init_scope
将操作提升到函数图之外。这确保变量增量只在跟踪时间进行一次。需要注意的是,init_scope
还有其他副作用,包括清除控制流和梯度带。有时,init_scope
的使用可能会变得过于复杂,难以实际管理。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
总之,经验法则是避免修改 Python 对象,例如整数或容器(如列表),这些对象位于 tf.function
之外。相反,请使用参数和 TF 对象。例如,"在循环中累积值" 部分有一个关于如何实现类似列表的操作的示例。
在某些情况下,您可以捕获和操作状态,如果它是一个 tf.Variable
。这就是 Keras 模型的权重如何通过对同一个 ConcreteFunction
的重复调用来更新的。
使用 Python 迭代器和生成器
许多 Python 特性,例如生成器和迭代器,依赖于 Python 运行时来跟踪状态。一般来说,虽然这些结构在急切模式下按预期工作,但它们是 Python 副作用的示例,因此只发生在跟踪期间。
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
就像 TensorFlow 为列表结构有一个专门的 tf.TensorArray
一样,它也有一个专门的 tf.data.Iterator
用于迭代结构。有关概述,请参阅 AutoGraph 转换 部分。此外,tf.data
API 可以帮助实现生成器模式
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
tf.function 的所有输出都必须是返回值
除了 tf.Variable
之外,tf.function 必须返回其所有输出。尝试直接访问函数中的任何张量而不通过返回值会导致“泄漏”。
例如,下面的函数通过 Python 全局变量 x
“泄漏”了张量 a
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'SymbolicTensor' object has no attribute 'numpy'
即使泄漏的值也被返回,情况也是如此
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'SymbolicTensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://tensorflowcn.cn/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information. <tf.Tensor 'add:0' shape=() dtype=int32> was defined here: File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main File "/usr/lib/python3.9/runpy.py", line 87, in _run_code File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module> File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 1077, in launch_instance File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 195, in start File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 529, in dispatch_queue File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 518, in process_one File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 424, in dispatch_shell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 766, in execute_request File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 429, in do_execute File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 7, in <module> File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 832, in __call__ File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 888, in _call File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler File "/tmpfs/tmp/ipykernel_11117/566849597.py", line 4, in leaky_function File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1478, in binary_op_wrapper File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1871, in _add_dispatch File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 490, in add_v2 File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2652, in _create_op_internal File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1160, in from_node_def The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=140399462105440), which is out of scope.
通常,这种泄漏发生在您使用 Python 语句或数据结构时。除了泄漏无法访问的张量之外,这些语句也可能是错误的,因为它们被视为 Python 副作用,并且不能保证在每次函数调用时都执行。
泄漏局部张量的常见方法还包括修改外部 Python 集合或对象
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
不支持递归 tf.functions
不支持递归 tf.function
,它们会导致无限循环。例如,
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_11117/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__ return _abc_instancecheck(cls, instance) RecursionError: maximum recursion depth exceeded while calling a Python object
即使递归 tf.function
似乎可以工作,Python 函数也会被跟踪多次,并且可能会对性能产生影响。例如,
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
已知问题
如果您的 tf.function
未正确评估,则错误可能是由这些已知问题引起的,这些问题将在将来修复。
依赖于 Python 全局变量和自由变量
tf.function
在使用 Python 参数的新值调用时会创建一个新的 ConcreteFunction
。但是,它不会对该 tf.function
的 Python 闭包、全局变量或非局部变量执行此操作。如果它们的值在调用 tf.function
之间发生变化,tf.function
仍然会使用它们在跟踪时具有的值。这与普通 Python 函数的工作方式不同。
因此,您应该遵循使用参数而不是闭包外部名称的函数式编程风格。
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
更新全局值的另一种方法是将其设置为 tf.Variable
并使用 Variable.assign
方法。
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
依赖于 Python 对象
将自定义 Python 对象作为参数传递给 tf.function
是支持的,但有一些限制。
为了获得最大的功能覆盖范围,请考虑在将对象传递给 tf.function
之前将其转换为 扩展类型。您也可以使用 Python 原语和 tf.nest
兼容的结构。
但是,正如 跟踪规则 中所述,当自定义 Python 类未提供自定义 TraceType
时,tf.function
被迫使用基于实例的相等性,这意味着当您传递具有修改属性的相同对象时,它不会创建新的跟踪。
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
使用相同的 tf.function
来评估模型的修改后的实例将是有问题的,因为它仍然具有与原始模型相同的 基于实例的 TraceType。
因此,建议您编写 tf.function
以避免依赖可变对象属性,或者为对象实现 跟踪协议 以告知 tf.function
这些属性。
如果不可能,一种解决方法是在每次修改对象时创建新的 tf.function
以强制重新跟踪
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`. `tf.function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
由于 重新跟踪可能很昂贵,您可以使用 tf.Variable
作为对象属性,这些属性可以被修改(但不能更改,小心!),以获得类似的效果,而无需重新跟踪。
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
创建 tf.Variables
tf.function
只支持单例 tf.Variable
,这些单例在第一次调用时创建一次,并在随后的函数调用中重复使用。下面的代码片段将在每次函数调用中创建一个新的 tf.Variable
,这会导致 ValueError
异常。
示例
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmpfs/tmp/ipykernel_11117/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://tensorflowcn.cn/guide/function#creating_tfvariables for more information.
一个常用的模式是使用 Python None 值,然后在值为空时有条件地创建 tf.Variable
来解决此限制
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
使用多个 Keras 优化器
当使用多个 Keras 优化器与 tf.function
一起使用时,您可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
错误。此错误发生是因为优化器在第一次应用梯度时会在内部创建 tf.Variable
。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1701138168.913099 11284 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_11117/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_11117/950644149.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmpfs/tmp/ipykernel_11117/950644149.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1223, in apply_gradients ** return super().apply_gradients(grads_and_vars, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 638, in apply_gradients self.build(trainable_variables) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py", line 145, in build self.add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1125, in add_variable_from_reference return super().add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 513, in add_variable_from_reference variable = tf.Variable( ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://tensorflowcn.cn/guide/function#creating_tfvariables for more information.
如果您需要在调用之间更改有状态对象,最简单的方法是定义一个 tf.Module
子类,并创建实例来保存这些对象。
class TrainStep(tf.Module):
def __init__(self, optimizer):
self.optimizer = optimizer
@tf.function
def __call__(self, w, x, y):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
self.optimizer.apply_gradients(zip(gradients, [w]))
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
train_o1 = TrainStep(opt1)
train_o2 = TrainStep(opt2)
train_o1(w, x, y)
train_o2(w, x, y)
您也可以手动执行此操作,方法是为每个优化器创建多个 @tf.function
包装器实例。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new tf.function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y, opt1)
else:
train_step_2(w, x, y, opt2)
与多个 Keras 模型一起使用
当将不同的模型实例传递给同一个 tf.function
时,您也可能会遇到 ValueError: tf.function only supports singleton tf.Variables created on the first call.
错误。
此错误发生是因为 Keras 模型(没有定义其输入形状)和 Keras 层在第一次调用时会创建 tf.Variable
。您可能试图在已经调用过的 tf.function
中初始化这些变量。为了避免此错误,请尝试调用 model.build(input_shape)
来初始化所有权重,然后再训练模型。
进一步阅读
要了解如何导出和加载 tf.function
,请参阅 SavedModel 指南。要了解有关在跟踪后执行的图优化,请参阅 Grappler 指南。要了解如何优化您的数据管道和分析您的模型,请参阅 Profiler 指南。