联邦学习

概述

本文档介绍了用于促进联邦学习任务的接口,例如使用 TensorFlow 中实现的现有机器学习模型进行联邦训练或评估。在设计这些接口时,我们的主要目标是使人们能够在无需了解其内部工作原理的情况下进行联邦学习实验,并在各种现有模型和数据上评估已实现的联邦学习算法。我们鼓励您为平台做出贡献。TFF 的设计考虑了可扩展性和可组合性,我们欢迎贡献;我们很高兴看到您想出什么!

此层提供的接口包含以下三个关键部分

  • 模型。类和辅助函数,允许您包装现有模型以供 TFF 使用。包装模型可以像调用单个包装函数一样简单(例如,tff.learning.models.from_keras_model),或者定义 tff.learning.models.VariableModel 接口的子类以实现完全自定义。

  • 联邦计算构建器。辅助函数,使用您的现有模型构建用于训练或评估的联邦计算。

  • 数据集。您可以下载并在 Python 中访问的预制数据集合,用于模拟联邦学习场景。虽然联邦学习旨在用于无法在集中位置简单下载的去中心化数据,但在研究和开发阶段,使用可以下载和本地操作的数据进行初步实验通常很方便,特别是对于可能不熟悉这种方法的开发人员。

这些接口主要在 tff.learning 命名空间中定义,除了研究数据集和其他与模拟相关的功能,这些功能已在 tff.simulation 中分组。此层使用 联邦核心 (FC) 提供的更低级接口实现,该接口还提供运行时环境。

在继续之前,我们建议您首先查看有关 图像分类文本生成 的教程,因为它们使用具体示例介绍了此处描述的大多数概念。如果您有兴趣了解有关 TFF 工作原理的更多信息,您可能需要浏览一下 自定义算法 教程,作为对我们用来表达联邦计算逻辑的更低级接口的介绍,并研究 tff.learning 接口的现有实现。

模型

架构假设

序列化

TFF 旨在支持各种分布式学习场景,在这些场景中,您编写的机器学习模型代码可能在大量具有不同功能的异构客户端上执行。虽然在频谱的一端,在某些应用程序中,这些客户端可能是功能强大的数据库服务器,但我们平台打算支持的许多重要用途涉及资源有限的移动设备和嵌入式设备。我们不能假设这些设备能够托管 Python 运行时;目前我们唯一可以假设的是它们能够托管本地 TensorFlow 运行时。因此,我们在 TFF 中做出的一个基本架构假设是,您的模型代码必须可序列化为 TensorFlow 图。

您仍然可以(并且应该)按照最新的最佳实践开发您的 TF 代码,例如使用急切模式。但是,最终代码必须可序列化(例如,可以包装为 tf.function,用于急切模式代码)。这确保了在执行时所需的任何 Python 状态或控制流都可以序列化(可能借助 Autograph)。

目前,TensorFlow 并不完全支持序列化和反序列化急切模式 TensorFlow。因此,TFF 中的序列化目前遵循 TF 1.0 模式,其中所有代码都必须在 TFF 控制的 tf.Graph 内构建。这意味着目前 TFF 无法使用已构建的模型;相反,模型定义逻辑打包在一个无参数函数中,该函数返回一个 tff.learning.models.VariableModel。然后,此函数由 TFF 调用以确保模型的所有组件都被序列化。此外,作为一种强类型环境,TFF 将需要一些额外的元数据,例如模型输入类型的规范。

聚合

我们强烈建议大多数用户使用 Keras 构建模型,请参阅下面的 Keras 转换器 部分。这些包装器会自动处理模型更新的聚合以及为模型定义的任何指标。但是,了解如何为一般 tff.learning.models.VariableModel 处理聚合可能仍然有用。

在联邦学习中,至少始终存在两层聚合:本地设备聚合和跨设备(或联邦)聚合

  • 本地聚合。此级别的聚合是指跨单个客户端拥有的多个示例批次进行的聚合。它适用于模型参数(变量),这些参数在模型进行本地训练时会继续依次演变,以及您计算的统计信息(例如平均损失、准确率和其他指标),您的模型将在迭代每个客户端的本地数据流时再次在本地更新这些统计信息。

    在这一级别上执行聚合是模型代码的责任,可以使用标准的 TensorFlow 结构来完成。

    处理的一般结构如下

    • 模型首先构建 tf.Variable 来保存聚合,例如批次数量或处理的示例数量、每个批次或每个示例损失的总和等。

    • TFF 在您的 Model 上多次调用 forward_pass 方法,依次遍历后续的客户端数据批次,这使您可以将保存各种聚合的变量作为副作用更新。

    • 最后,TFF 在您的模型上调用 report_local_unfinalized_metrics 方法,以允许您的模型将收集的所有摘要统计信息编译成一组紧凑的指标,以便客户端导出。例如,您的模型代码可以将损失总和除以处理的示例数量,以导出平均损失等。

  • 联邦聚合。这种级别的聚合是指系统中多个客户端(设备)之间的聚合。同样,它适用于模型参数(变量)(在客户端之间进行平均),以及您的模型作为本地聚合结果导出的指标。

    在这一级别上执行聚合是 TFF 的责任。但是,作为模型创建者,您可以控制此过程(更多内容见下文)。

    处理的一般结构如下

    • 初始模型以及训练所需的任何参数由服务器分发到将参与一轮训练或评估的客户端子集。

    • 在每个客户端上,您的模型代码独立且并行地在本地数据批次流上重复调用,以生成一组新的模型参数(在训练时)和一组新的本地指标,如上所述(这是本地聚合)。

    • TFF 运行一个分布式聚合协议,以在整个系统中累积和聚合模型参数和本地导出的指标。此逻辑使用 TFF 自己的联邦计算语言以声明方式表达(而不是在 TensorFlow 中)。有关聚合 API 的更多信息,请参见 自定义算法 教程。

抽象接口

此基本构造函数 + 元数据接口由接口 tff.learning.models.VariableModel 表示,如下所示

  • 构造函数、forward_passreport_local_unfinalized_metrics 方法应分别构建模型变量、前向传递和您希望报告的统计信息。这些方法构建的 TensorFlow 必须是可序列化的,如上所述。

  • input_spec 属性以及返回可训练变量、不可训练变量和本地变量子集的 3 个属性表示元数据。TFF 使用此信息来确定如何将模型的各个部分连接到联邦优化算法,以及定义内部类型签名以帮助验证构建的系统的正确性(以便您的模型不能在与模型设计用于使用的数据不匹配的数据上实例化)。

此外,抽象接口 tff.learning.models.VariableModel 公开了属性 metric_finalizers,它接收指标的未最终确定值(由 report_local_unfinalized_metrics() 返回)并返回最终确定的指标值。 metric_finalizersreport_local_unfinalized_metrics() 方法将一起用于在定义联邦训练过程或评估计算时构建跨客户端指标聚合器。例如,一个简单的 tff.learning.metrics.sum_then_finalize 聚合器将首先对来自客户端的未最终确定指标值求和,然后在服务器上调用最终确定器函数。

您可以在我们 图像分类 教程的第二部分以及我们在 model_examples.py 中用于测试的示例模型中找到如何定义您自己的自定义 tff.learning.models.VariableModel 的示例。

Keras 转换器

TFF 所需的几乎所有信息都可以通过调用 tf.keras 接口来获取,因此如果您有 Keras 模型,您可以依靠 tff.learning.models.from_keras_model 来构建一个 tff.learning.models.VariableModel.

请注意,TFF 仍然希望您提供一个构造函数 - 一个无参数的模型函数,例如以下函数

def model_fn():
  keras_model = ...
  return tff.learning.models.from_keras_model(keras_model, sample_batch, loss=...)

除了模型本身之外,您还提供了一个数据样本批次,TFF 使用它来确定模型输入的类型和形状。这确保了 TFF 可以为客户端设备上实际存在的数据正确实例化模型(因为我们假设在您构建要序列化的 TensorFlow 时,这些数据通常不可用)。

我们在 图像分类文本生成 教程中说明了 Keras 包装器的使用。

联邦计算构建器

tff.learning 包为执行与学习相关的任务的 tff.Computation 提供了几个构建器;我们预计此类计算的集合将在未来扩展。

架构假设

执行

运行联邦计算有两个不同的阶段。

  • 编译:TFF 首先将联邦学习算法编译成整个分布式计算的抽象序列化表示。这是 TensorFlow 序列化发生的时候,但其他转换也可能发生以支持更有效的执行。我们将编译器发出的序列化表示称为联邦计算

  • 执行 TFF 提供了执行这些计算的方法。目前,仅通过本地模拟(例如,在使用模拟分散数据的笔记本中)支持执行。

由 TFF 的联邦学习 API 生成的联邦计算,例如使用 联邦模型平均 的训练算法或联邦评估,包括许多元素,最值得注意的是

  • 您的模型代码的序列化形式以及联邦学习框架构建的额外 TensorFlow 代码,以驱动模型的训练/评估循环(例如,构建优化器、应用模型更新、迭代 tf.data.Dataset、计算指标和应用服务器上的聚合更新,仅举几例)。

  • 客户端服务器之间通信的声明性规范(通常是跨客户端设备的各种形式的聚合,以及从服务器到所有客户端的广播),以及这种分布式通信如何与客户端本地或服务器本地 TensorFlow 代码的执行交织在一起。

此序列化形式中表示的联邦计算是用一种与平台无关的内部语言表达的,该语言不同于 Python,但要使用联邦学习 API,您无需关心此表示的细节。计算在您的 Python 代码中表示为类型为 tff.Computation 的对象,在大多数情况下,您可以将它们视为不透明的 Python callable

在教程中,您将调用这些联邦计算,就好像它们是常规的 Python 函数一样,在本地执行。但是,TFF 旨在以一种对大多数执行环境方面无关的方式表达联邦计算,以便它们可以潜在地部署到例如运行 Android 的设备组或数据中心中的集群。同样,这带来的主要后果是对 序列化 的强烈假设。特别是,当您调用下面描述的 build_... 方法之一时,计算将完全序列化。

建模状态

TFF 是一个函数式编程环境,但联邦学习中许多感兴趣的过程是有状态的。例如,涉及多轮联邦模型平均的训练循环是我们可以归类为有状态过程的示例。在此过程中,从一轮到下一轮演化的状态包括正在训练的模型参数集,以及可能与优化器相关的其他状态(例如,动量向量)。

由于 TFF 是函数式的,因此有状态过程在 TFF 中被建模为接受当前状态作为输入,然后提供更新后的状态作为输出的计算。为了完全定义一个有状态过程,还需要指定初始状态来自哪里(否则我们就无法启动该过程)。这在辅助类 tff.templates.IterativeProcess 的定义中得到了体现,其中 initializenext 这 2 个属性分别对应于初始化和迭代。

可用构建器

目前,TFF 提供了各种构建器函数,这些函数生成用于联邦训练和评估的联邦计算。两个值得注意的例子包括

数据集

架构假设

客户端选择

在典型的联邦学习场景中,我们拥有一个庞大的人口,可能包含数亿个客户端设备,其中只有一小部分可能在任何给定时间处于活动状态并可用于训练(例如,这可能仅限于已插入电源、不在计量网络上且处于空闲状态的客户端)。通常,可用于参与训练或评估的客户端集不受开发人员控制。此外,由于协调数百万个客户端是不切实际的,因此典型的训练或评估轮次将仅包括一小部分可用客户端,这些客户端可能 随机抽样

这带来的关键结果是,联邦计算在设计上以一种对参与者确切集一无所知的方式表达;所有处理都表示为对一组抽象的匿名客户端的聚合操作,并且该组可能在一轮训练到下一轮训练之间有所不同。因此,将计算绑定到具体参与者(从而绑定到他们提供给计算的具体数据)是在计算本身之外建模的。

为了模拟您联邦学习代码的现实部署,您通常会编写一个看起来像这样的训练循环

trainer = tff.learning.algorithms.build_weighted_fed_avg(...)
state = trainer.initialize()
federated_training_data = ...

def sample(federate_data):
  return ...

while True:
  data_for_this_round = sample(federated_training_data)
  result = trainer.next(state, data_for_this_round)
  state = result.state

为了方便使用,在模拟中使用 TFF 时,联邦数据被接受为 Python list,每个参与的客户端设备对应一个元素,代表该设备的本地 tf.data.Dataset

抽象接口

为了标准化处理模拟联邦数据集,TFF 提供了一个抽象接口 tff.simulation.datasets.ClientData,它允许枚举客户端集合,并构建一个包含特定客户端数据的 tf.data.Dataset。这些 tf.data.Dataset 可以直接作为输入,在 eager 模式下馈送到生成的联邦计算中。

需要注意的是,访问客户端身份的能力是模拟数据集提供的功能,在模拟中可能需要对特定客户端子集的数据进行训练(例如,模拟不同类型客户端的昼夜可用性)。编译后的计算和底层运行时 *不* 涉及任何客户端身份的概念。一旦特定客户端子集的数据被选定为输入,例如在调用 tff.templates.IterativeProcess.next 时,客户端身份将不再出现。

可用数据集

我们为实现 tff.simulation.datasets.ClientData 接口以用于模拟的数据集专门分配了命名空间 tff.simulation.datasets,并为其提供了数据集以支持 图像分类文本生成 教程。我们鼓励您为平台贡献自己的数据集。