自定义联邦算法,第一部分:联邦核心介绍

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本教程是两部分系列教程的第一部分,演示了如何使用 联邦核心 (FC) 在 TensorFlow 联邦 (TFF) 中实现自定义类型的联邦算法 - 一组低级接口,作为我们实现 联邦学习 (FL) 层的基础。

第一部分更偏向概念性;我们介绍了 TFF 中使用的一些关键概念和编程抽象,并通过一个非常简单的分布式温度传感器数组示例演示了它们的用法。在 本系列的第二部分 中,我们将使用此处介绍的机制来实现联邦训练和评估算法的简单版本。作为后续步骤,我们鼓励您学习 联邦平均的实现tff.learning 中。

在本系列结束时,您应该能够认识到联邦核心的应用并不一定局限于学习。我们提供的编程抽象非常通用,可以用于实现分析和其他自定义类型的分布式数据计算。

虽然本教程旨在自成一体,但我们鼓励您首先阅读有关 图像分类文本生成 的教程,以获得对 TensorFlow 联邦框架和 联邦学习 API (tff.learning) 的更高级和更温和的介绍,因为它将帮助您将我们在此处描述的概念置于上下文中。

预期用途

简而言之,联邦核心 (FC) 是一个开发环境,它使您可以紧凑地表达将 TensorFlow 代码与分布式通信运算符相结合的程序逻辑,例如在 联邦平均 中使用的运算符 - 计算系统中一组客户端设备上的分布式总和、平均值和其他类型的分布式聚合,将模型和参数广播到这些设备,等等。

您可能知道 tf.contrib.distribute,此时自然会问:这个框架有哪些不同?毕竟,这两个框架都试图使 TensorFlow 计算分布式。

一种思考方式是,tf.contrib.distribute 的目标是允许用户使用现有的模型和训练代码,只需进行最小的更改即可启用分布式训练,并且重点是如何利用分布式基础设施来提高现有训练代码的效率,而 TFF 的联邦核心的目标是让研究人员和从业人员明确控制他们在系统中使用的特定分布式通信模式。FC 的重点是提供一种灵活且可扩展的语言来表达分布式数据流算法,而不是一组具体的已实现分布式训练功能。

TFF 的 FC API 的主要目标受众之一是希望尝试新的联邦学习算法并评估影响分布式系统中数据流编排方式的细微设计选择的后果的研究人员和从业人员,而无需陷入系统实现细节。FC API 旨在实现的抽象级别大致对应于在研究出版物中描述联邦学习算法机制的伪代码 - 系统中存在哪些数据以及如何转换它,但不会下降到单个点对点网络消息交换的级别。

TFF 作为一个整体,针对的是数据分布且必须保持这种状态的场景,例如出于隐私原因,以及在集中位置收集所有数据可能不可行的情况。与所有数据都可以在数据中心集中位置累积的场景相比,这会对需要更高程度的显式控制的机器学习算法的实现产生影响。

在我们开始之前

在我们深入代码之前,请尝试运行以下“Hello World”示例,以确保您的环境已正确设置。如果它不起作用,请参阅安装指南以获取说明。

pip install --quiet --upgrade tensorflow-federated
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

联邦数据

TFF 的一个区别特征是它允许您简洁地表达基于 TensorFlow 的联邦数据上的计算。在本教程中,我们将使用术语联邦数据来指代分布式系统中一组设备上托管的数据项集合。例如,在移动设备上运行的应用程序可能会收集数据并将其存储在本地,而不会上传到集中位置。或者,一组分布式传感器可能会收集并存储其位置的温度读数。

TFF 中将像上述示例中的那些联邦数据视为一等公民,即它们可以作为函数的参数和结果出现,并且它们具有类型。为了加强这一概念,我们将联邦数据集称为联邦值,或称为联邦类型的值

需要理解的重要一点是,我们将所有设备上的所有数据项集合(例如,来自分布式阵列中所有传感器的所有温度读数的整个集合)建模为单个联邦值。

例如,以下是如何在 TFF 中定义由一组客户端设备托管的联邦浮点数类型。跨分布式传感器阵列实现的温度读数集合可以建模为此联邦类型的值。

federated_float_on_clients = tff.FederatedType(np.float32, tff.CLIENTS)

更一般地说,TFF 中的联邦类型是通过指定其成员组成部分的类型T(驻留在单个设备上的数据项)和托管此类型联邦值的设备组G(加上我们将很快提到的第三个可选信息位)来定义的。我们将托管联邦值的设备组G称为该值的放置。因此,tff.CLIENTS 是放置的示例。

str(federated_float_on_clients.member)
'float32'
str(federated_float_on_clients.placement)
'CLIENTS'

具有成员组成部分T 和放置G 的联邦类型可以用{T}@G 简洁地表示,如下所示。

str(federated_float_on_clients)
'{float32}@CLIENTS'

此简洁表示法中的花括号{} 提醒我们,成员组成部分(不同设备上的数据项)可能不同,正如您对温度传感器读数的预期一样,因此客户端作为一个组共同托管了一个多重集T 类型项,这些项共同构成联邦值。

重要的是要注意,联邦值的成员组成部分通常对程序员来说是不透明的,即联邦值不应被视为一个简单的dict,它以系统中设备的标识符为键 - 这些值旨在仅通过联邦运算符进行集体转换,这些运算符抽象地表示各种分布式通信协议(例如聚合)。如果这听起来太抽象了,别担心 - 我们将很快回到这一点,并将用具体的例子来说明它。

TFF 中的联邦类型有两种形式:那些联邦值的成员组成部分可能不同(如上面所见),以及那些已知都相等的类型。这是由tff.FederatedType 构造函数中的第三个可选all_equal 参数控制的(默认为False)。

federated_float_on_clients.all_equal
False

具有放置G 的联邦类型,其中所有T 类型成员组成部分都已知相等,可以用T@G 简洁地表示(而不是{T}@G,即,省略花括号以反映成员组成部分的多重集仅包含一个项目)。

str(tff.FederatedType(np.float32, tff.CLIENTS, all_equal=True))
'float32@CLIENTS'

在实际场景中可能出现的这种类型联邦值的示例是超参数(例如学习率、裁剪范数等),它已由服务器广播到参与联邦训练的一组设备。

另一个示例是在服务器上预先训练的机器学习模型的参数集,然后将其广播到一组客户端设备,在那里可以为每个用户个性化这些参数。

例如,假设我们有一对float32 参数ab,用于简单的单维线性回归模型。我们可以构建此类模型的(非联邦)类型,以便在 TFF 中使用,如下所示。打印类型字符串中的尖括号<> 是 TFF 的简洁表示法,用于命名或未命名的元组。

simple_regression_model_type = (
    tff.StructType([('a', np.float32), ('b', np.float32)]))

str(simple_regression_model_type)
'<a=float32,b=float32>'

请注意,我们只在上面指定了dtype。也支持非标量类型。在上面的代码中,np.float32 是更一般的tff.TensorType(np.float32, []) 的快捷表示法。

当此模型广播到客户端时,结果联邦值的类型可以表示如下。

str(tff.FederatedType(
    simple_regression_model_type, tff.CLIENTS, all_equal=True))
'<a=float32,b=float32>@CLIENTS'

与上面的联邦浮点数对称,我们将这种类型称为联邦元组。更一般地说,我们通常使用术语联邦 XYZ 来指代成员组成部分是XYZ 类的联邦值。因此,我们将讨论诸如联邦元组联邦序列联邦模型等内容。

现在,回到float32@CLIENTS - 虽然它似乎在多个设备上复制,但它实际上是一个单一的float32,因为所有成员都是相同的。一般来说,您可以将任何全等联邦类型(即T@G 形式的类型)视为与非联邦类型T 同构,因为在这两种情况下,实际上只有一个(虽然可能复制)类型为T 的项目。

鉴于TT@G 之间的同构性,您可能想知道后者类型可能有什么用处(如果有的话)。继续阅读。

放置

设计概述

在上一节中,我们介绍了放置的概念 - 可能共同托管联邦值的系统参与者组,并且我们演示了使用tff.CLIENTS 作为放置规范的示例。

为了解释为什么放置的概念如此基本,以至于我们需要将其纳入 TFF 类型系统,请回想一下我们在本教程开头提到的 TFF 的一些预期用途。

虽然在本教程中,您只会看到 TFF 代码在模拟环境中本地执行,但我们的目标是让 TFF 能够编写可以在分布式系统中一组物理设备上执行的代码,这些设备可能包括运行 Android 的移动设备或嵌入式设备。这些设备中的每一个都将收到一组单独的指令,以根据其在系统中的角色(最终用户设备、集中式协调器、多层架构中的中间层等)在本地执行。能够推断哪些设备子集执行哪些代码以及数据的不同部分可能在物理上实现的位置非常重要。

这在处理移动设备上的应用程序数据时尤其重要。由于数据是私密的并且可能是敏感的,我们需要能够静态地验证此数据永远不会离开设备(并证明有关如何处理数据的结论)。放置规范是旨在支持此功能的机制之一。

TFF 被设计为以数据为中心的编程环境,因此,与一些专注于操作及其运行位置的现有框架不同,TFF 专注于数据、数据实现的位置以及如何转换它。因此,放置在 TFF 中被建模为数据的属性,而不是数据的操作的属性。事实上,正如您将在下一节中看到的那样,一些 TFF 操作跨越位置,并在网络中运行,而不是由单个机器或一组机器执行。

将某个值的类型表示为T@G{T}@G(而不是仅仅T)使数据放置决策明确,并且与对 TFF 中编写的程序的静态分析一起,它可以作为为敏感的设备上数据提供正式隐私保证的基础。

然而,在这一点上需要注意的是,虽然我们鼓励 TFF 用户明确参与设备的(放置)来托管数据,但程序员永远不会处理单个参与者的原始数据或身份。

在 TFF 代码主体中,通过设计,无法枚举构成tff.CLIENTS 所表示的组的设备,也无法探测组中是否存在特定设备。在 Federated Core API、底层架构抽象集或我们提供的支持模拟的核心运行时基础设施中,没有任何设备或客户端身份的概念。您编写的所有计算逻辑都将表示为对整个客户端组的操作。

请回想一下我们之前提到的联邦类型的值与 Pythondict 不同,因为无法简单地枚举它们的成员组成部分。将您的 TFF 程序逻辑操作的值视为与放置(组)相关联,而不是与单个参与者相关联。

放置在 TFF 中也被设计为一等公民,可以作为参数和结果出现在 placement 类型中(在 API 中由 tff.PlacementType 表示)。将来,我们计划提供各种运算符来转换或组合放置,但这超出了本教程的范围。目前,将 placement 视为 TFF 中不透明的内置基本类型就足够了,类似于 intbool 是 Python 中不透明的内置类型,tff.CLIENTS 是这种类型的常量字面量,类似于 1int 类型的常量字面量。

指定放置

TFF 提供了两个基本的放置字面量,tff.CLIENTStff.SERVER,以便轻松表达各种实际场景,这些场景自然地被建模为客户端-服务器架构,其中多个客户端设备(移动电话、嵌入式设备、分布式数据库、传感器等)由单个集中式服务器协调器进行协调。TFF 还被设计为支持自定义放置、多个客户端组、多层级和其他更通用的分布式架构,但讨论这些内容超出了本教程的范围。

TFF 并没有规定 tff.CLIENTStff.SERVER 实际上代表什么。

特别是,tff.SERVER 可以是单个物理设备(单例组的成员),但它也可能是一个运行状态机复制的容错集群中的副本组 - 我们不作任何特殊的架构假设。相反,我们使用上一节中提到的 all_equal 位来表达这样一个事实,即我们通常只处理服务器上的单个数据项。

同样,在某些应用程序中,tff.CLIENTS 可能代表系统中的所有客户端 - 在联邦学习的背景下,我们有时将其称为总体,但在例如 联邦平均的生产实现 中,它可能代表一个队列 - 在特定训练轮次中被选中参与的客户端子集。当包含它们的计算被部署以执行(或像本教程中演示的那样,在模拟环境中像 Python 函数一样被调用)时,抽象定义的放置将被赋予具体的含义。在我们的本地模拟中,客户端组由作为输入提供的联邦数据决定。

联邦计算

声明联邦计算

TFF 被设计为一个强类型函数式编程环境,支持模块化开发。

TFF 中的基本组合单元是联邦计算 - 一段逻辑,它可以接受联邦值作为输入,并返回联邦值作为输出。以下是定义一个计算的示例,该计算计算我们之前示例中传感器阵列报告的温度的平均值。

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def get_average_temperature(sensor_readings):
  return tff.federated_mean(sensor_readings)

查看上面的代码,此时您可能会问 - 在 TensorFlow 中,难道没有用于定义可组合单元的装饰器结构,例如 tf.function 吗?如果有,为什么还要引入另一个装饰器,它有什么不同?

简短的回答是,由 tff.federated_computation 包装器生成的代码既不是 TensorFlow,也不是 Python - 它是在内部平台无关的粘合语言中对分布式系统的规范。目前,这无疑听起来很神秘,但请记住这种对联邦计算的直观解释,即它是一个分布式系统的抽象规范。我们将在稍后解释它。

首先,让我们玩一下这个定义。TFF 计算通常被建模为函数 - 有参数或没有参数,但具有明确定义的类型签名。您可以通过查询其 type_signature 属性来打印计算的类型签名,如下所示。

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

类型签名告诉我们,该计算接受客户端设备上不同传感器读数的集合,并在服务器上返回单个平均值。

在我们继续之前,让我们花点时间思考一下 - 该计算的输入和输出位于不同的位置(在 CLIENTS 上与在 SERVER 上)。回想一下我们在上一节关于放置中所说的,TFF 运算可能会跨越位置,并在网络中运行,以及我们刚刚关于联邦计算代表分布式系统的抽象规范所说的。我们刚刚定义了一个这样的计算 - 一个简单的分布式系统,其中数据在客户端设备上被消费,而聚合结果出现在服务器上。

在许多实际场景中,代表顶级任务的计算往往会在服务器上接受其输入并报告其输出 - 这反映了这样一个想法,即计算可能是由查询触发的,这些查询起源于服务器并终止于服务器。

但是,FC API 并没有强加这种假设,我们内部使用的许多构建块(包括您可能在 API 中找到的许多 tff.federated_... 运算符)具有不同放置的输入和输出,因此一般来说,您不应该将联邦计算视为在服务器上运行由服务器执行的东西。服务器只是联邦计算中的一种参与者类型。在考虑此类计算的机制时,最好始终默认使用全局网络范围的视角,而不是单个集中式协调器的视角。

一般来说,函数类型签名被紧凑地表示为 (T -> U),分别用于类型 TU 的输入和输出。形式参数的类型(在本例中为 sensor_readings)被指定为装饰器的参数。您不需要指定结果的类型 - 它会自动确定。

虽然 TFF 提供了有限形式的多态性,但强烈建议程序员明确说明他们使用的数据的类型,因为这使得理解、调试和正式验证代码属性变得更容易。在某些情况下,明确指定类型是必需的(例如,多态计算目前不可直接执行)。

执行联邦计算

为了支持开发和调试,TFF 允许您像 Python 函数一样直接调用以这种方式定义的计算,如下所示。如果计算期望一个具有 all_equal 位设置为 False 的联邦类型的值,您可以将其作为 Python 中的普通 list 提供,对于具有 all_equal 位设置为 True 的联邦类型,您可以直接提供(单个)成员成分。这也是结果如何报告给您的方式。

get_average_temperature([68.5, 70.3, 69.8])
69.53334

在模拟模式下运行计算时,您充当具有系统范围视图的外部观察者,能够在网络中的任何位置提供输入和消费输出,正如这里所发生的那样 - 您在输入处提供了客户端值,并消费了服务器结果。

现在,让我们回到之前关于 tff.federated_computation 装饰器在粘合语言中发出代码的说明。虽然 TFF 计算的逻辑可以用 Python 中的普通函数来表达(您只需要用 tff.federated_computation 装饰它们,就像我们在上面所做的那样),并且您可以在这个笔记本中像任何其他 Python 函数一样直接用 Python 参数调用它们,但在幕后,正如我们之前提到的,TFF 计算实际上不是 Python。

我们的意思是,当 Python 解释器遇到一个用 tff.federated_computation 装饰的函数时,它会跟踪该函数主体中的语句一次(在定义时),然后构建一个 序列化表示,用于未来的使用 - 无论是用于执行,还是作为子组件被合并到另一个计算中。

您可以通过添加一个打印语句来验证这一点,如下所示

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def get_average_temperature(sensor_readings):

  print ('Getting traced, the argument is "{}".'.format(
      type(sensor_readings).__name__))

  return tff.federated_mean(sensor_readings)
Getting traced, the argument is "Value".

您可以将定义联邦计算的 Python 代码视为类似于您如何看待在非急切上下文中构建 TensorFlow 图的 Python 代码(如果您不熟悉 TensorFlow 的非急切用法,请将您的 Python 代码视为定义一个要在稍后执行的操作图,但不要在运行时实际执行它们)。TensorFlow 中的非急切图构建代码是 Python,但由该代码构建的 TensorFlow 图是平台无关且可序列化的。

同样,TFF 计算是在 Python 中定义的,但其主体中的 Python 语句,例如我们刚刚展示的示例中的 tff.federated_mean,在幕后被编译成一个可移植的、平台无关的可序列化表示。

作为开发人员,您不需要关心此表示的细节,因为您永远不需要直接使用它,但您应该知道它的存在,TFF 计算本质上是非急切的,并且不能捕获任意 Python 状态。包含在 TFF 计算主体中的 Python 代码在定义时执行,当用 tff.federated_computation 装饰的 Python 函数的主体在被序列化之前被跟踪。它不会在调用时再次被跟踪(除非该函数是多态的;请参阅文档页面以了解详细信息)。

您可能想知道为什么我们选择引入一个专门的内部非 Python 表示。一个原因是,最终,TFF 计算旨在部署到真实的物理环境中,并在移动设备或嵌入式设备上托管,而这些设备可能没有 Python。

另一个原因是,TFF 计算表达了分布式系统的全局行为,而不是表达单个参与者的本地行为的 Python 程序。您可以在上面的简单示例中看到这一点,它使用特殊的运算符 tff.federated_mean 接受客户端设备上的数据,但将结果存储在服务器上。

运算符 tff.federated_mean 不能轻易地被建模为 Python 中的普通运算符,因为它不会在本地执行 - 正如我们之前提到的,它代表一个协调多个系统参与者行为的分布式系统。我们将此类运算符称为联邦运算符,以区别于 Python 中的普通(本地)运算符。

因此,TFF 类型系统和 TFF 语言中支持的基本运算符集与 Python 中的那些有很大不同,因此需要使用专门的表示。

组合联邦计算

如上所述,联合计算及其组成部分最好理解为分布式系统的模型,您可以将联合计算的组合视为从更简单的分布式系统中组合更复杂的分布式系统。您可以将 tff.federated_mean 运算符视为一种内置模板联合计算,其类型签名为 ({T}@CLIENTS -> T@SERVER)(实际上,就像您编写的计算一样,此运算符也具有复杂的结构 - 在幕后,我们将它分解为更简单的运算符)。

联合计算的组合也是如此。计算 get_average_temperature 可以在用 tff.federated_computation 装饰的另一个 Python 函数的函数体中调用 - 这样做会导致它被嵌入到父函数的函数体中,就像 tff.federated_mean 之前被嵌入到自己的函数体中一样。

需要注意的一个重要限制是,用 tff.federated_computation 装饰的 Python 函数的函数体必须仅包含联合运算符,即它们不能直接包含 TensorFlow 运算符。例如,您不能直接使用 tf.nest 接口来添加一对联合值。TensorFlow 代码必须限制在用 tff.tensorflow.computation 装饰的代码块中,下一节将讨论此代码块。只有以这种方式包装后,包装的 TensorFlow 代码才能在 tff.federated_computation 的函数体中调用。

这种分离的原因是技术性的(很难欺骗像 tf.add 这样的运算符来处理非张量),以及架构性的。联合计算的语言(即从用 tff.federated_computation 装饰的 Python 函数的序列化函数体构建的逻辑)旨在用作平台无关的粘合语言。这种粘合语言目前用于从嵌入的 TensorFlow 代码段(限制在 tff.tensorflow.computation 块中)构建分布式系统。随着时间的推移,我们预计需要嵌入其他非 TensorFlow 逻辑的部分,例如可能代表输入管道的关系数据库查询,所有这些都使用相同的粘合语言(tff.federated_computation 块)连接在一起。

TensorFlow 逻辑

声明 TensorFlow 计算

TFF 旨在与 TensorFlow 一起使用。因此,您在 TFF 中编写的绝大多数代码很可能是普通的(即本地执行的)TensorFlow 代码。为了将这种代码与 TFF 一起使用,如上所述,它只需要用 tff.tensorflow.computation 装饰即可。

例如,以下是如何实现一个函数,该函数接受一个数字并向其添加 0.5

@tff.tensorflow.computation(np.float32)
def add_half(x):
  return tf.add(x, 0.5)

再次,看到这一点,您可能想知道为什么我们应该定义另一个装饰器 tff.tensorflow.computation 而不是简单地使用现有的机制,例如 tf.function。与上一节不同,这里我们处理的是一个普通的 TensorFlow 代码块。

有几个原因,完整的处理超出了本教程的范围,但值得一提的是主要原因。

  • 为了将使用 TensorFlow 代码实现的可重用构建块嵌入到联合计算的函数体中,它们需要满足某些属性 - 例如在定义时被跟踪和序列化,具有类型签名等。这通常需要某种形式的装饰器。

一般来说,我们建议尽可能使用 TensorFlow 的原生机制进行组合,例如 tf.function,因为 TFF 的装饰器与急切函数交互的确切方式预计会发生变化。

现在,回到上面的示例代码片段,我们刚刚定义的计算 add_half 可以像任何其他 TFF 计算一样被 TFF 处理。特别是,它具有 TFF 类型签名。

str(add_half.type_signature)
'(float32 -> float32)'

请注意,此类型签名没有位置。TensorFlow 计算不能使用或返回联合类型。

您现在也可以使用 add_half 作为其他计算的构建块。例如,以下是如何使用 tff.federated_map 运算符将 add_half 逐点应用于客户端设备上联合浮点数的所有成员组成部分。

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)
str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

执行 TensorFlow 计算

tff.tensorflow.computation 定义的计算的执行遵循与我们为 tff.federated_computation 描述的相同规则。它们可以像 Python 中的普通可调用对象一样被调用,如下所示。

add_half_on_clients([1.0, 3.0, 2.0])
[<tf.Tensor: shape=(), dtype=float32, numpy=1.5>,
 <tf.Tensor: shape=(), dtype=float32, numpy=3.5>,
 <tf.Tensor: shape=(), dtype=float32, numpy=2.5>]

再次,值得注意的是,以这种方式调用计算 add_half_on_clients 模拟了一个分布式过程。数据在客户端被消费,并在客户端返回。实际上,此计算让每个客户端执行一个本地操作。在这个系统中没有明确提到 tff.SERVER(即使在实践中,协调这种处理可能涉及一个)。将以这种方式定义的计算视为概念上类似于 MapReduce 中的 Map 阶段。

此外,请记住,我们在上一节中关于 TFF 计算在定义时被序列化的内容对于 tff.tensorflow.computation 代码也是如此 - add_half_on_clients 的 Python 函数体在定义时被跟踪一次。在随后的调用中,TFF 使用其序列化表示。

tff.federated_computation 装饰的 Python 方法与用 tff.tensorflow.computation 装饰的 Python 方法之间的唯一区别是,后者被序列化为 TensorFlow 图(而前者不允许直接嵌入其中的 TensorFlow 代码)。

在幕后,每个用 tff.tensorflow.computation 装饰的方法都会暂时禁用急切执行,以便能够捕获计算的结构。虽然急切执行在本地被禁用,但您可以随意使用急切 TensorFlow、AutoGraph、TensorFlow 2.0 结构等,只要您以一种能够正确序列化的方式编写计算的逻辑即可。

例如,以下代码将失败

try:

  # Eager mode
  constant_10 = tf.constant(10.)

  @tff.tensorflow.computation(np.float32)
  def add_ten(x):
    return x + constant_10

except Exception as err:
  print (err)
Attempting to capture an EagerTensor without building a function.

上面的代码失败是因为 constant_10 已经在 tff.tensorflow.computation 在序列化过程中在 add_ten 的函数体内内部构建的图之外被构建。

另一方面,在 tff.tensorflow.computation 中调用在被调用时修改当前图的 Python 函数是可以的

def get_constant_10():
  return tf.constant(10.)

@tff.tensorflow.computation(np.float32)
def add_ten(x):
  return x + get_constant_10()

add_ten(5.0)
15.0

请注意,TensorFlow 中的序列化机制正在不断发展,我们预计 TFF 序列化计算的细节也会随之发展。

使用 tf.data.Dataset

如前所述,tff.tensorflow.computation 的一个独特功能是,它们允许您使用 tf.data.Dataset,这些数据被抽象地定义为您的代码的形式参数。需要在 TensorFlow 中表示为数据集的参数需要使用 tff.SequenceType 构造函数声明。

例如,类型规范 tff.SequenceType(np.float32) 在 TFF 中定义了一个抽象的浮点元素序列。序列可以包含张量或复杂的嵌套结构(我们将在后面看到这些结构的示例)。T 类型项序列的简明表示为 T*

float32_sequence = tff.SequenceType(np.float32)

str(float32_sequence)
'float32*'

假设在我们的温度传感器示例中,每个传感器不仅保存一个温度读数,还保存多个读数。以下是如何使用 tf.data.Dataset.reduce 运算符在 TensorFlow 中定义一个 TFF 计算,该计算使用单个本地数据集计算温度的平均值。

@tff.tensorflow.computation(tff.SequenceType(np.float32))
def get_local_temperature_average(local_temperatures):
  sum_and_count = (
      local_temperatures.reduce((0.0, 0), lambda x, y: (x[0] + y, x[1] + 1)))
  return sum_and_count[0] / tf.cast(sum_and_count[1], tf.float32)
str(get_local_temperature_average.type_signature)
'(float32* -> float32)'

在用 tff.tensorflow.computation 装饰的方法的函数体中,TFF 序列类型的形式参数简单地表示为行为类似于 tf.data.Dataset 的对象,即支持相同的属性和方法(它们目前没有实现为该类型的子类 - 随着 TensorFlow 中对数据集的支持不断发展,这可能会发生变化)。

您可以轻松地验证这一点,如下所示。

@tff.tensorflow.computation(tff.SequenceType(np.int32))
def foo(x):
  return x.reduce(np.int32(0), lambda x, y: x + y)

foo([1, 2, 3])
6

请记住,与普通的 tf.data.Dataset 不同,这些类似数据集的对象是占位符。它们不包含任何元素,因为它们代表抽象的序列类型参数,在具体上下文中使用时,将绑定到具体数据。对抽象定义的占位符数据集的支持目前仍然有限,在 TFF 的早期阶段,您可能会遇到某些限制,但我们不必在本教程中担心它们(有关详细信息,请参阅文档页面)。

在模拟模式下本地执行接受序列的计算时,例如在本教程中,您可以将序列作为 Python 列表提供,如下所示(以及其他方式,例如,在急切模式下作为 tf.data.Dataset,但现在,我们将保持简单)。

get_local_temperature_average([68.5, 70.3, 69.8])
69.53333

与所有其他 TFF 类型一样,上面定义的序列可以使用 tff.StructType 构造函数来定义嵌套结构。例如,以下是如何声明一个计算,该计算接受 AB 对的序列,并返回其乘积的总和。我们在计算的函数体中包含跟踪语句,以便您可以看到 TFF 类型签名如何转换为数据集的 output_typesoutput_shapes

@tff.tensorflow.computation(tff.SequenceType(collections.OrderedDict([('A', np.int32), ('B', np.int32)])))
def foo(ds):
  print('element_structure = {}'.format(ds.element_spec))
  return ds.reduce(np.int32(0), lambda total, x: total + x['A'] * x['B'])
element_structure = OrderedDict([('A', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('B', TensorSpec(shape=(), dtype=tf.int32, name=None))])
str(foo.type_signature)
'(<A=int32,B=int32>* -> int32)'
foo([{'A': 2, 'B': 3}, {'A': 4, 'B': 5}])
26

使用 tf.data.Datasets 作为形式参数的支持仍然有限,并且正在不断发展,尽管在像本教程中使用的简单场景中是有效的。

将所有内容整合在一起

现在,让我们尝试再次在联合环境中使用我们的 TensorFlow 计算。假设我们有一组传感器,每个传感器都有一个本地温度读数序列。我们可以通过对传感器本地平均值求平均来计算全局温度平均值,如下所示。

@tff.federated_computation(
    tff.FederatedType(tff.SequenceType(np.float32), tff.CLIENTS))
def get_global_temperature_average(sensor_readings):
  return tff.federated_mean(
      tff.federated_map(get_local_temperature_average, sensor_readings))

请注意,这不是对所有客户端所有本地温度读数的简单平均值,因为这将需要根据每个客户端本地维护的读数数量来权衡来自不同客户端的贡献。我们将此留给读者作为练习来更新上面的代码;tff.federated_mean 运算符接受权重作为可选的第二个参数(预期为联邦浮点数)。

另请注意,get_global_temperature_average 的输入现在变成了一个联邦浮点数序列。联邦序列是我们通常在联邦学习中表示设备上数据的形式,序列元素通常代表数据批次(您很快就会看到这方面的示例)。

str(get_global_temperature_average.type_signature)
'({float32*}@CLIENTS -> float32@SERVER)'

以下是如何在 Python 中使用样本数据在本地执行计算。请注意,我们现在以 listlist 形式提供输入。外部列表遍历由 tff.CLIENTS 表示的组中的设备,内部列表遍历每个设备本地序列中的元素。

get_global_temperature_average([[68.0, 70.0], [71.0], [68.0, 72.0, 70.0]])
70.0

本教程的第一部分到此结束... 我们鼓励您继续学习 第二部分