在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
设置
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
扩展类型
用户定义类型可以使项目更具可读性、模块化和可维护性。但是,大多数 TensorFlow API 对用户定义的 Python 类型支持非常有限。这包括高级 API(例如 Keras、tf.function、tf.SavedModel
)和低级 API(例如 tf.while_loop
和 tf.concat
)。TensorFlow 扩展类型 可用于创建与 TensorFlow API 无缝协作的用户定义的面向对象类型。要创建扩展类型,只需定义一个以 tf.experimental.ExtensionType
为基类的 Python 类,并使用 类型注释 为每个字段指定类型。
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
tf.experimental.ExtensionType
基类的工作方式类似于标准 Python 库中的 typing.NamedTuple
和 @dataclasses.dataclass
。特别是,它会根据字段类型注释自动添加构造函数和特殊方法(例如 __repr__
和 __eq__
)。
通常,扩展类型倾向于属于以下两种类别之一
数据结构,它将一组相关值组合在一起,并可以根据这些值提供有用的操作。数据结构可能相当通用(例如上面的
TensorGraph
示例);或者它们可能高度定制到特定模型。类似张量的类型,它专门化或扩展了“张量”的概念。此类别的类型具有
rank
、shape
,通常还具有dtype
;并且使用它们进行张量运算(例如tf.stack
、tf.add
或tf.matmul
)是有意义的。MaskedTensor
和CSRSparseMatrix
是类似张量类型的示例。
支持的 API
以下 TensorFlow API 支持扩展类型
- Keras:扩展类型可以用作 Keras
Models
和Layers
的输入和输出。 tf.data.Dataset
:扩展类型可以包含在Datasets
中,并由数据集Iterators
返回。- TensorFlow Hub:扩展类型可以用作
tf.hub
模块的输入和输出。 - SavedModel:扩展类型可以用作
SavedModel
函数的输入和输出。 tf.function
:扩展类型可以用作用@tf.function
装饰器包装的函数的参数和返回值。- While 循环:扩展类型可以用作
tf.while_loop
中的循环变量,并且可以用作 while 循环主体函数的参数和返回值。 - 条件语句:可以使用
tf.cond
和tf.case
有条件地选择扩展类型。 tf.py_function
:扩展类型可以用作tf.py_function
的func
参数的输入和输出。- 张量操作:扩展类型可以扩展以支持大多数接受张量输入的 TensorFlow 操作(例如
tf.matmul
、tf.gather
和tf.reduce_sum
)。有关更多信息,请参见下面的“调度”部分。 - 分布式策略:扩展类型可以用作每个副本的值。
有关更多详细信息,请参见下面的“支持 ExtensionTypes 的 TensorFlow API”部分。
要求
字段类型
所有字段(实例变量)都必须声明,并且必须为每个字段提供类型注释。支持以下类型注释
类型 | 示例 |
---|---|
Python 整数 | i: int |
Python 浮点数 | f: float |
Python 字符串 | s: str |
Python 布尔值 | b: bool |
Python None |
n: None |
张量形状 | shape: tf.TensorShape |
张量 dtype |
dtype: tf.DType |
张量 | t: tf.Tensor |
扩展类型 | mt: MyMaskedTensor |
Ragged 张量 | rt: tf.RaggedTensor |
稀疏张量 | st: tf.SparseTensor |
索引切片 | s: tf.IndexedSlices |
可选张量 | o: tf.experimental.Optional |
类型联合 | int_or_float: typing.Union[int, float] |
元组 | params: typing.Tuple[int, float, tf.Tensor, int] |
可变长度元组 | lengths: typing.Tuple[int, ...] |
映射 | tags: typing.Mapping[str, tf.Tensor] |
可选值 | weight: typing.Optional[tf.Tensor] |
可变性
扩展类型必须是不可变的。这确保了它们可以被 TensorFlow 的图跟踪机制正确跟踪。如果您发现自己想要修改扩展类型值,请考虑改为定义转换值的函数。例如,与其定义一个 set_mask
函数来修改 MaskedTensor
,不如定义一个 replace_mask
函数,它返回一个新的 MaskedTensor
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
由 ExtensionType
添加的功能
ExtensionType
基类提供以下功能
- 一个构造函数(
__init__
)。 - 一个可打印表示函数(
__repr__
)。 - 相等和不相等运算符(
__eq__
)。 - 一个验证函数(
__validate__
)。 - 强制不可变性。
- 一个嵌套的
TypeSpec
。 - 张量 API 调度支持。
有关自定义此功能的更多信息,请参见下面的“自定义 ExtensionType
”部分。
构造函数
由 ExtensionType
添加的构造函数将每个字段作为命名参数(按它们在类定义中列出的顺序)进行接收。此构造函数将对每个参数进行类型检查,并在必要时进行转换。特别是,Tensor
字段使用 tf.convert_to_tensor
进行转换;Tuple
字段转换为 tuple
;Mapping
字段转换为不可变字典。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)
如果字段值无法转换为其声明的类型,则构造函数将引发 TypeError
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
可以通过在类级别设置字段的值来指定字段的默认值
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(length=0.5, color="blue")
可打印表示
ExtensionType
添加了一个默认的可打印表示函数(__repr__
),其中包括类名和每个字段的值
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
相等运算符
ExtensionType
添加了默认的相等运算符(__eq__
和 __ne__
),如果两个值具有相同的类型并且所有字段都相等,则认为这两个值相等。如果张量字段具有相同的形状并且所有元素的逐元素相等,则认为它们相等。
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
验证函数
ExtensionType
添加了一个 __validate__
函数,它可以被重写以对字段执行验证检查。它在构造函数被调用后以及字段被类型检查并转换为其声明的类型后运行,因此它可以假设所有字段都具有其声明的类型。
以下示例更新了 MaskedTensor
以验证其字段的 shape
和 dtype
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # Wrong `dtype` for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
强制不可变性
ExtensionType
重写了 __setattr__
和 __delattr__
函数以防止修改,确保扩展类型值是不可变的。
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
嵌套 TypeSpec
每个 ExtensionType
类都有一个相应的 TypeSpec
类,该类是自动创建的,并存储为 <extension_type_name>.Spec
。
此类捕获了值中的所有信息,除了任何嵌套张量的值。特别是,值的 TypeSpec
是通过用其 TypeSpec
替换任何嵌套的张量、扩展类型或复合张量来创建的。
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
TypeSpec
值可以显式构造,也可以使用 tf.type_spec_from_value
从 ExtensionType
值构建。
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TypeSpec
用于 TensorFlow 将值划分为静态部分和动态部分
- 静态部分(在图构建时固定)使用
tf.TypeSpec
编码。 - 动态部分(每次运行图时都可能变化)使用
tf.Tensor
列表编码。
例如,每当参数具有以前从未见过的 TypeSpec
时,tf.function
都会重新跟踪其包装的函数
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
有关更多信息,请参见 tf.function 指南。
自定义 ExtensionType
除了简单地声明字段及其类型外,扩展类型还可以
- 重写默认的可打印表示(
__repr__
)。 - 定义函数。
- 定义
classmethod
和staticmethod
。 - 定义属性。
- 重写默认的构造函数(
__init__
)。 - 重写默认的相等运算符(
__eq__
)。 - 定义运算符(例如
__add__
和__lt__
)。 - 为字段声明默认值。
- 定义子类。
重写默认的可打印表示
您可以重写扩展类型的默认字符串转换运算符。以下示例更新了 MaskedTensor
类,以便在以 Eager 模式打印值时生成更易读的字符串表示。
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
定义函数
扩展类型可以定义函数,就像任何普通的 Python 类一样。例如,MaskedTensor
类型可以定义一个 with_default
函数,该函数返回一个 self
的副本,其中掩码值被给定的 default
值替换。函数可以选择用 @tf.function
装饰器进行注释。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
定义 classmethod
和 staticmethod
扩展类型可以使用 @classmethod
和 @staticmethod
装饰器定义函数。例如,MaskedTensor
类型可以定义一个工厂函数,该函数掩盖任何具有给定值的元素
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values != value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
定义属性
扩展类型可以使用 @property
装饰器定义属性,就像任何普通的 Python 类一样。例如,MaskedTensor
类型可以定义一个 dtype
属性,它是值的 dtype
的简写
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
重写默认的构造函数
您可以重写扩展类型的默认构造函数。自定义构造函数必须为每个声明的字段设置一个值;在自定义构造函数返回后,所有字段都将被类型检查,并且值将按上述方式进行转换。
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
或者,您可以考虑保留默认的构造函数,但添加一个或多个工厂函数。例如
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
重写默认的相等运算符(__eq__
)
您可以重写扩展类型的默认 __eq__
运算符。以下示例更新了 MaskedTensor
以在比较相等性时忽略掩码元素。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
使用前向引用
如果字段的类型尚未定义,您可以使用包含类型名称的字符串。在以下示例中,字符串 "Node"
用于注释 children
字段,因为 Node
类型尚未(完全)定义。
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
定义子类
扩展类型可以使用标准 Python 语法进行子类化。扩展类型子类可以添加新的字段、方法和属性;并且可以覆盖构造函数、可打印表示和相等运算符。以下示例定义了一个基本的 TensorGraph
类,它使用三个 Tensor
字段来编码节点之间的一组边。然后它定义了一个子类,该子类添加了一个 Tensor
字段来记录每个节点的“特征值”。子类还定义了一个方法来沿着边传播特征值。
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
定义私有字段
扩展类型的字段可以通过在它们前面加上下划线(遵循标准 Python 约定)来标记为私有。这不会以任何方式影响 TensorFlow 处理字段的方式;而只是向扩展类型的任何用户发出信号,表明这些字段是私有的。
自定义 ExtensionType
的 TypeSpec
每个 ExtensionType
类都有一个对应的 TypeSpec
类,该类是自动创建的,并存储为 <extension_type_name>.Spec
。有关更多信息,请参见上面的“嵌套 TypeSpec”部分。
要自定义 TypeSpec
,只需定义您自己的名为 Spec
的嵌套类,ExtensionType
将使用它作为自动构建的 TypeSpec
的基础。您可以通过以下方式自定义 Spec
类
- 覆盖默认的可打印表示。
- 覆盖默认构造函数。
- 定义方法、
classmethod
、staticmethod
和属性。
以下示例自定义了 MaskedTensor.Spec
类,使其更易于使用
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
张量 API 调度
扩展类型可以是“类似张量的”,因为它们专门化或扩展了由 tf.Tensor
类型定义的接口。类似张量的扩展类型的示例包括 RaggedTensor
、SparseTensor
和 MaskedTensor
。调度装饰器 可用于覆盖 TensorFlow 操作在应用于类似张量的扩展类型时的默认行为。TensorFlow 目前定义了三个调度装饰器
@tf.experimental.dispatch_for_api(tf_api)
@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
为单个 API 调度
该 tf.experimental.dispatch_for_api
装饰器覆盖了指定 TensorFlow 操作的默认行为,当它以指定的签名调用时。例如,您可以使用此装饰器来指定 tf.stack
如何处理 MaskedTensor
值
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
这将覆盖 tf.stack
的默认实现,只要它被调用时带有 MaskedTensor
值的列表(因为 values
参数用 typing.List[MaskedTensor]
进行了注释)。
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
要允许 tf.stack
处理混合 MaskedTensor
和 Tensor
值的列表,您可以细化 values
参数的类型注释,并相应地更新函数体
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
有关可以覆盖的 API 列表,请参见 tf.experimental.dispatch_for_api
的 API 文档。
为所有一元逐元素 API 调度
该 tf.experimental.dispatch_for_unary_elementwise_apis
装饰器覆盖了所有一元逐元素操作(例如 tf.math.cos
)的默认行为,只要第一个参数的值(通常命名为 x
)与类型注释 x_type
匹配。装饰的函数应接受两个参数
api_func
:一个函数,它接受一个参数并执行逐元素操作(例如,tf.abs
)。x
:逐元素操作的第一个参数。
以下示例更新所有一元逐元素操作以处理 MaskedTensor
类型
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
现在,只要在一元逐元素操作上调用 MaskedTensor
,就会使用此函数。
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
print(tf.ones_like(x, dtype=tf.float32))
为所有二元逐元素 API 调度
类似地,tf.experimental.dispatch_for_binary_elementwise_apis
可用于更新所有二元逐元素操作以处理 MaskedTensor
类型
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
有关被覆盖的逐元素 API 列表,请转到 tf.experimental.dispatch_for_unary_elementwise_apis
和 tf.experimental.dispatch_for_binary_elementwise_apis
的 API 文档。
可批处理的 ExtensionType
如果单个实例可以用来表示一批值,则 ExtensionType
是可批处理的。通常,这是通过向所有嵌套的 Tensor
添加批处理维度来实现的。以下 TensorFlow API 要求任何扩展类型输入都是可批处理的
tf.data.Dataset
(batch
、unbatch
、from_tensor_slices
)tf.keras
(fit
、evaluate
、predict
)tf.map_fn
默认情况下,BatchableExtensionType
通过对任何嵌套的 Tensor
、CompositeTensor
和 ExtensionType
进行批处理来创建批处理值。如果这对您的类不适用,那么您需要使用 tf.experimental.ExtensionTypeBatchEncoder
来覆盖此默认行为。例如,通过简单地堆叠单个稀疏张量的 values
、indices
和 dense_shape
字段来创建一批 tf.SparseTensor
值将不合适——在大多数情况下,您无法堆叠这些张量,因为它们具有不兼容的形状;即使可以,结果也不会是有效的 SparseTensor
。
BatchableExtensionType
示例:Network
例如,考虑一个简单的 Network
类,用于负载均衡,它跟踪每个节点上还有多少工作要做,以及在节点之间移动工作还有多少带宽可用
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
要使此类型可批处理,请将基本类型更改为 BatchableExtensionType
,并将每个字段的形状调整为包含可选的批处理维度。以下示例还添加了一个 shape
字段来跟踪批处理形状。此 shape
字段不是 tf.data.Dataset
或 tf.map_fn
所必需的,但它是 tf.keras
所必需的。
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
然后,您可以使用 tf.data.Dataset
迭代一批网络
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
您还可以使用 map_fn
将函数应用于每个批处理元素
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
支持 ExtensionType
的 TensorFlow API
@tf.function
tf.function
是一个装饰器,它为 Python 函数预先计算 TensorFlow 图,这可以大幅提高 TensorFlow 代码的性能。扩展类型值可以与 @tf.function
装饰的函数透明地使用。
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
如果您希望显式指定 tf.function
的 input_signature
,那么您可以使用扩展类型的 TypeSpec
来做到这一点。
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
具体函数
具体函数封装了由 tf.function
构建的单个跟踪图。扩展类型可以与具体函数透明地使用。
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
控制流操作
扩展类型受 TensorFlow 的控制流操作支持
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
自动图控制流
扩展类型也受 tf.function
(使用自动图)中的控制流语句支持。在以下示例中,if
语句和 for
语句会自动转换为 tf.cond
和 tf.while_loop
操作,它们支持扩展类型。
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
Keras
tf.keras 是 TensorFlow 用于构建和训练深度学习模型的高级 API。扩展类型可以作为输入传递给 Keras 模型,在 Keras 层之间传递,并由 Keras 模型返回。Keras 目前对扩展类型提出了两个要求
- 它们必须是可批处理的(请参见上面的“可批处理的
ExtensionType
”)。 - 它们必须有一个名为
shape
的字段或属性。shape[0]
被假定为批处理维度。
以下两个小节给出了展示如何将扩展类型与 Keras 一起使用的示例。
Keras 示例:Network
对于第一个示例,考虑上面“可批处理的 ExtensionType
”部分中定义的 Network
类,它可以用于在节点之间负载均衡工作。它的定义在此重复
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network with 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
您可以定义一个新的 Keras 层来处理 Network
。
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above in the "Batchable `ExtensionType`s" section.
return balance_work_greedy(inputs)
然后,您可以使用这些层来创建一个简单的模型。要将 ExtensionType
输入模型,可以使用 tf.keras.layer.Input
层,并将 type_spec
设置为扩展类型的 TypeSpec
。如果 Keras 模型将用于处理批次,则 type_spec
必须包含批次维度。
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
最后,您可以将模型应用于单个网络和一批网络。
model(single_network)
model(batch_of_networks)
Keras 示例:MaskedTensor
在此示例中,MaskedTensor
扩展为支持 Keras
。 shape
定义为一个属性,该属性是从 values
字段计算得出的。Keras 要求您将此属性添加到扩展类型及其 TypeSpec
中。 MaskedTensor
还定义了一个 __name__
变量,这将是 SavedModel
序列化(如下)所必需的。
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
接下来,调度装饰器用于覆盖多个 TensorFlow API 的默认行为。由于这些 API 被标准 Keras 层(如 Dense
层)使用,因此覆盖这些 API 将允许我们使用这些层与 MaskedTensor
一起使用。在本示例中,matmul
用于掩码张量被定义为将掩码值视为零(即,不将它们包含在乘积中)。
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
然后,您可以使用标准 Keras 层构建一个接受 MaskedTensor
输入的 Keras 模型。
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
SavedModel
SavedModel 是一个序列化 TensorFlow 程序,包括权重和计算。它可以从 Keras 模型或自定义模型构建。在这两种情况下,扩展类型都可以与 SavedModel 定义的函数和方法透明地使用。
SavedModel 可以保存处理扩展类型的模型、层和函数,只要扩展类型具有 __name__
字段。此名称用于注册扩展类型,以便在加载模型时可以找到它。
示例:保存 Keras 模型
使用扩展类型的 Keras 模型可以使用 SavedModel
保存。
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
示例:保存自定义模型
SavedModel 也可以用于保存具有处理扩展类型函数的自定义 tf.Module
子类。
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
在 ExtensionType
不可用时加载 SavedModel
如果您加载一个使用 ExtensionType
的 SavedModel
,但该 ExtensionType
不可用(即,它尚未导入),那么您将收到警告,TensorFlow 将回退到使用“匿名扩展类型”对象。此对象将具有与原始类型相同的字段,但将缺少您为该类型添加的任何进一步自定义,例如自定义方法或属性。
将 ExtensionType
与 TensorFlow Serving 一起使用
目前,TensorFlow Serving(以及 SavedModel“签名”字典的其他使用者)要求所有输入和输出都是原始张量。如果您希望将 TensorFlow Serving 与使用扩展类型的模型一起使用,那么您可以添加包装方法,这些方法将扩展类型值从张量中组合或分解。例如
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
Dataset
tf.data
是一个 API,它使您能够从简单、可重用的部分构建复杂的输入管道。它的核心数据结构是 tf.data.Dataset
,它表示元素的序列,其中每个元素包含一个或多个组件。
使用扩展类型构建 Dataset
可以使用 Dataset.from_tensors
、Dataset.from_tensor_slices
或 Dataset.from_generator
从扩展类型值构建数据集。
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
使用扩展类型对 Dataset
进行批处理和取消批处理
可以使用 Dataset.batch
和 Dataset.unbatch
对具有扩展类型的数据集进行批处理和取消批处理。
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)