构建学习算法

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

开始之前

开始之前,请运行以下代码以确保您的环境已正确设置。如果您没有看到问候语,请参阅 安装 指南以获取说明。

pip install --quiet --upgrade tensorflow-federated
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

构建学习算法

构建您自己的联合学习算法教程 使用 TFF 的联合核心直接实现联合平均 (FedAvg) 算法的版本。

在本教程中,您将使用 TFF API 中的联合学习组件以模块化方式构建联合学习算法,而无需从头开始重新实现所有内容。

在本教程中,您将实现 FedAvg 的变体,该变体通过本地训练采用梯度裁剪。

学习算法构建块

从高层次来看,许多学习算法可以分为 4 个独立的组件,称为 **构建块**。它们如下所示

  1. 分发器(即服务器到客户端的通信)
  2. 客户端工作(即本地客户端计算)
  3. 聚合器(即客户端到服务器的通信)
  4. 终结器(即服务器使用聚合的客户端输出进行计算)

虽然 构建您自己的联合学习算法教程 从头开始实现了所有这些构建块,但这通常是不必要的。相反,您可以重复使用来自类似算法的构建块。

在本例中,要实现带有梯度裁剪的 FedAvg,您只需要修改 **客户端工作** 构建块。其余块可以与“普通”FedAvg 中使用的块相同。

实现客户端工作

首先,让我们编写 TF 逻辑,该逻辑执行带有梯度裁剪的本地模型训练。为简单起见,梯度将被裁剪,其范数最大为 1。

TF 逻辑

@tf.function
def client_update(
    model: tff.learning.models.FunctionalModel,
    dataset: tf.data.Dataset,
    initial_weights: tff.learning.models.ModelWeights,
    client_optimizer: tff.learning.optimizers.Optimizer,
):
  """Performs training (using the initial server model weights) on the client's dataset."""
  # Keep track of the number of examples.
  num_examples = 0.0
  # Use the client_optimizer to update the local model.
  trainable_weights, non_trainable_weights = (
      initial_weights.trainable,
      initial_weights.non_trainable,
  )
  optimizer_state = client_optimizer.initialize(
      tf.nest.map_structure(lambda x: tf.TensorSpec, trainable_weights)
  )
  for batch in dataset:
    x, y = batch
    with tf.GradientTape() as tape:
      tape.watch(trainable_weights)
      logits = model.predict_on_batch(
          model_weights=(trainable_weights, non_trainable_weights),
          x=x,
          training=True,
      )
      num_examples += tf.cast(tf.shape(y)[0], tf.float32)
      loss = model.loss(output=logits, label=y)

    # Compute the corresponding gradient
    grads = tape.gradient(loss, trainable_weights)

    # Compute the gradient norm and clip
    gradient_norm = tf.linalg.global_norm(grads)
    if gradient_norm > 1:
      grads = tf.nest.map_structure(lambda x: x / gradient_norm, grads)

    # Apply the gradient using a client optimizer.
    optimizer_state, trainable_weights = client_optimizer.next(
        optimizer_state, trainable_weights, grads
    )

  # Compute the difference between the initial weights and the client weights
  client_update = tf.nest.map_structure(
      tf.subtract, trainable_weights, initial_weights[0]
  )

  return tff.learning.templates.ClientResult(
      update=client_update, update_weight=num_examples
  )

上面代码中有一些重要的点。首先,它跟踪已查看的示例数量,因为这将构成客户端更新的 *权重*(在计算跨客户端的平均值时)。

其次,它使用 tff.learning.templates.ClientResult 来打包输出。此返回类型用于在 tff.learning 中标准化客户端工作构建块。

创建 ClientWorkProcess

虽然上面的 TF 逻辑将执行带有裁剪的本地训练,但它仍然需要包装在 TFF 代码中才能创建必要的构建块。

具体来说,4 个构建块表示为 tff.templates.MeasuredProcess。这意味着所有 4 个块都具有 initializenext 函数,用于实例化和运行计算。

这允许每个构建块跟踪其自己的 **状态**(存储在服务器上),以便根据需要执行其操作。虽然在本教程中不会使用它,但它可用于跟踪已发生的迭代次数或跟踪优化器状态等。

客户端工作 TF 逻辑通常应包装为 tff.learning.templates.ClientWorkProcess,它对进入和离开客户端本地训练的预期类型进行编码。它可以由模型和优化器参数化,如下所示。

def build_gradient_clipping_client_work(
    model: tff.learning.models.FunctionalModel,
    optimizer: tff.learning.optimizers.Optimizer,
) -> tff.learning.templates.ClientWorkProcess:
  """Creates a client work process that uses gradient clipping."""
  data_type = tff.SequenceType(tff.types.tensorflow_to_type(model.input_spec))
  model_weights_type = tff.types.to_type(
      tf.nest.map_structure(
          lambda arr: tff.types.TensorType(shape=arr.shape, dtype=arr.dtype),
          tff.learning.models.ModelWeights(*model.initial_weights),
      )
  )

  @tff.federated_computation
  def initialize_fn():
    return tff.federated_value((), tff.SERVER)

  @tff.tensorflow.computation(model_weights_type, data_type)
  def client_update_computation(model_weights, dataset):
    return client_update(model, dataset, model_weights, optimizer)

  @tff.federated_computation(
      initialize_fn.type_signature.result,
      tff.FederatedType(model_weights_type, tff.CLIENTS),
      tff.FederatedType(data_type, tff.CLIENTS),
  )
  def next_fn(state, model_weights, client_dataset):
    client_result = tff.federated_map(
        client_update_computation, (model_weights, client_dataset)
    )
    # Return empty measurements, though a more complete algorithm might
    # measure something here.
    measurements = tff.federated_value((), tff.SERVER)
    return tff.templates.MeasuredProcessOutput(
        state, client_result, measurements
    )

  return tff.learning.templates.ClientWorkProcess(initialize_fn, next_fn)

构建学习算法

让我们将上面的客户端工作放入一个完整的算法中。首先,让我们设置我们的数据和模型。

准备输入数据

加载并预处理 TFF 中包含的 EMNIST 数据集。有关更多详细信息,请参阅 图像分类 教程。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

为了将数据集馈送到我们的模型中,数据被扁平化并转换为形式为 (flattened_image_vector, label) 的元组。

让我们选择少量客户端,并将上述预处理应用于它们的数据集。

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]),
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

准备模型

这使用与 图像分类 教程中相同的模型。此模型(通过 tf.keras 实现)具有一个隐藏层,后面跟着一个 softmax 层。为了在 TFF 中使用此模型,Keras 模型被包装为 tff.learning.models.FunctionalModel。这使我们能够执行模型的 前向传递 aggregator_factory = tff.aggregators.MeanFactory() aggregator = aggregator_factory.create( model_weights_type.trainable, tff.TensorType(np.float32) ) finalizer = tff.learning.templates.build_apply_optimizer_finalizer( server_optimizer, model_weights_type )

initializer = tf.keras.initializers.GlorotNormal(seed=0)
keras_model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(784,)),
    tf.keras.layers.Dense(10, kernel_initializer=initializer),
    tf.keras.layers.Softmax(),
])

tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
    input_spec=federated_train_data[0].element_spec,
    metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy
    ),
)

准备优化器

就像在 tff.learning.algorithms.build_weighted_fed_avg 中一样,这里有两个优化器:一个客户端优化器和一个服务器优化器。为简单起见,优化器将是具有不同学习率的 SGD。

client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=1.0)

定义构建块

现在客户端工作构建块、数据、模型和优化器都已设置好,剩下的就是创建用于分发器、聚合器和最终器的构建块。这可以通过借用 TFF 中的一些默认值来完成,这些默认值被 FedAvg 使用。

@tff.tensorflow.computation
def initial_model_weights_fn():
  return tff.learning.models.ModelWeights(*tff_model.initial_weights)


model_weights_type = initial_model_weights_fn.type_signature.result

distributor = tff.learning.templates.build_broadcast_process(model_weights_type)
client_work = build_gradient_clipping_client_work(tff_model, client_optimizer)

# TFF aggregators use a factory pattern, which create an aggregator
# based on the output type of the client work. This also uses a float (the number
# of examples) to govern the weight in the average being computed.)
aggregator_factory = tff.aggregators.MeanFactory()
aggregator = aggregator_factory.create(
    model_weights_type.trainable, tff.TensorType(np.float32)
)
finalizer = tff.learning.templates.build_apply_optimizer_finalizer(
    server_optimizer, model_weights_type
)

组合构建块

最后,您可以使用 TFF 中的内置 **组合器** 将构建块组合在一起。这是一个相对简单的组合器,它接受上述 4 个构建块并将它们类型连接在一起。

fed_avg_with_clipping = tff.learning.templates.compose_learning_process(
    initial_model_weights_fn,
    distributor,
    client_work,
    aggregator,
    finalizer
)

运行算法

现在算法已经完成,让我们运行它。首先,**初始化** 算法。此算法的 **状态** 具有每个构建块的组件,以及一个用于 *全局模型权重* 的组件。

state = fed_avg_with_clipping.initialize()

state.client_work
()

正如预期的那样,客户端工作具有空状态(请记住上面的客户端工作代码!)。但是,其他构建块可能具有非空状态。例如,最终器会跟踪已发生的迭代次数。由于 next 尚未运行,因此它的状态为 0

state.finalizer
OrderedDict([('learning_rate', 1.0)])

现在运行一个训练回合。

learning_process_output = fed_avg_with_clipping.next(state, federated_train_data)

此输出(tff.learning.templates.LearningProcessOutput)同时具有 .state.metrics 输出。让我们看看两者。

learning_process_output.state.finalizer
OrderedDict([('learning_rate', 1.0)])

显然,最终器状态已增加 1,因为已运行了一轮 .next

learning_process_output.metrics
OrderedDict([('distributor', ()), ('client_work', ()), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

虽然指标为空,但对于更复杂和实用的算法,它们通常会包含大量有用信息。

结论

通过使用上面的构建块/组合器框架,您可以创建全新的学习算法,而无需从头开始重新执行所有操作。但是,这仅仅是起点。此框架使将算法表示为 FedAvg 的简单修改变得更加容易。有关更多算法,请参阅 tff.learning.algorithms,其中包含诸如 FedProx具有客户端学习率调度的 FedAvg 等算法。这些 API 甚至可以帮助实现全新的算法,例如 联邦 k 均值聚类