TF-NumPy 类型提升

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

TensorFlow 中有 4 种类型提升选项。

  • 默认情况下,TensorFlow 会引发错误,而不是为混合类型操作提升类型。
  • 运行 tf.numpy.experimental_enable_numpy_behavior() 会将 TensorFlow 切换为使用 NumPy 类型提升规则
  • 本文档描述了将在 TensorFlow 2.15 中提供(或当前在 tf-nightly 中提供)的两个新选项。
pip install -q tf_nightly

设置

import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

print("Using TensorFlow version %s" % tf.__version__)
Using TensorFlow version 2.17.0-dev20240210

启用新的类型提升

为了在 TF-Numpy 中使用 类似 JAX 的类型提升,在为 TensorFlow 启用 NumPy 行为时,将 'all''safe' 指定为数据类型转换模式。

这个新系统(使用 dtype_conversion_mode="all")是结合律、交换律,并且可以轻松控制最终使用的浮点数宽度(它不会自动转换为更宽的浮点数)。它确实引入了一些溢出和精度损失的风险,但是 dtype_conversion_mode="safe" 会强制您显式处理这些情况。这两种模式将在 下一节 中详细说明。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.

两种模式:ALL 模式与 SAFE 模式

在新类型提升系统中,我们引入了两种模式:ALL 模式和 SAFE 模式。SAFE 模式用于缓解对可能导致精度损失或位宽扩展的“有风险”提升的担忧。

数据类型

为了简洁,我们将使用以下缩写。

星号 (*) 表示相应的类型是“弱”的 - 这种数据类型由系统临时推断,并且可能推迟到其他数据类型。这个概念将在 这里 详细解释。

精度丢失操作的示例

在以下示例中,i32 + f32ALL 模式下允许,但在 SAFE 模式下不允许,因为存在精度丢失的风险。

# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b  # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
  a + b
except TypeError as e:
   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).

位扩展操作的示例

在以下示例中,i8 + u32ALL 模式下允许,但在 SAFE 模式下不允许,因为存在位扩展,这意味着使用比输入位数更多的位。请注意,新的类型提升语义只允许必要的位扩展。

# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
  a + b
except TypeError as e:
   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).

基于格子的系统

类型提升格

新的类型提升行为通过以下类型提升格来确定

Type Promotion Lattice

更具体地说,任何两种类型之间的提升是通过找到这两个节点的第一个公共子节点(包括节点本身)来确定的。

例如,在上图中,i8i32 的第一个公共子节点是 i32,因为这两个节点在沿着箭头方向时第一次在 i32 处相交。

类似地,作为另一个示例,u64f16 之间的结果提升类型将是 f16

类型提升表

遵循格会生成下面的二进制提升表

Type Promotion Table

新类型提升的优势

我们采用类似 JAX 的基于格子的系统来进行新的类型提升,它提供了以下优势

基于格子的系统的优势

首先,使用基于格子的系统确保了三个非常重要的属性

  • 存在性:对于任何类型的组合,都存在唯一的提升结果类型。
  • 交换律:a + b = b + a
  • 结合律:a + (b + c) = (a + b) = c

这三个属性对于构建一致且可预测的类型提升语义至关重要。

类似 JAX 的格系统的优势

类似 JAX 的格系统的另一个关键优势是,在无符号整数之外,它避免了所有不必要的更宽的提升。这意味着您无法在没有 64 位输入的情况下获得 64 位结果。这对于在加速器上工作特别有利,因为它避免了不必要的 64 位值,这在旧的类型提升中很常见。

但是,这带来了一个权衡:混合浮点/整数提升很容易导致精度丢失。例如,在下面的示例中,i64 + f16 会导致将 i64 提升为 f16

# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float16, numpy=4.2>

为了缓解这个问题,我们引入了 SAFE 模式,它将不允许这些“有风险”的提升。

弱张量

概述

弱张量 是“弱类型”的张量,类似于 JAX 中的概念.

WeakTensor 的数据类型由系统临时推断,并且可能推迟到其他数据类型。这个概念是在新的类型提升中引入的,以防止在 TF 值和没有明确用户指定类型的值(例如 Python 标量文字)之间的二元操作中出现不必要的类型提升。

例如,在下面的示例中,tf.constant(1.2) 被认为是“弱”的,因为它没有特定的数据类型。因此,tf.constant(1.2) 推迟到 tf.constant(3.1, tf.float16) 的类型,导致 f16 输出。

tf.constant(1.2) + tf.constant(3.1, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>

弱张量构造

如果您在创建张量时没有指定数据类型,则会创建弱张量,结果是弱张量。您可以通过检查张量字符串表示末尾的弱属性来检查张量是否为“弱”的。

第一种情况:当 tf.constant 被调用时,输入没有用户指定的数据类型。

tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3])  # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32)  # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>

第二种情况:当没有用户指定的数据类型的输入被传递到 支持弱张量的 API 中时。

tf.math.abs([100.0, 4.0])  # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100.,   4.], dtype=float32), weak=True>

启用新的类型提升的效果

以下是启用新的类型提升后发生的一些非详尽的更改列表。

  • 更一致且可预测的提升结果。
  • 降低了位扩展的风险。
  • tf.Tensor 数学 dunder 方法使用新的类型提升。
  • tf.constant 可以返回 WeakTensor
  • tf.constant 允许隐式转换,当传递具有与 dtype 参数不同的数据类型的张量输入时。
  • tf.Variable 原地操作 (assign, assign-add, assign-sub) 允许隐式转换。
  • tnp.array(1)tnp.array(1.0) 返回 32 位弱张量。
  • 将为 支持弱张量的单目和二元 API 创建和使用 WeakTensor

更一致且可预测的提升结果

使用 基于格子的系统 使新的类型提升能够产生一致且可预测的类型提升结果。

旧的类型提升

使用旧的类型提升,更改操作顺序会产生不一致的结果。

# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
  tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
  print(f'{type(e)}: {e}')  # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>

新的类型提升

新的类型提升无论顺序如何都会产生一致的结果。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>

降低了位扩展的风险

旧的类型提升

旧的类型提升通常会导致 64 位结果。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>

新的类型提升

新的类型提升返回具有最小必要位数的结果。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>

tf.Tensor 数学 dunder 方法

所有 tf.Tensor 数学 dunder 方法都将遵循新的类型提升。

-tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>

tf.Variable 原地操作

将允许 tf.Variable 原地操作中的隐式转换。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16))  # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>

tf.constant 隐式转换

在旧的类型提升中,tf.constant 要求输入张量具有与 dtype 参数相同的数据类型。但是,在新的类型提升中,我们隐式地将张量转换为指定的数据类型。

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>

TF-NumPy 数组

tnp.array 使用新的类型提升,对于 Python 输入默认使用 i32*f32*

tnp.array(1)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>

输入类型推断

这是在新的类型提升中推断不同输入类型的过程。

  • tf.Tensor:由于 tf.Tensor 具有 dtype 属性,因此我们不会进行进一步的推断。
  • NumPy 类型:这包括 np.array(1), np.int16(1)np.float 等类型。由于 NumPy 输入也具有 dtype 属性,因此我们将 dtype 属性作为结果推断类型。请注意,NumPy 默认使用 i64f64
  • Python 标量/嵌套类型:这包括 1, [1, 2, 3](1.0, 2.0) 等类型。
    • Python int 被推断为 i32*
    • Python float 被推断为 f32*
    • Python complex 被推断为 c128*
  • 如果输入不属于上述任何类别,但具有 dtype 属性,我们将 dtype 属性作为结果推断类型。

进一步阅读

新的类型提升与 JAX-NumPy 的类型提升非常相似。如果您想了解有关新的类型提升和设计选择的更多详细信息,请查看以下资源。

参考资料

支持弱张量的 API

以下是支持 WeakTensor 的 API 列表。

对于单目操作,这意味着如果传递了没有用户指定类型的输入,它将返回 WeakTensor

对于二元操作,它将遵循 这里 的提升表。它可能会或可能不会返回 WeakTensor,具体取决于两个输入的提升结果。