在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
开始之前
开始之前,请运行以下代码以确保您的环境已正确设置。如果您没有看到问候语,请参阅 安装 指南以获取说明。
pip install --quiet --upgrade tensorflow-federated
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
构建学习算法
构建您自己的联合学习算法教程 使用 TFF 的联合核心直接实现联合平均 (FedAvg) 算法的版本。
在本教程中,您将使用 TFF API 中的联合学习组件以模块化方式构建联合学习算法,而无需从头开始重新实现所有内容。
在本教程中,您将实现 FedAvg 的变体,该变体通过本地训练采用梯度裁剪。
学习算法构建块
从高层次来看,许多学习算法可以分为 4 个独立的组件,称为 **构建块**。它们如下所示
- 分发器(即服务器到客户端的通信)
- 客户端工作(即本地客户端计算)
- 聚合器(即客户端到服务器的通信)
- 终结器(即服务器使用聚合的客户端输出进行计算)
虽然 构建您自己的联合学习算法教程 从头开始实现了所有这些构建块,但这通常是不必要的。相反,您可以重复使用来自类似算法的构建块。
在本例中,要实现带有梯度裁剪的 FedAvg,您只需要修改 **客户端工作** 构建块。其余块可以与“普通”FedAvg 中使用的块相同。
实现客户端工作
首先,让我们编写 TF 逻辑,该逻辑执行带有梯度裁剪的本地模型训练。为简单起见,梯度将被裁剪,其范数最大为 1。
TF 逻辑
@tf.function
def client_update(
model: tff.learning.models.FunctionalModel,
dataset: tf.data.Dataset,
initial_weights: tff.learning.models.ModelWeights,
client_optimizer: tff.learning.optimizers.Optimizer,
):
"""Performs training (using the initial server model weights) on the client's dataset."""
# Keep track of the number of examples.
num_examples = 0.0
# Use the client_optimizer to update the local model.
trainable_weights, non_trainable_weights = (
initial_weights.trainable,
initial_weights.non_trainable,
)
optimizer_state = client_optimizer.initialize(
tf.nest.map_structure(lambda x: tf.TensorSpec, trainable_weights)
)
for batch in dataset:
x, y = batch
with tf.GradientTape() as tape:
tape.watch(trainable_weights)
logits = model.predict_on_batch(
model_weights=(trainable_weights, non_trainable_weights),
x=x,
training=True,
)
num_examples += tf.cast(tf.shape(y)[0], tf.float32)
loss = model.loss(output=logits, label=y)
# Compute the corresponding gradient
grads = tape.gradient(loss, trainable_weights)
# Compute the gradient norm and clip
gradient_norm = tf.linalg.global_norm(grads)
if gradient_norm > 1:
grads = tf.nest.map_structure(lambda x: x / gradient_norm, grads)
# Apply the gradient using a client optimizer.
optimizer_state, trainable_weights = client_optimizer.next(
optimizer_state, trainable_weights, grads
)
# Compute the difference between the initial weights and the client weights
client_update = tf.nest.map_structure(
tf.subtract, trainable_weights, initial_weights[0]
)
return tff.learning.templates.ClientResult(
update=client_update, update_weight=num_examples
)
上面代码中有一些重要的点。首先,它跟踪已查看的示例数量,因为这将构成客户端更新的 *权重*(在计算跨客户端的平均值时)。
其次,它使用 tff.learning.templates.ClientResult
来打包输出。此返回类型用于在 tff.learning
中标准化客户端工作构建块。
创建 ClientWorkProcess
虽然上面的 TF 逻辑将执行带有裁剪的本地训练,但它仍然需要包装在 TFF 代码中才能创建必要的构建块。
具体来说,4 个构建块表示为 tff.templates.MeasuredProcess
。这意味着所有 4 个块都具有 initialize
和 next
函数,用于实例化和运行计算。
这允许每个构建块跟踪其自己的 **状态**(存储在服务器上),以便根据需要执行其操作。虽然在本教程中不会使用它,但它可用于跟踪已发生的迭代次数或跟踪优化器状态等。
客户端工作 TF 逻辑通常应包装为 tff.learning.templates.ClientWorkProcess
,它对进入和离开客户端本地训练的预期类型进行编码。它可以由模型和优化器参数化,如下所示。
def build_gradient_clipping_client_work(
model: tff.learning.models.FunctionalModel,
optimizer: tff.learning.optimizers.Optimizer,
) -> tff.learning.templates.ClientWorkProcess:
"""Creates a client work process that uses gradient clipping."""
data_type = tff.SequenceType(tff.types.tensorflow_to_type(model.input_spec))
model_weights_type = tff.types.to_type(
tf.nest.map_structure(
lambda arr: tff.types.TensorType(shape=arr.shape, dtype=arr.dtype),
tff.learning.models.ModelWeights(*model.initial_weights),
)
)
@tff.federated_computation
def initialize_fn():
return tff.federated_value((), tff.SERVER)
@tff.tensorflow.computation(model_weights_type, data_type)
def client_update_computation(model_weights, dataset):
return client_update(model, dataset, model_weights, optimizer)
@tff.federated_computation(
initialize_fn.type_signature.result,
tff.FederatedType(model_weights_type, tff.CLIENTS),
tff.FederatedType(data_type, tff.CLIENTS),
)
def next_fn(state, model_weights, client_dataset):
client_result = tff.federated_map(
client_update_computation, (model_weights, client_dataset)
)
# Return empty measurements, though a more complete algorithm might
# measure something here.
measurements = tff.federated_value((), tff.SERVER)
return tff.templates.MeasuredProcessOutput(
state, client_result, measurements
)
return tff.learning.templates.ClientWorkProcess(initialize_fn, next_fn)
构建学习算法
让我们将上面的客户端工作放入一个完整的算法中。首先,让我们设置我们的数据和模型。
准备输入数据
加载并预处理 TFF 中包含的 EMNIST 数据集。有关更多详细信息,请参阅 图像分类 教程。
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
为了将数据集馈送到我们的模型中,数据被扁平化并转换为形式为 (flattened_image_vector, label)
的元组。
让我们选择少量客户端,并将上述预处理应用于它们的数据集。
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
准备模型
这使用与 图像分类 教程中相同的模型。此模型(通过 tf.keras
实现)具有一个隐藏层,后面跟着一个 softmax 层。为了在 TFF 中使用此模型,Keras 模型被包装为 tff.learning.models.FunctionalModel
。这使我们能够执行模型的 前向传递 aggregator_factory = tff.aggregators.MeanFactory() aggregator = aggregator_factory.create( model_weights_type.trainable, tff.TensorType(np.float32) ) finalizer = tff.learning.templates.build_apply_optimizer_finalizer( server_optimizer, model_weights_type )
initializer = tf.keras.initializers.GlorotNormal(seed=0)
keras_model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
tff_model = tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
input_spec=federated_train_data[0].element_spec,
metrics_constructor=collections.OrderedDict(
accuracy=tf.keras.metrics.SparseCategoricalAccuracy
),
)
准备优化器
就像在 tff.learning.algorithms.build_weighted_fed_avg
中一样,这里有两个优化器:一个客户端优化器和一个服务器优化器。为简单起见,优化器将是具有不同学习率的 SGD。
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=1.0)
定义构建块
现在客户端工作构建块、数据、模型和优化器都已设置好,剩下的就是创建用于分发器、聚合器和最终器的构建块。这可以通过借用 TFF 中的一些默认值来完成,这些默认值被 FedAvg 使用。
@tff.tensorflow.computation
def initial_model_weights_fn():
return tff.learning.models.ModelWeights(*tff_model.initial_weights)
model_weights_type = initial_model_weights_fn.type_signature.result
distributor = tff.learning.templates.build_broadcast_process(model_weights_type)
client_work = build_gradient_clipping_client_work(tff_model, client_optimizer)
# TFF aggregators use a factory pattern, which create an aggregator
# based on the output type of the client work. This also uses a float (the number
# of examples) to govern the weight in the average being computed.)
aggregator_factory = tff.aggregators.MeanFactory()
aggregator = aggregator_factory.create(
model_weights_type.trainable, tff.TensorType(np.float32)
)
finalizer = tff.learning.templates.build_apply_optimizer_finalizer(
server_optimizer, model_weights_type
)
组合构建块
最后,您可以使用 TFF 中的内置 **组合器** 将构建块组合在一起。这是一个相对简单的组合器,它接受上述 4 个构建块并将它们类型连接在一起。
fed_avg_with_clipping = tff.learning.templates.compose_learning_process(
initial_model_weights_fn,
distributor,
client_work,
aggregator,
finalizer
)
运行算法
现在算法已经完成,让我们运行它。首先,**初始化** 算法。此算法的 **状态** 具有每个构建块的组件,以及一个用于 *全局模型权重* 的组件。
state = fed_avg_with_clipping.initialize()
state.client_work
()
正如预期的那样,客户端工作具有空状态(请记住上面的客户端工作代码!)。但是,其他构建块可能具有非空状态。例如,最终器会跟踪已发生的迭代次数。由于 next
尚未运行,因此它的状态为 0
。
state.finalizer
OrderedDict([('learning_rate', 1.0)])
现在运行一个训练回合。
learning_process_output = fed_avg_with_clipping.next(state, federated_train_data)
此输出(tff.learning.templates.LearningProcessOutput
)同时具有 .state
和 .metrics
输出。让我们看看两者。
learning_process_output.state.finalizer
OrderedDict([('learning_rate', 1.0)])
显然,最终器状态已增加 1,因为已运行了一轮 .next
。
learning_process_output.metrics
OrderedDict([('distributor', ()), ('client_work', ()), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
虽然指标为空,但对于更复杂和实用的算法,它们通常会包含大量有用信息。
结论
通过使用上面的构建块/组合器框架,您可以创建全新的学习算法,而无需从头开始重新执行所有操作。但是,这仅仅是起点。此框架使将算法表示为 FedAvg 的简单修改变得更加容易。有关更多算法,请参阅 tff.learning.algorithms
,其中包含诸如 FedProx 和 具有客户端学习率调度的 FedAvg 等算法。这些 API 甚至可以帮助实现全新的算法,例如 联邦 k 均值聚类。