构建您自己的联邦学习算法

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

开始之前

在开始之前,请运行以下代码以确保您的环境已正确设置。如果您没有看到问候语,请参阅 安装 指南以获取说明。

pip install --quiet --upgrade tensorflow-federated
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

图像分类文本生成 教程中,您学习了如何为联邦学习 (FL) 设置模型和数据管道,并通过 TFF 的 tff.learning API 层执行联邦训练。

这只是 FL 研究的冰山一角。本教程讨论了如何在依赖 tff.learning API 的情况下实现联邦学习算法。在本教程中,您将完成以下操作

目标

  • 了解联邦学习算法的一般结构。
  • 探索 TFF 的联邦核心
  • 使用联邦核心直接实现联邦平均。

虽然本教程是自包含的,但最好先查看 图像分类文本生成 教程。

准备输入数据

首先加载并预处理 TFF 中包含的 EMNIST 数据集。有关更多详细信息,请参阅 图像分类 教程。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

为了将数据集馈送到我们的模型,数据被扁平化,每个示例都被转换为形式为 (flattened_image_vector, label) 的元组。

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]),
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

现在,选择少量客户端,并将上述预处理应用于它们的数据集。

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

准备模型

这使用与 图像分类 教程中相同的模型。此模型(通过 tf.keras 实现)具有一个隐藏层,然后是一个 softmax 层。

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

为了在 TFF 中使用此模型,将 Keras 模型包装为 tff.learning.models.FunctionalModel。这允许您在 TFF 中执行模型的 前向传递,并 提取模型输出。有关更多详细信息,请参阅 图像分类 教程。

keras_model = create_keras_model()
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
    input_spec=federated_train_data[0].element_spec,
    metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy
    ),
)

虽然上面使用 tf.keras 创建了 tff.learning.models.FunctionalModel,但 TFF 支持更通用的模型。这些模型具有以下捕获模型权重的相关属性

  • trainable_variables:对应于可训练层的张量的可迭代对象。
  • non_trainable_variables:对应于不可训练层的张量的可迭代对象。

在本教程中,将仅使用 trainable_variables(因为模型只有这些!)。

构建您自己的联邦学习算法

虽然 tff.learning API 允许您创建联邦平均的许多变体,但还有其他联邦算法不适合此框架。例如,您可能希望添加正则化、裁剪或更复杂的算法,例如 联邦 GAN 训练。您可能还对 联邦分析 感兴趣。

对于这些更高级的算法,您需要使用 TFF 编写我们自己的自定义算法。在许多情况下,联邦算法具有 4 个主要组件

  1. 服务器到客户端的广播步骤。
  2. 本地客户端更新步骤。
  3. 客户端到服务器的上传步骤。
  4. 服务器更新步骤。

在 TFF 中,联邦算法通常表示为一个 tff.templates.IterativeProcess(在整个过程中将简称为 IterativeProcess)。这是一个包含 initializenext 函数的类。这里,initialize 用于初始化服务器,而 next 将执行联邦算法的一个通信轮次。让我们写一个 FedAvg 迭代过程的骨架。

首先,有一个初始化函数,它只是创建一个 tff.learning.models.FunctionalModel,并返回其可训练权重。

def initialize_fn():
  trainable_weights, _ =  tff_model.initial_weights
  return trainable_weights

这个函数看起来不错,但正如你将在后面看到的那样,你需要做一个小小的修改才能使其成为一个“TFF 计算”。

接下来,让我们写一个 next_fn 的草图。

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

让我们专注于分别实现这四个组件。首先,让我们专注于可以用纯 TensorFlow 实现的部分,即客户端和服务器更新步骤。

TensorFlow 模块

客户端更新

可以使用 tff.learning.models.FunctionalModel 以与训练 TensorFlow 模型基本相同的方式进行客户端训练。特别是,可以使用 tf.GradientTape 计算数据批次的梯度,然后使用 client_optimizer 应用这些梯度。这将只涉及可训练权重。

@tf.function
def client_update(model, dataset, initial_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights and the optimizer
  # state.
  client_weights = initial_weights.trainable
  optimizer_state = client_optimizer.initialize(
      tf.nest.map_structure(tf.TensorSpec.from_tensor, client_weights)
  )

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    x, y = batch
    with tf.GradientTape() as tape:
      tape.watch(client_weights)
      # Compute a forward pass on the batch of data
      outputs = model.predict_on_batch(
          model_weights=(client_weights, ()), x=x, training=True
      )
      loss = model.loss(output=outputs, label=y)

    # Compute the corresponding gradient
    grads = tape.gradient(loss, client_weights)

    # Apply the gradient using a client optimizer.
    optimizer_state, client_weights = client_optimizer.next(
        optimizer_state, weights=client_weights, gradients=grads
    )

  return tff.learning.models.ModelWeights(client_weights, non_trainable=())

服务器更新

FedAvg 的服务器更新比客户端更新更简单。本教程将实现“普通”联邦平均,其中服务器模型权重被替换为客户端模型权重的平均值。同样,这仅使用可训练权重。

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  del model  # Unused, just take the mean_client_weights.
  return mean_client_weights

可以通过简单地返回 mean_client_weights 来简化代码片段。但是,联邦平均的更高级实现使用 mean_client_weights 以及更复杂的技巧,例如动量或自适应性。

挑战:实现一个 server_update 版本,该版本将服务器权重更新为 model_weights 和 mean_client_weights 的中点。(注意:这种“中点”方法类似于最近关于 Lookahead 优化器 的工作!)。

到目前为止,这仅涉及 TensorFlow 代码。这是有意为之,因为 TFF 允许你使用你已经熟悉的许多 TensorFlow 代码。接下来,你将需要指定编排逻辑,即决定服务器向客户端广播什么以及客户端向服务器上传什么的逻辑。

这将需要 TFF 的联邦核心

联邦核心简介

联邦核心 (FC) 是一组低级接口,它们是 tff.learning API 的基础。但是,这些接口并不局限于学习。事实上,它们可以用于分析和许多其他分布式数据的计算。

从高层次来看,联邦核心是一个开发环境,它允许紧凑地表达的程序逻辑将 TensorFlow 代码与分布式通信运算符(例如分布式求和和广播)结合起来。目标是让研究人员和从业人员明确控制其系统中的分布式通信,而无需系统实现细节(例如指定点对点网络消息交换)。

一个关键点是 TFF 是为隐私保护而设计的。因此,它允许明确控制数据驻留的位置,以防止数据在集中式服务器位置意外累积。

联邦数据

TFF 中的一个关键概念是“联邦数据”,它指的是分布式系统中一组设备上托管的一组数据项(例如客户端数据集或服务器模型权重)。所有设备上的整个值集合表示为单个联邦值

例如,假设有一些客户端设备,每个设备都有一个表示传感器温度的浮点数。这些浮点数可以通过以下方式表示为联邦浮点数

federated_float_on_clients = tff.FederatedType(np.float32, tff.CLIENTS)

联邦类型由其成员成分的类型 T(例如 np.float32)和设备组 G 指定。通常,G 既可以是 tff.CLIENTS,也可以是 tff.SERVER。这种联邦类型表示为 {T}@G,如下所示。

str(federated_float_on_clients)
'{float32}@CLIENTS'

为什么 TFF 如此重视放置?TFF 的一个关键目标是能够编写可以在真实分布式系统上部署的代码。这意味着必须推理哪些设备子集执行哪些代码以及不同的数据片段驻留在哪里。

TFF 专注于三件事:数据、数据放置的位置以及数据如何被转换。前两个封装在联邦类型中,而最后一个封装在联邦计算中。

联邦计算

TFF 是一个强类型函数式编程环境,其基本单元是联邦计算。这些是接受联邦值作为输入并返回联邦值作为输出的逻辑片段。

例如,假设你想对我们客户端传感器上的温度进行平均。你可以定义以下内容(使用我们的联邦浮点数)

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

你可能会问,这与 TensorFlow 中的 tf.function 装饰器有什么区别?关键答案是,由 tff.federated_computation 生成的代码既不是 TensorFlow 代码也不是 Python 代码;它是在内部平台无关的粘合语言中对分布式系统的规范。

虽然这听起来可能很复杂,但你可以将 TFF 计算视为具有明确定义类型签名的函数。这些类型签名可以直接查询。

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

这个 tff.federated_computation 接受联邦类型 {float32}@CLIENTS 的参数,并返回联邦类型 {float32}@SERVER 的值。联邦计算也可以从服务器到客户端、从客户端到客户端或从服务器到服务器。联邦计算也可以像普通函数一样组合,只要它们的类型签名匹配即可。

为了支持开发,TFF 允许你将 tff.federated_computation 作为 Python 函数调用。例如,你可以调用

get_average_temperature([68.5, 70.3, 69.8])
69.53333

非急切计算和 TensorFlow

需要注意两个关键限制。首先,当 Python 解释器遇到 tff.federated_computation 装饰器时,该函数将被跟踪一次并序列化以供将来使用。由于联邦学习的去中心化性质,这种未来的使用可能会发生在其他地方,例如远程执行环境。因此,TFF 计算本质上是非急切的。这种行为与 TensorFlow 中的 tf.function 装饰器的行为有些类似。

其次,联邦计算只能包含联邦运算符(例如 tff.federated_mean),它们不能包含 TensorFlow 运算。TensorFlow 代码必须限制在用 tff.tensorflow.computation 装饰的块中。大多数普通的 TensorFlow 代码可以直接装饰,例如以下函数,它接受一个数字并向其添加 0.5

@tff.tensorflow.computation(np.float32)
def add_half(x):
  return tf.add(x, 0.5)

这些也有类型签名,但没有放置。例如,你可以调用

str(add_half.type_signature)
'(float32 -> float32)'

这展示了 tff.federated_computationtff.tensorflow.computation 之间的一个重要区别。前者具有明确的放置,而后者则没有。

你可以在联邦计算中使用 tff.tensorflow.computation 块,方法是指定放置。让我们创建一个函数,它将一半加到客户端的联邦浮点数上。你可以通过使用 tff.federated_map 来做到这一点,它应用给定的 tff.tensorflow.computation,同时保留放置。

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

这个函数与 add_half 几乎相同,只是它只接受放置在 tff.CLIENTS 上的值,并返回具有相同放置的值。这可以在其类型签名中看到

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

总结

  • TFF 对联邦值进行操作。
  • 每个联邦值都有一个联邦类型,它包含一个类型(例如 np.float32)和一个放置(例如 tff.CLIENTS)。
  • 联邦值可以使用联邦计算进行转换,联邦计算必须用 tff.federated_computation 和联邦类型签名进行装饰。
  • TensorFlow 代码必须包含在用 tff.tensorflow.computation 装饰器装饰的块中。
  • 然后,这些块可以被合并到联邦计算中。

构建你自己的联邦学习算法,重温

现在你已经对联邦核心有了一点了解,你可以构建我们自己的联邦学习算法。请记住,在上面,你为我们的算法定义了一个 initialize_fn 和一个 next_fnnext_fn 将使用你用纯 TensorFlow 代码定义的 client_updateserver_update

但是,为了使我们的算法成为联邦计算,你需要让 next_fninitialize_fn 都成为 tff.federated_computation

TensorFlow 联邦模块

创建初始化计算

初始化函数将非常简单:您将使用 model_fn 创建一个模型。但是,请记住,您必须使用 tff.tensorflow.computation 将我们的 TensorFlow 代码分离出来。

@tff.tensorflow.computation
def server_init():
  return tff.learning.models.ModelWeights(*tff_model.initial_weights)

然后,您可以直接将它传递到使用 tff.federated_value 的联合计算中。

@tff.federated_computation
def initialize_fn():
  return tff.federated_eval(server_init, tff.SERVER)

创建 next_fn

现在可以使用客户端和服务器更新代码来编写实际的算法。首先,您将把 client_update 转换为一个 tff.tensorflow.computation,它接受客户端数据集和服务器权重,并输出更新的客户端权重张量。

您将需要相应的类型来正确地装饰我们的函数。幸运的是,服务器权重的类型可以直接从我们的模型中提取。

tf_dataset_type = tff.SequenceType(
    tff.types.tensorflow_to_type(tff_model.input_spec)
)

让我们看一下数据集类型签名。请记住,您使用了 28x28 的图像(带有整数标签)并将它们展平。

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

您还可以使用上面的 server_init 函数来提取模型权重类型。

model_weights_type = server_init.type_signature.result

检查类型签名,您将能够看到我们模型的架构!

str(model_weights_type)
'<trainable=<float32[784,10],float32[10]>,non_trainable=<>>'

您现在可以为客户端更新创建我们的 tff.tensorflow.computation

@tff.tensorflow.computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
  return client_update(tff_model, tf_dataset, server_weights, client_optimizer)

服务器更新的 tff.tensorflow.computation 版本可以使用类似的方式定义,使用您已经提取的类型。

@tff.tensorflow.computation(model_weights_type)
def server_update_fn(mean_client_weights):
  return server_update(tff_model, mean_client_weights)

最后,您需要创建将所有这些整合在一起的 tff.federated_computation。此函数将接受两个 *联合值*,一个对应于服务器权重(放置在 tff.SERVER),另一个对应于客户端数据集(放置在 tff.CLIENTS)。

请注意,这两种类型都在上面定义了!您只需要使用 tff.FederatedType 为它们提供适当的放置。

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

还记得 FL 算法的 4 个要素吗?

  1. 服务器到客户端的广播步骤。
  2. 本地客户端更新步骤。
  3. 客户端到服务器的上传步骤。
  4. 服务器更新步骤。

现在您已经构建了以上内容,每个部分都可以简洁地表示为一行 TFF 代码。这种简洁性是您必须格外小心地指定联合类型等内容的原因!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client)
  )

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

您现在拥有一个 tff.federated_computation,用于算法初始化和运行算法的一步。要完成我们的算法,您需要将它们传递到 tff.templates.IterativeProcess 中。

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

让我们看一下迭代过程的 initializenext 函数的 *类型签名*。

str(federated_algorithm.initialize.type_signature)
'( -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'

这反映了 federated_algorithm.initialize 是一个无参数函数,它返回一个单层模型(具有 784x10 的权重矩阵和 10 个偏差单元)。

str(federated_algorithm.next.type_signature)
'(<server_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'

在这里,可以看出 federated_algorithm.next 接受服务器模型和客户端数据,并返回更新的服务器模型。

评估算法

让我们运行几轮,看看损失如何变化。首先,您将使用在第二个教程中讨论的 *集中式* 方法定义一个评估函数。

您将首先创建一个集中式评估数据集,然后应用与训练数据相同的预处理。

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

接下来,您将编写一个接受服务器状态的函数,并使用 Keras 在测试数据集上进行评估。如果您熟悉 tf.Keras,那么这一切看起来都将很熟悉,但请注意 set_weights 的使用!

def evaluate(model_weights):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )
  model_weights.assign_weights_to(keras_model)
  keras_model.evaluate(central_emnist_test)

现在,让我们初始化我们的算法并在测试集上进行评估。

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 26s 10ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

让我们训练几轮,看看是否有什么变化。

for _ in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 4s 1ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

损失函数略有下降。虽然跳跃很小,但您只执行了 15 轮训练,并且是在一小部分客户端上进行的。要看到更好的结果,您可能需要进行数百甚至数千轮训练。

修改我们的算法

在这一点上,让我们停下来思考一下您已经完成了什么。您已经通过将纯 TensorFlow 代码(用于客户端和服务器更新)与来自 TFF 联合核心的联合计算相结合,直接实现了联合平均。

要执行更复杂的学习,您只需更改上面的内容。特别是,通过编辑上面的纯 TF 代码,您可以更改客户端执行训练的方式,或服务器更新其模型的方式。

挑战:client_update 函数中添加 梯度裁剪

如果您想进行更大的更改,您还可以让服务器存储和广播更多数据。例如,服务器还可以存储客户端学习率,并随着时间的推移使其衰减!请注意,这将需要更改上面 tff.tensorflow.computation 调用中使用的类型签名。

更难的挑战:实现具有客户端学习率衰减的联合平均。

在这一点上,您可能会开始意识到在这个框架中您可以实现多少灵活性。有关想法(包括上面更难挑战的答案),您可以查看 tff.learning.algorithms.build_weighted_fed_avg 的源代码,或查看使用 TFF 的各种 研究项目