扩展类型

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

设置

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

扩展类型

用户定义类型可以使项目更具可读性、模块化和可维护性。但是,大多数 TensorFlow API 对用户定义的 Python 类型支持非常有限。这包括高级 API(例如 Kerastf.functiontf.SavedModel)和低级 API(例如 tf.while_looptf.concat)。TensorFlow 扩展类型 可用于创建与 TensorFlow API 无缝协作的用户定义的面向对象类型。要创建扩展类型,只需定义一个以 tf.experimental.ExtensionType 为基类的 Python 类,并使用 类型注释 为每个字段指定类型。

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

tf.experimental.ExtensionType 基类的工作方式类似于标准 Python 库中的 typing.NamedTuple@dataclasses.dataclass。特别是,它会根据字段类型注释自动添加构造函数和特殊方法(例如 __repr____eq__)。

通常,扩展类型倾向于属于以下两种类别之一

  • 数据结构,它将一组相关值组合在一起,并可以根据这些值提供有用的操作。数据结构可能相当通用(例如上面的 TensorGraph 示例);或者它们可能高度定制到特定模型。

  • 类似张量的类型,它专门化或扩展了“张量”的概念。此类别的类型具有 rankshape,通常还具有 dtype;并且使用它们进行张量运算(例如 tf.stacktf.addtf.matmul)是有意义的。 MaskedTensorCSRSparseMatrix 是类似张量类型的示例。

支持的 API

以下 TensorFlow API 支持扩展类型

  • Keras:扩展类型可以用作 Keras ModelsLayers 的输入和输出。
  • tf.data.Dataset:扩展类型可以包含在 Datasets 中,并由数据集 Iterators 返回。
  • TensorFlow Hub:扩展类型可以用作 tf.hub 模块的输入和输出。
  • SavedModel:扩展类型可以用作 SavedModel 函数的输入和输出。
  • tf.function:扩展类型可以用作用 @tf.function 装饰器包装的函数的参数和返回值。
  • While 循环:扩展类型可以用作 tf.while_loop 中的循环变量,并且可以用作 while 循环主体函数的参数和返回值。
  • 条件语句:可以使用 tf.condtf.case 有条件地选择扩展类型。
  • tf.py_function:扩展类型可以用作 tf.py_functionfunc 参数的输入和输出。
  • 张量操作:扩展类型可以扩展以支持大多数接受张量输入的 TensorFlow 操作(例如 tf.matmultf.gathertf.reduce_sum)。有关更多信息,请参见下面的“调度”部分。
  • 分布式策略:扩展类型可以用作每个副本的值。

有关更多详细信息,请参见下面的“支持 ExtensionTypes 的 TensorFlow API”部分。

要求

字段类型

所有字段(实例变量)都必须声明,并且必须为每个字段提供类型注释。支持以下类型注释

类型 示例
Python 整数 i: int
Python 浮点数 f: float
Python 字符串 s: str
Python 布尔值 b: bool
Python None n: None
张量形状 shape: tf.TensorShape
张量 dtype dtype: tf.DType
张量 t: tf.Tensor
扩展类型 mt: MyMaskedTensor
Ragged 张量 rt: tf.RaggedTensor
稀疏张量 st: tf.SparseTensor
索引切片 s: tf.IndexedSlices
可选张量 o: tf.experimental.Optional
类型联合 int_or_float: typing.Union[int, float]
元组 params: typing.Tuple[int, float, tf.Tensor, int]
可变长度元组 lengths: typing.Tuple[int, ...]
映射 tags: typing.Mapping[str, tf.Tensor]
可选值 weight: typing.Optional[tf.Tensor]

可变性

扩展类型必须是不可变的。这确保了它们可以被 TensorFlow 的图跟踪机制正确跟踪。如果您发现自己想要修改扩展类型值,请考虑改为定义转换值的函数。例如,与其定义一个 set_mask 函数来修改 MaskedTensor,不如定义一个 replace_mask 函数,它返回一个新的 MaskedTensor

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

ExtensionType 添加的功能

ExtensionType 基类提供以下功能

  • 一个构造函数(__init__)。
  • 一个可打印表示函数(__repr__)。
  • 相等和不相等运算符(__eq__)。
  • 一个验证函数(__validate__)。
  • 强制不可变性。
  • 一个嵌套的 TypeSpec
  • 张量 API 调度支持。

有关自定义此功能的更多信息,请参见下面的“自定义 ExtensionType”部分。

构造函数

ExtensionType 添加的构造函数将每个字段作为命名参数(按它们在类定义中列出的顺序)进行接收。此构造函数将对每个参数进行类型检查,并在必要时进行转换。特别是,Tensor 字段使用 tf.convert_to_tensor 进行转换;Tuple 字段转换为 tupleMapping 字段转换为不可变字典。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)

如果字段值无法转换为其声明的类型,则构造函数将引发 TypeError

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")

可以通过在类级别设置字段的值来指定字段的默认值

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(length=0.5, color="blue")

可打印表示

ExtensionType 添加了一个默认的可打印表示函数(__repr__),其中包括类名和每个字段的值

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))

相等运算符

ExtensionType 添加了默认的相等运算符(__eq____ne__),如果两个值具有相同的类型并且所有字段都相等,则认为这两个值相等。如果张量字段具有相同的形状并且所有元素的逐元素相等,则认为它们相等。

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")

验证函数

ExtensionType 添加了一个 __validate__ 函数,它可以被重写以对字段执行验证检查。它在构造函数被调用后以及字段被类型检查并转换为其声明的类型后运行,因此它可以假设所有字段都具有其声明的类型。

以下示例更新了 MaskedTensor 以验证其字段的 shapedtype

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # Wrong `dtype` for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")

强制不可变性

ExtensionType 重写了 __setattr____delattr__ 函数以防止修改,确保扩展类型值是不可变的。

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")

嵌套 TypeSpec

每个 ExtensionType 类都有一个相应的 TypeSpec 类,该类是自动创建的,并存储为 <extension_type_name>.Spec

此类捕获了值中的所有信息,除了任何嵌套张量的值。特别是,值的 TypeSpec 是通过用其 TypeSpec 替换任何嵌套的张量、扩展类型或复合张量来创建的。

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.

TypeSpec 值可以显式构造,也可以使用 tf.type_spec_from_valueExtensionType 值构建。

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec 用于 TensorFlow 将值划分为静态部分动态部分

  • 静态部分(在图构建时固定)使用 tf.TypeSpec 编码。
  • 动态部分(每次运行图时都可能变化)使用 tf.Tensor 列表编码。

例如,每当参数具有以前从未见过的 TypeSpec 时,tf.function 都会重新跟踪其包装的函数

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))

有关更多信息,请参见 tf.function 指南

自定义 ExtensionType

除了简单地声明字段及其类型外,扩展类型还可以

  • 重写默认的可打印表示(__repr__)。
  • 定义函数。
  • 定义 classmethodstaticmethod
  • 定义属性。
  • 重写默认的构造函数(__init__)。
  • 重写默认的相等运算符(__eq__)。
  • 定义运算符(例如 __add____lt__)。
  • 为字段声明默认值。
  • 定义子类。

重写默认的可打印表示

您可以重写扩展类型的默认字符串转换运算符。以下示例更新了 MaskedTensor 类,以便在以 Eager 模式打印值时生成更易读的字符串表示。

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)

定义函数

扩展类型可以定义函数,就像任何普通的 Python 类一样。例如,MaskedTensor 类型可以定义一个 with_default 函数,该函数返回一个 self 的副本,其中掩码值被给定的 default 值替换。函数可以选择用 @tf.function 装饰器进行注释。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)

定义 classmethodstaticmethod

扩展类型可以使用 @classmethod@staticmethod 装饰器定义函数。例如,MaskedTensor 类型可以定义一个工厂函数,该函数掩盖任何具有给定值的元素

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values != value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)

定义属性

扩展类型可以使用 @property 装饰器定义属性,就像任何普通的 Python 类一样。例如,MaskedTensor 类型可以定义一个 dtype 属性,它是值的 dtype 的简写

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype

重写默认的构造函数

您可以重写扩展类型的默认构造函数。自定义构造函数必须为每个声明的字段设置一个值;在自定义构造函数返回后,所有字段都将被类型检查,并且值将按上述方式进行转换。

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!

或者,您可以考虑保留默认的构造函数,但添加一个或多个工厂函数。例如

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))

重写默认的相等运算符(__eq__

您可以重写扩展类型的默认 __eq__ 运算符。以下示例更新了 MaskedTensor 以在比较相等性时忽略掩码元素。

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)

使用前向引用

如果字段的类型尚未定义,您可以使用包含类型名称的字符串。在以下示例中,字符串 "Node" 用于注释 children 字段,因为 Node 类型尚未(完全)定义。

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])

定义子类

扩展类型可以使用标准 Python 语法进行子类化。扩展类型子类可以添加新的字段、方法和属性;并且可以覆盖构造函数、可打印表示和相等运算符。以下示例定义了一个基本的 TensorGraph 类,它使用三个 Tensor 字段来编码节点之间的一组边。然后它定义了一个子类,该子类添加了一个 Tensor 字段来记录每个节点的“特征值”。子类还定义了一个方法来沿着边传播特征值。

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)

定义私有字段

扩展类型的字段可以通过在它们前面加上下划线(遵循标准 Python 约定)来标记为私有。这不会以任何方式影响 TensorFlow 处理字段的方式;而只是向扩展类型的任何用户发出信号,表明这些字段是私有的。

自定义 ExtensionTypeTypeSpec

每个 ExtensionType 类都有一个对应的 TypeSpec 类,该类是自动创建的,并存储为 <extension_type_name>.Spec。有关更多信息,请参见上面的“嵌套 TypeSpec”部分。

要自定义 TypeSpec,只需定义您自己的名为 Spec 的嵌套类,ExtensionType 将使用它作为自动构建的 TypeSpec 的基础。您可以通过以下方式自定义 Spec

  • 覆盖默认的可打印表示。
  • 覆盖默认构造函数。
  • 定义方法、classmethodstaticmethod 和属性。

以下示例自定义了 MaskedTensor.Spec 类,使其更易于使用

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

张量 API 调度

扩展类型可以是“类似张量的”,因为它们专门化或扩展了由 tf.Tensor 类型定义的接口。类似张量的扩展类型的示例包括 RaggedTensorSparseTensorMaskedTensor调度装饰器 可用于覆盖 TensorFlow 操作在应用于类似张量的扩展类型时的默认行为。TensorFlow 目前定义了三个调度装饰器

为单个 API 调度

tf.experimental.dispatch_for_api 装饰器覆盖了指定 TensorFlow 操作的默认行为,当它以指定的签名调用时。例如,您可以使用此装饰器来指定 tf.stack 如何处理 MaskedTensor

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

这将覆盖 tf.stack 的默认实现,只要它被调用时带有 MaskedTensor 值的列表(因为 values 参数用 typing.List[MaskedTensor] 进行了注释)。

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])

要允许 tf.stack 处理混合 MaskedTensorTensor 值的列表,您可以细化 values 参数的类型注释,并相应地更新函数体

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])

有关可以覆盖的 API 列表,请参见 tf.experimental.dispatch_for_api 的 API 文档。

为所有一元逐元素 API 调度

tf.experimental.dispatch_for_unary_elementwise_apis 装饰器覆盖了所有一元逐元素操作(例如 tf.math.cos)的默认行为,只要第一个参数的值(通常命名为 x)与类型注释 x_type 匹配。装饰的函数应接受两个参数

  • api_func:一个函数,它接受一个参数并执行逐元素操作(例如,tf.abs)。
  • x:逐元素操作的第一个参数。

以下示例更新所有一元逐元素操作以处理 MaskedTensor 类型

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

现在,只要在一元逐元素操作上调用 MaskedTensor,就会使用此函数。

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
print(tf.ones_like(x, dtype=tf.float32))

为所有二元逐元素 API 调度

类似地,tf.experimental.dispatch_for_binary_elementwise_apis 可用于更新所有二元逐元素操作以处理 MaskedTensor 类型

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)

有关被覆盖的逐元素 API 列表,请转到 tf.experimental.dispatch_for_unary_elementwise_apistf.experimental.dispatch_for_binary_elementwise_apis 的 API 文档。

可批处理的 ExtensionType

如果单个实例可以用来表示一批值,则 ExtensionType可批处理的。通常,这是通过向所有嵌套的 Tensor 添加批处理维度来实现的。以下 TensorFlow API 要求任何扩展类型输入都是可批处理的

默认情况下,BatchableExtensionType 通过对任何嵌套的 TensorCompositeTensorExtensionType 进行批处理来创建批处理值。如果这对您的类不适用,那么您需要使用 tf.experimental.ExtensionTypeBatchEncoder 来覆盖此默认行为。例如,通过简单地堆叠单个稀疏张量的 valuesindicesdense_shape 字段来创建一批 tf.SparseTensor 值将不合适——在大多数情况下,您无法堆叠这些张量,因为它们具有不兼容的形状;即使可以,结果也不会是有效的 SparseTensor

BatchableExtensionType 示例:Network

例如,考虑一个简单的 Network 类,用于负载均衡,它跟踪每个节点上还有多少工作要做,以及在节点之间移动工作还有多少带宽可用

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

要使此类型可批处理,请将基本类型更改为 BatchableExtensionType,并将每个字段的形状调整为包含可选的批处理维度。以下示例还添加了一个 shape 字段来跟踪批处理形状。此 shape 字段不是 tf.data.Datasettf.map_fn 所必需的,但它是 tf.keras 所必需的。

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")

然后,您可以使用 tf.data.Dataset 迭代一批网络

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")

您还可以使用 map_fn 将函数应用于每个批处理元素

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)

支持 ExtensionType 的 TensorFlow API

@tf.function

tf.function 是一个装饰器,它为 Python 函数预先计算 TensorFlow 图,这可以大幅提高 TensorFlow 代码的性能。扩展类型值可以与 @tf.function 装饰的函数透明地使用。

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)

如果您希望显式指定 tf.functioninput_signature,那么您可以使用扩展类型的 TypeSpec 来做到这一点。

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)

具体函数

具体函数封装了由 tf.function 构建的单个跟踪图。扩展类型可以与具体函数透明地使用。

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)

控制流操作

扩展类型受 TensorFlow 的控制流操作支持

# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])

自动图控制流

扩展类型也受 tf.function(使用自动图)中的控制流语句支持。在以下示例中,if 语句和 for 语句会自动转换为 tf.condtf.while_loop 操作,它们支持扩展类型。

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))

Keras

tf.keras 是 TensorFlow 用于构建和训练深度学习模型的高级 API。扩展类型可以作为输入传递给 Keras 模型,在 Keras 层之间传递,并由 Keras 模型返回。Keras 目前对扩展类型提出了两个要求

  • 它们必须是可批处理的(请参见上面的“可批处理的 ExtensionType”)。
  • 它们必须有一个名为 shape 的字段或属性。shape[0] 被假定为批处理维度。

以下两个小节给出了展示如何将扩展类型与 Keras 一起使用的示例。

Keras 示例:Network

对于第一个示例,考虑上面“可批处理的 ExtensionType”部分中定义的 Network 类,它可以用于在节点之间负载均衡工作。它的定义在此重复

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network with 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

您可以定义一个新的 Keras 层来处理 Network

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above in the "Batchable `ExtensionType`s" section.
    return balance_work_greedy(inputs)

然后,您可以使用这些层来创建一个简单的模型。要将 ExtensionType 输入模型,可以使用 tf.keras.layer.Input 层,并将 type_spec 设置为扩展类型的 TypeSpec。如果 Keras 模型将用于处理批次,则 type_spec 必须包含批次维度。

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

最后,您可以将模型应用于单个网络和一批网络。

model(single_network)
model(batch_of_networks)

Keras 示例:MaskedTensor

在此示例中,MaskedTensor 扩展为支持 Kerasshape 定义为一个属性,该属性是从 values 字段计算得出的。Keras 要求您将此属性添加到扩展类型及其 TypeSpec 中。 MaskedTensor 还定义了一个 __name__ 变量,这将是 SavedModel 序列化(如下)所必需的。

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

接下来,调度装饰器用于覆盖多个 TensorFlow API 的默认行为。由于这些 API 被标准 Keras 层(如 Dense 层)使用,因此覆盖这些 API 将允许我们使用这些层与 MaskedTensor 一起使用。在本示例中,matmul 用于掩码张量被定义为将掩码值视为零(即,不将它们包含在乘积中)。

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

然后,您可以使用标准 Keras 层构建一个接受 MaskedTensor 输入的 Keras 模型。

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))

SavedModel

SavedModel 是一个序列化 TensorFlow 程序,包括权重和计算。它可以从 Keras 模型或自定义模型构建。在这两种情况下,扩展类型都可以与 SavedModel 定义的函数和方法透明地使用。

SavedModel 可以保存处理扩展类型的模型、层和函数,只要扩展类型具有 __name__ 字段。此名称用于注册扩展类型,以便在加载模型时可以找到它。

示例:保存 Keras 模型

使用扩展类型的 Keras 模型可以使用 SavedModel 保存。

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)

示例:保存自定义模型

SavedModel 也可以用于保存具有处理扩展类型函数的自定义 tf.Module 子类。

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))

ExtensionType 不可用时加载 SavedModel

如果您加载一个使用 ExtensionTypeSavedModel,但该 ExtensionType 不可用(即,它尚未导入),那么您将收到警告,TensorFlow 将回退到使用“匿名扩展类型”对象。此对象将具有与原始类型相同的字段,但将缺少您为该类型添加的任何进一步自定义,例如自定义方法或属性。

ExtensionType 与 TensorFlow Serving 一起使用

目前,TensorFlow Serving(以及 SavedModel“签名”字典的其他使用者)要求所有输入和输出都是原始张量。如果您希望将 TensorFlow Serving 与使用扩展类型的模型一起使用,那么您可以添加包装方法,这些方法将扩展类型值从张量中组合或分解。例如

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)

Dataset

tf.data 是一个 API,它使您能够从简单、可重用的部分构建复杂的输入管道。它的核心数据结构是 tf.data.Dataset,它表示元素的序列,其中每个元素包含一个或多个组件。

使用扩展类型构建 Dataset

可以使用 Dataset.from_tensorsDataset.from_tensor_slicesDataset.from_generator 从扩展类型值构建数据集。

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)

使用扩展类型对 Dataset 进行批处理和取消批处理

可以使用 Dataset.batchDataset.unbatch 对具有扩展类型的数据集进行批处理和取消批处理。

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)