在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
概述
TensorFlow 中有 4 种类型提升选项。
- 默认情况下,TensorFlow 会引发错误,而不是为混合类型操作提升类型。
- 运行
tf.numpy.experimental_enable_numpy_behavior()
会将 TensorFlow 切换为使用 NumPy 类型提升规则。 - 本文档描述了将在 TensorFlow 2.15 中提供(或当前在
tf-nightly
中提供)的两个新选项。
pip install -q tf_nightly
设置
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
print("Using TensorFlow version %s" % tf.__version__)
Using TensorFlow version 2.17.0-dev20240210
启用新的类型提升
为了在 TF-Numpy 中使用 类似 JAX 的类型提升,在为 TensorFlow 启用 NumPy 行为时,将 'all'
或 'safe'
指定为数据类型转换模式。
这个新系统(使用 dtype_conversion_mode="all"
)是结合律、交换律,并且可以轻松控制最终使用的浮点数宽度(它不会自动转换为更宽的浮点数)。它确实引入了一些溢出和精度损失的风险,但是 dtype_conversion_mode="safe"
会强制您显式处理这些情况。这两种模式将在 下一节 中详细说明。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
两种模式:ALL 模式与 SAFE 模式
在新类型提升系统中,我们引入了两种模式:ALL
模式和 SAFE
模式。SAFE
模式用于缓解对可能导致精度损失或位宽扩展的“有风险”提升的担忧。
数据类型
为了简洁,我们将使用以下缩写。
b
表示tf.bool
u8
表示tf.uint8
i16
表示tf.int16
i32
表示tf.int32
bf16
表示tf.bfloat16
f32
表示tf.float32
f64
表示tf.float64
i32*
表示 Pythonint
或弱类型i32
f32*
表示 Pythonfloat
或弱类型f32
c128*
表示 Pythoncomplex
或弱类型c128
星号 (*) 表示相应的类型是“弱”的 - 这种数据类型由系统临时推断,并且可能推迟到其他数据类型。这个概念将在 这里 详细解释。
精度丢失操作的示例
在以下示例中,i32
+ f32
在 ALL
模式下允许,但在 SAFE
模式下不允许,因为存在精度丢失的风险。
# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
位扩展操作的示例
在以下示例中,i8
+ u32
在 ALL
模式下允许,但在 SAFE
模式下不允许,因为存在位扩展,这意味着使用比输入位数更多的位。请注意,新的类型提升语义只允许必要的位扩展。
# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
基于格子的系统
类型提升格
新的类型提升行为通过以下类型提升格来确定
更具体地说,任何两种类型之间的提升是通过找到这两个节点的第一个公共子节点(包括节点本身)来确定的。
例如,在上图中,i8
和 i32
的第一个公共子节点是 i32
,因为这两个节点在沿着箭头方向时第一次在 i32
处相交。
类似地,作为另一个示例,u64
和 f16
之间的结果提升类型将是 f16
。
类型提升表
遵循格会生成下面的二进制提升表
新类型提升的优势
我们采用类似 JAX 的基于格子的系统来进行新的类型提升,它提供了以下优势
基于格子的系统的优势
首先,使用基于格子的系统确保了三个非常重要的属性
- 存在性:对于任何类型的组合,都存在唯一的提升结果类型。
- 交换律:
a + b = b + a
- 结合律:
a + (b + c) = (a + b) = c
这三个属性对于构建一致且可预测的类型提升语义至关重要。
类似 JAX 的格系统的优势
类似 JAX 的格系统的另一个关键优势是,在无符号整数之外,它避免了所有不必要的更宽的提升。这意味着您无法在没有 64 位输入的情况下获得 64 位结果。这对于在加速器上工作特别有利,因为它避免了不必要的 64 位值,这在旧的类型提升中很常见。
但是,这带来了一个权衡:混合浮点/整数提升很容易导致精度丢失。例如,在下面的示例中,i64
+ f16
会导致将 i64
提升为 f16
。
# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
为了缓解这个问题,我们引入了 SAFE
模式,它将不允许这些“有风险”的提升。
弱张量
概述
WeakTensor
的数据类型由系统临时推断,并且可能推迟到其他数据类型。这个概念是在新的类型提升中引入的,以防止在 TF 值和没有明确用户指定类型的值(例如 Python 标量文字)之间的二元操作中出现不必要的类型提升。
例如,在下面的示例中,tf.constant(1.2)
被认为是“弱”的,因为它没有特定的数据类型。因此,tf.constant(1.2)
推迟到 tf.constant(3.1, tf.float16)
的类型,导致 f16
输出。
tf.constant(1.2) + tf.constant(3.1, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>
弱张量构造
如果您在创建张量时没有指定数据类型,则会创建弱张量,结果是弱张量。您可以通过检查张量字符串表示末尾的弱属性来检查张量是否为“弱”的。
第一种情况:当 tf.constant
被调用时,输入没有用户指定的数据类型。
tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3]) # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32) # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>
第二种情况:当没有用户指定的数据类型的输入被传递到 支持弱张量的 API 中时。
tf.math.abs([100.0, 4.0]) # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
启用新的类型提升的效果
以下是启用新的类型提升后发生的一些非详尽的更改列表。
- 更一致且可预测的提升结果。
- 降低了位扩展的风险。
tf.Tensor
数学 dunder 方法使用新的类型提升。tf.constant
可以返回WeakTensor
。tf.constant
允许隐式转换,当传递具有与dtype
参数不同的数据类型的张量输入时。tf.Variable
原地操作 (assign
,assign-add
,assign-sub
) 允许隐式转换。tnp.array(1)
和tnp.array(1.0)
返回 32 位弱张量。- 将为 支持弱张量的单目和二元 API 创建和使用
WeakTensor
。
更一致且可预测的提升结果
使用 基于格子的系统 使新的类型提升能够产生一致且可预测的类型提升结果。
旧的类型提升
使用旧的类型提升,更改操作顺序会产生不一致的结果。
# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
print(f'{type(e)}: {e}') # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>
新的类型提升
新的类型提升无论顺序如何都会产生一致的结果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
降低了位扩展的风险
旧的类型提升
旧的类型提升通常会导致 64 位结果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
新的类型提升
新的类型提升返回具有最小必要位数的结果。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>
tf.Tensor 数学 dunder 方法
所有 tf.Tensor
数学 dunder 方法都将遵循新的类型提升。
-tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>
tf.Variable 原地操作
将允许 tf.Variable
原地操作中的隐式转换。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16)) # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>
tf.constant 隐式转换
在旧的类型提升中,tf.constant
要求输入张量具有与 dtype 参数相同的数据类型。但是,在新的类型提升中,我们隐式地将张量转换为指定的数据类型。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
TF-NumPy 数组
tnp.array
使用新的类型提升,对于 Python 输入默认使用 i32*
和 f32*
。
tnp.array(1) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>
输入类型推断
这是在新的类型提升中推断不同输入类型的过程。
tf.Tensor
:由于tf.Tensor
具有 dtype 属性,因此我们不会进行进一步的推断。- NumPy 类型:这包括
np.array(1)
,np.int16(1)
和np.float
等类型。由于 NumPy 输入也具有 dtype 属性,因此我们将 dtype 属性作为结果推断类型。请注意,NumPy 默认使用i64
和f64
。 - Python 标量/嵌套类型:这包括
1
,[1, 2, 3]
和(1.0, 2.0)
等类型。- Python
int
被推断为i32*
。 - Python
float
被推断为f32*
。 - Python
complex
被推断为c128*
。
- Python
- 如果输入不属于上述任何类别,但具有 dtype 属性,我们将 dtype 属性作为结果推断类型。
进一步阅读
新的类型提升与 JAX-NumPy 的类型提升非常相似。如果您想了解有关新的类型提升和设计选择的更多详细信息,请查看以下资源。
参考资料
支持弱张量的 API
以下是支持 WeakTensor
的 API 列表。
对于单目操作,这意味着如果传递了没有用户指定类型的输入,它将返回 WeakTensor
。
对于二元操作,它将遵循 这里 的提升表。它可能会或可能不会返回 WeakTensor
,具体取决于两个输入的提升结果。
tf.bitwise.invert
tf.clip_by_value
tf.debugging.check_numerics
tf.expand_dims
tf.identity
tf.image.adjust_brightness
tf.image.adjust_gamma
tf.image.extract_patches
tf.image.random_brightness
tf.image.stateless_random_brightness
tf.linalg.diag
tf.linalg.diag_part
tf.linalg.matmul
tf.linalg.matrix_transpose
tf.linalg.tensor_diag_part
tf.linalg.trace
tf.math.abs
tf.math.acos
tf.math.acosh
tf.math.add
tf.math.angle
tf.math.asin
tf.math.asinh
tf.math.atan
tf.math.atanh
tf.math.ceil
tf.math.conj
tf.math.cos
tf.math.cosh
tf.math.digamma
tf.math.divide_no_nan
tf.math.divide
tf.math.erf
tf.math.erfc
tf.math.erfcinv
tf.math.erfinv
tf.math.exp
tf.math.expm1
tf.math.floor
tf.math.floordiv
tf.math.floormod
tf.math.imag
tf.math.lgamma
tf.math.log1p
tf.math.log_sigmoid
tf.math.log
tf.math.multiply_no_nan
tf.math.multiply
tf.math.ndtri
tf.math.negative
tf.math.pow
tf.math.real
tf.math.real
tf.math.reciprocal_no_nan
tf.math.reciprocal
tf.math.reduce_euclidean_norm
tf.math.reduce_logsumexp
tf.math.reduce_max
tf.math.reduce_mean
tf.math.reduce_min
tf.math.reduce_prod
tf.math.reduce_std
tf.math.reduce_sum
tf.math.reduce_variance
tf.math.rint
tf.math.round
tf.math.rsqrt
tf.math.scalar_mul
tf.math.sigmoid
tf.math.sign
tf.math.sin
tf.math.sinh
tf.math.softplus
tf.math.special.bessel_i0
tf.math.special.bessel_i0e
tf.math.special.bessel_i1
tf.math.special.bessel_i1e
tf.math.special.bessel_j0
tf.math.special.bessel_j1
tf.math.special.bessel_k0
tf.math.special.bessel_k0e
tf.math.special.bessel_k1
tf.math.special.bessel_k1e
tf.math.special.bessel_y0
tf.math.special.bessel_y1
tf.math.special.dawsn
tf.math.special.expint
tf.math.special.fresnel_cos
tf.math.special.fresnel_sin
tf.math.special.spence
tf.math.sqrt
tf.math.square
tf.math.subtract
tf.math.tan
tf.math.tanh
tf.nn.depth_to_space
tf.nn.elu
tf.nn.gelu
tf.nn.leaky_relu
tf.nn.log_softmax
tf.nn.relu6
tf.nn.relu
tf.nn.selu
tf.nn.softsign
tf.nn.space_to_depth
tf.nn.swish
tf.ones_like
tf.realdiv
tf.reshape
tf.squeeze
tf.stop_gradient
tf.transpose
tf.truncatediv
tf.truncatemod
tf.zeros_like
tf.experimental.numpy.abs
tf.experimental.numpy.absolute
tf.experimental.numpy.amax
tf.experimental.numpy.amin
tf.experimental.numpy.angle
tf.experimental.numpy.arange
tf.experimental.numpy.arccos
tf.experimental.numpy.arccosh
tf.experimental.numpy.arcsin
tf.experimental.numpy.arcsinh
tf.experimental.numpy.arctan
tf.experimental.numpy.arctanh
tf.experimental.numpy.around
tf.experimental.numpy.array
tf.experimental.numpy.asanyarray
tf.experimental.numpy.asarray
tf.experimental.numpy.ascontiguousarray
tf.experimental.numpy.average
tf.experimental.numpy.bitwise_not
tf.experimental.numpy.cbrt
tf.experimental.numpy.ceil
tf.experimental.numpy.conj
tf.experimental.numpy.conjugate
tf.experimental.numpy.copy
tf.experimental.numpy.cos
tf.experimental.numpy.cosh
tf.experimental.numpy.cumprod
tf.experimental.numpy.cumsum
tf.experimental.numpy.deg2rad
tf.experimental.numpy.diag
tf.experimental.numpy.diagflat
tf.experimental.numpy.diagonal
tf.experimental.numpy.diff
tf.experimental.numpy.empty_like
tf.experimental.numpy.exp2
tf.experimental.numpy.exp
tf.experimental.numpy.expand_dims
tf.experimental.numpy.expm1
tf.experimental.numpy.fabs
tf.experimental.numpy.fix
tf.experimental.numpy.flatten
tf.experimental.numpy.flip
tf.experimental.numpy.fliplr
tf.experimental.numpy.flipud
tf.experimental.numpy.floor
tf.experimental.numpy.full_like
tf.experimental.numpy.imag
tf.experimental.numpy.log10
tf.experimental.numpy.log1p
tf.experimental.numpy.log2
tf.experimental.numpy.log
tf.experimental.numpy.max
tf.experimental.numpy.mean
tf.experimental.numpy.min
tf.experimental.numpy.moveaxis
tf.experimental.numpy.nanmean
tf.experimental.numpy.negative
tf.experimental.numpy.ones_like
tf.experimental.numpy.positive
tf.experimental.numpy.prod
tf.experimental.numpy.rad2deg
tf.experimental.numpy.ravel
tf.experimental.numpy.real
tf.experimental.numpy.reciprocal
tf.experimental.numpy.repeat
tf.experimental.numpy.reshape
tf.experimental.numpy.rot90
tf.experimental.numpy.round
tf.experimental.numpy.signbit
tf.experimental.numpy.sin
tf.experimental.numpy.sinc
tf.experimental.numpy.sinh
tf.experimental.numpy.sort
tf.experimental.numpy.sqrt
tf.experimental.numpy.square
tf.experimental.numpy.squeeze
tf.experimental.numpy.std
tf.experimental.numpy.sum
tf.experimental.numpy.swapaxes
tf.experimental.numpy.tan
tf.experimental.numpy.tanh
tf.experimental.numpy.trace
tf.experimental.numpy.transpose
tf.experimental.numpy.triu
tf.experimental.numpy.vander
tf.experimental.numpy.var
tf.experimental.numpy.zeros_like