在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
这是 构建您自己的联合学习算法 教程和 simple_fedavg 示例的替代方案,用于为 联合平均 算法构建自定义迭代过程。本教程将使用 TFF 优化器 而不是 Keras 优化器。TFF 优化器抽象旨在成为状态到状态的输出,以便更容易地将其合并到 TFF 迭代过程中。 tff.learning
API 还接受 TFF 优化器作为输入参数。
开始之前
开始之前,请运行以下命令以确保您的环境已正确设置。如果您没有看到问候语,请参阅 安装 指南以获取说明。
pip install --quiet --upgrade tensorflow-federated
from typing import Any
import functools
import attrs
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
准备数据和模型
EMNIST 数据处理和模型与 simple_fedavg 示例非常相似。
only_digits=True
# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)
# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):
def batch_format_fn(element):
return (tf.expand_dims(element['pixels'], -1), element['label'])
return dataset.batch(batch_size).map(batch_format_fn)
# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
emnist_train.create_tf_dataset_for_client(train_client_ids[0]))
# Define model.
def create_keras_model():
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
# Wrap as `tff.learning.models.VariableModel`.
def model_fn():
keras_model = create_keras_model()
return tff.learning.models.from_keras_model(
keras_model,
input_spec=central_test_data.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
自定义迭代过程
在许多情况下,联合算法有 4 个主要组件
- 服务器到客户端的广播步骤。
- 本地客户端更新步骤。
- 客户端到服务器的上传步骤。
- 服务器更新步骤。
在 TFF 中,我们通常将联合算法表示为 tff.templates.IterativeProcess
(在整个过程中我们将其称为 IterativeProcess
)。这是一个包含 initialize
和 next
函数的类。这里,initialize
用于初始化服务器,而 next
将执行联合算法的一轮通信。
我们将介绍不同的组件来构建联合平均 (FedAvg) 算法,该算法将在客户端更新步骤中使用一个优化器,在服务器更新步骤中使用另一个优化器。客户端和服务器更新的核心逻辑可以表示为纯 TF 块。
TF 块:客户端和服务器更新
在每个客户端上,都会初始化一个本地 client_optimizer
并用于更新客户端模型权重。在服务器上,server_optimizer
将使用来自上一轮的状态,并更新下一轮的状态。
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs local training on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Initialize the client optimizer.
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
# Use the client_optimizer to update the local model.
for batch in iter(dataset):
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data.
outputs = model.forward_pass(batch)
# Compute the corresponding gradient.
grads = tape.gradient(outputs.loss, client_weights)
# Apply the gradient using a client optimizer.
optimizer_state, updated_weights = client_optimizer.next(
optimizer_state, client_weights, grads)
tf.nest.map_structure(lambda a, b: a.assign(b),
client_weights, updated_weights)
# Return model deltas.
return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attrs.define(eq=False, frozen=True)
class ServerState(object):
trainable_weights: Any
optimizer_state: Any
@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
"""Updates the server model weights."""
# Use aggregated negative model delta as pseudo gradient.
negative_weights_delta = tf.nest.map_structure(
lambda w: -1.0 * w, mean_model_delta)
new_optimizer_state, updated_weights = server_optimizer.next(
server_state.optimizer_state, server_state.trainable_weights,
negative_weights_delta)
return tff.structure.update_struct(
server_state,
trainable_weights=updated_weights,
optimizer_state=new_optimizer_state)
TFF 块:tff.tensorflow.computation
和 tff.federated_computation
我们现在使用 TFF 进行编排并为 FedAvg 构建迭代过程。我们必须使用 tff.tensorflow.computation
包装上面定义的 TF 块,并在 tff.federated_computation
函数中使用 TFF 方法 tff.federated_broadcast
、tff.federated_map
、tff.federated_mean
。在定义自定义迭代过程时,使用 tff.learning.optimizers.Optimizer
API 以及 initialize
和 next
函数非常容易。
# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.01)
# 2. Functions return initial state on server.
@tff.tensorflow.computation
def server_init():
model = model_fn()
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
return ServerState(
trainable_weights=model.trainable_variables,
optimizer_state=optimizer_state)
@tff.federated_computation
def server_init_tff():
return tff.federated_value(server_init(), tff.SERVER)
# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n',
server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n',
trainable_weights_type.formatted_representation())
# 3-1. Wrap server and client TF blocks with `tff.tensorflow.computation`.
@tff.tensorflow.computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
return server_update(server_state, model_delta, server_optimizer)
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(tff.types.tensorflow_to_type(whimsy_model.input_spec))
print('tf_dataset_type:\n',
tf_dataset_type.formatted_representation())
@tff.tensorflow.computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
model = model_fn()
return client_update(model, dataset, server_weights, client_optimizer)
# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
# Server-to-client broadcast.
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
# Local client update.
model_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# Client-to-server upload and aggregation.
mean_model_delta = tff.federated_mean(model_deltas)
# Server update.
server_state = tff.federated_map(
server_update_fn, (server_state, mean_model_delta))
return server_state
# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n',
fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n',
fedavg_process.next.type_signature.formatted_representation())
server_state_type: < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< learning_rate=float32, momentum=float32, accumulator=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > > trainable_weights_type: < float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > tf_dataset_type: < float32[?,28,28,1], int32[?] >* type signature of `initialize`: ( -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< learning_rate=float32, momentum=float32, accumulator=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > >@SERVER) type signature of `next`: (< server_state=< trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< learning_rate=float32, momentum=float32, accumulator=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > >@SERVER, federated_dataset={< float32[?,28,28,1], int32[?] >*}@CLIENTS > -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< learning_rate=float32, momentum=float32, accumulator=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > >@SERVER)
评估算法
我们在集中式评估数据集上评估性能。
def evaluate(server_state):
keras_model = create_keras_model()
tf.nest.map_structure(
lambda var, t: var.assign(t),
keras_model.trainable_weights, server_state.trainable_weights)
metric = tf.keras.metrics.SparseCategoricalAccuracy()
for batch in iter(central_test_data):
preds = keras_model(batch[0], training=False)
metric.update_state(y_true=batch[1], y_pred=preds)
return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)
# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients]
for round in range(20):
server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.06451613 Test accuracy 0.086021505