联邦学习用于图像分类

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

在本教程中,我们使用经典的 MNIST 训练示例来介绍 TFF 的联邦学习 (FL) API 层,tff.learning - 一组更高级别的接口,可用于执行常见的联邦学习任务类型,例如针对用户提供的 TensorFlow 模型实现的联邦训练。

本教程和联邦学习 API 主要面向希望将自己的 TensorFlow 模型插入 TFF 的用户,将后者主要视为黑盒。要更深入地了解 TFF 以及如何实现自己的联邦学习算法,请参阅有关 FC 核心 API 的教程 - 自定义联邦算法第 1 部分第 2 部分

有关 tff.learning 的更多信息,请继续学习 联邦学习用于文本生成 教程,该教程除了介绍循环模型外,还演示了如何加载预训练的序列化 Keras 模型以使用联邦学习进行细化,并结合使用 Keras 进行评估。

开始之前

开始之前,请运行以下内容以确保您的环境已正确设置。如果您没有看到问候语,请参阅 安装 指南以获取说明。

pip install --quiet --upgrade tensorflow-federated
%load_ext tensorboard
Fetching TensorBoard MPM version 'live'... done.
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

准备输入数据

让我们从数据开始。联邦学习需要一个联邦数据集,即来自多个用户的数据的集合。联邦数据通常是非 i.i.d. 的,这带来了独特的挑战。

为了便于实验,我们在 TFF 存储库中添加了一些数据集,包括联邦版本的 MNIST,其中包含 原始 NIST 数据集 的版本,该版本已使用 Leaf 重新处理,以便数据按数字的原始编写者进行键控。由于每个编写者都有独特的风格,因此该数据集展现了联邦数据集所期望的非 i.i.d. 行为。

以下是如何加载它。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

load_data() 返回的数据集是 tff.simulation.ClientData 的实例,这是一个接口,允许您枚举用户集,构建一个 tf.data.Dataset 来表示特定用户的 data,并查询单个元素的结构。以下是如何使用此接口来探索数据集的内容。请记住,虽然此接口允许您迭代客户端 ID,但这只是模拟数据的特性。正如您很快将看到的那样,客户端身份不会被联邦学习框架使用 - 它们的唯一目的是允许您选择数据的子集以进行模拟。

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

探索联邦数据中的异质性

联邦数据通常是非 i.i.d. 的,用户通常具有不同的数据分布,具体取决于使用模式。一些客户端可能在设备上具有较少的训练示例,在本地遭受数据匮乏,而一些客户端将拥有足够的训练示例。让我们使用可用的 EMNIST 数据来探索联邦系统中常见的数据异质性概念。重要的是要注意,我们只能深入分析客户端的数据,因为这是一个模拟环境,所有数据都在本地可用。在实际的生产联邦环境中,您将无法检查单个客户端的数据。

首先,让我们获取一个客户端数据的样本,以了解一个模拟设备上的示例。由于我们使用的数据集已按唯一的编写者进行键控,因此一个客户端的数据代表了一个人对 0 到 9 的数字样本的笔迹,模拟了一个用户的独特“使用模式”。

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

现在让我们可视化每个客户端在每个 MNIST 数字标签上的示例数量。在联邦环境中,每个客户端上的示例数量可能会有很大差异,具体取决于用户行为。

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

现在让我们可视化每个 MNIST 标签的每个客户端的平均图像。此代码将生成所有用户示例在一个标签上的每个像素值的平均值。我们会看到,一个客户端对一个数字的平均图像看起来与另一个客户端对同一个数字的平均图像不同,这是由于每个人的书写风格都是独一无二的。我们可以思考一下,每个本地训练轮次将如何推动模型在每个客户端上朝着不同的方向发展,因为我们在该本地轮次中从该用户的独特数据中学习。在本教程的后面,我们将看到如何将来自所有客户端的模型的每次更新汇总在一起,形成我们的新全局模型,该模型已经从我们每个客户端的独特数据中学习。

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

用户数据可能存在噪声,并且标签不可靠。例如,查看上面的客户端 #2 的数据,我们可以看到,对于标签 2,可能存在一些错误标记的示例,从而导致更嘈杂的平均图像。

预处理输入数据

由于数据已经是 tf.data.Dataset,因此可以使用数据集转换来完成预处理。在这里,我们将 28x28 图像展平成 784 元素数组,对单个示例进行混洗,将它们组织成批次,并将特征从 pixelslabel 重命名为 xy 以供 Keras 使用。我们还添加了对数据集的 repeat 操作,以运行多个 epoch。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

让我们验证一下是否成功。

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

我们已经具备了构建联合数据集所需的大部分构建块。

在模拟中将联合数据馈送到 TFF 的一种方法是将其简单地作为 Python 列表,列表的每个元素都包含单个用户的数据,无论是作为列表还是作为 tf.data.Dataset。由于我们已经有一个提供后者的接口,让我们使用它。

这是一个简单的辅助函数,它将从给定的用户集中构建一个数据集列表,作为训练或评估轮次的输入。

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

现在,我们如何选择客户端?

在典型的联合训练场景中,我们正在处理可能非常庞大的用户设备群体,其中只有一小部分可能在给定时间点可用以进行训练。例如,当客户端设备是移动电话时,它们只有在插入电源、脱离计量网络且处于闲置状态时才会参与训练。

当然,我们是在模拟环境中,所有数据都可以在本地获得。因此,在运行模拟时,我们通常会简单地对参与每轮训练的客户端进行随机子集采样,通常在每轮中都不同。

也就是说,正如您通过研究关于 联合平均 算法的论文可以发现的那样,在每轮中使用随机采样的客户端子集的系统中实现收敛可能需要一段时间,并且在该交互式教程中必须运行数百轮是不切实际的。

我们将做的是对客户端集进行一次采样,并在各轮中重复使用相同的集,以加快收敛速度(有意地过度拟合这些用户的少量数据)。我们将此作为练习留给读者,让他们修改本教程以模拟随机采样——这很容易做到(一旦您做到了,请记住,让模型收敛可能需要一段时间)。

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')
Number of client datasets: 10
First dataset: <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>

使用 Keras 创建模型

如果您使用的是 Keras,您可能已经拥有构建 Keras 模型的代码。这是一个简单模型的示例,足以满足我们的需求。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

为了将任何模型与 TFF 一起使用,它需要包装在 tff.learning.models.VariableModel 接口的实例中,该接口公开了一些方法来标记模型的前向传递、元数据属性等,类似于 Keras,但也引入了其他元素,例如控制计算联合指标过程的方法。现在让我们先不要担心这个问题;如果您有一个像我们上面定义的 Keras 模型,您可以通过调用 tff.learning.models.from_keras_model 来让 TFF 为您包装它,并将模型和样本数据批次作为参数传递,如下所示。

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

在联合数据上训练模型

现在我们已经将模型包装为 tff.learning.models.VariableModel 以供 TFF 使用,我们可以通过调用辅助函数 tff.learning.algorithms.build_weighted_fed_avg 来让 TFF 构建联合平均算法,如下所示。

请记住,参数需要是一个构造函数(例如上面的 model_fn),而不是已经构建的实例,以便可以在 TFF 控制的上下文中构建您的模型(如果您对原因感到好奇,我们建议您阅读关于 自定义算法 的后续教程)。

关于下面的联合平均算法,有一个关键的注意事项,即存在 **2** 个优化器:一个 _客户端_优化器和一个 _服务器_优化器。_客户端_优化器仅用于计算每个客户端的本地模型更新。_服务器_优化器将平均更新应用于服务器上的全局模型。特别是,这意味着使用的优化器和学习率的选择可能需要与您用于在标准 i.i.d. 数据集上训练模型的优化器和学习率不同。我们建议从常规 SGD 开始,可能使用比平时更小的学习率。我们使用的学习率尚未经过仔细调整,您可以随意进行试验。

training_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

刚刚发生了什么?TFF 构建了一对 _联合计算_,并将它们打包到 tff.templates.IterativeProcess 中,其中这些计算作为一对属性 initializenext 可用。

简而言之,_联合计算_ 是 TFF 内部语言中的程序,可以表达各种联合算法(您可以在 自定义算法 教程中找到更多相关信息)。在本例中,生成的两个计算并打包到 iterative_process 中,实现了 联合平均

TFF 的目标是定义计算,以便它们可以在真实的联合学习环境中执行,但目前仅实现了本地执行模拟运行时。要在模拟器中执行计算,您只需像调用 Python 函数一样调用它即可。此默认解释环境并非为高性能而设计,但对于本教程来说已经足够了;我们预计将在未来的版本中提供更高性能的模拟运行时,以促进更大规模的研究。

让我们从 initialize 计算开始。与所有联合计算一样,您可以将其视为一个函数。该计算不接受任何参数,并返回一个结果——服务器上联合平均过程状态的表示。虽然我们不想深入研究 TFF 的细节,但了解此状态的外观可能会有所帮助。您可以将其可视化如下。

print(training_process.initialize.type_signature.formatted_representation())
( -> <
  global_model_weights=<
    trainable=<
      float32[784,10],
      float32[10]
    >,
    non_trainable=<>
  >,
  distributor=<>,
  client_work=<>,
  aggregator=<
    value_sum_process=<>,
    weight_sum_process=<>
  >,
  finalizer=<
    int64,
    float32[784,10],
    float32[10]
  >
>@SERVER)

虽然上面的类型签名乍一看可能有点神秘,但您可以识别出服务器状态包含一个 global_model_weights(将分发到所有设备的 MNIST 的初始模型参数)、一些空参数(如 distributor,它控制服务器到客户端的通信)和一个 finalizer 组件。最后一个组件控制服务器在轮次结束时用于更新其模型的逻辑,并包含一个整数,表示已发生的 FedAvg 轮次数量。

让我们调用 initialize 计算来构建服务器状态。

train_state = training_process.initialize()

第二对联合计算 next 表示联合平均的单轮,它包括将服务器状态(包括模型参数)推送到客户端、在本地数据上进行设备内训练、收集和平均模型更新,以及在服务器上生成新的更新模型。

从概念上讲,您可以将 next 视为具有如下功能类型签名的函数。

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

特别是,应该将 next() 视为不是在服务器上运行的函数,而是整个分散计算的声明性函数表示——一些输入由服务器提供 (SERVER_STATE),但每个参与的设备都会贡献自己的本地数据集。

让我们运行一轮训练,并可视化结果。我们可以使用上面为用户样本生成的联合数据。

result = training_process.next(train_state, federated_train_data)
train_state = result.state
train_metrics = result.metrics
print('round  1, metrics={}'.format(train_metrics))
round  1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193733), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

让我们再运行几轮。如前所述,通常在这一点上,您会从每轮中新随机选择的用户样本中选择模拟数据的子集,以模拟用户不断进出、但在此交互式笔记本中,为了演示的目的,我们将重复使用相同的用户,以便系统能够快速收敛。

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  result = training_process.next(train_state, federated_train_data)
  train_state = result.state
  train_metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, train_metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.14012346), ('loss', 2.9851403), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.1590535), ('loss', 2.8617127), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17860082), ('loss', 2.7401376), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20102881), ('loss', 2.6186547), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.22345679), ('loss', 2.5006158), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.24794239), ('loss', 2.3858356), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.27160493), ('loss', 2.2757034), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2958848), ('loss', 2.17098), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3251029), ('loss', 2.072707), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

在每轮联合训练后,训练损失都在下降,表明模型正在收敛。但是,这些训练指标有一些重要的注意事项,请参阅本教程后面关于 _评估_ 的部分。

在 TensorBoard 中显示模型指标

接下来,让我们使用 Tensorboard 可视化这些联合计算的指标。

让我们首先创建目录和相应的摘要写入器,以将指标写入其中。

logdir = "/tmp/logs/scalars/training/"
try:
  tf.io.gfile.rmtree(logdir)  # delete any previous results
except tf.errors.NotFoundError as e:
  pass # Ignore if the directory didn't previously exist.
summary_writer = tf.summary.create_file_writer(logdir)
train_state = training_process.initialize()

使用相同的摘要写入器绘制相关的标量指标。

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    result = training_process.next(train_state, federated_train_data)
    train_state = result.state
    train_metrics = result.metrics
    for name, value in train_metrics['client_work']['train'].items():
      tf.summary.scalar(name, value, step=round_num)

使用上面指定的根日志目录启动 TensorBoard。数据加载可能需要几秒钟。

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
# Uncomment and run this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

为了以相同的方式查看评估指标,您可以创建一个单独的 eval 文件夹,例如“logs/scalars/eval”,以写入 TensorBoard。

自定义模型实现

Keras 是 TensorFlow 推荐的高级模型 API,我们鼓励在 TFF 中尽可能使用 Keras 模型(通过 tff.learning.models.from_keras_model)。

但是,tff.learning 提供了一个更低级的模型接口,tff.learning.models.VariableModel,它公开了使用模型进行联合学习所需的最低限度功能。直接实现此接口(可能仍然使用 tf.keras.layers 等构建块)允许最大程度地自定义,而无需修改联合学习算法的内部机制。

所以让我们从头开始重新做一遍。

定义模型变量、前向传递和指标

第一步是识别我们将要使用的 TensorFlow 变量。为了使以下代码更易读,让我们定义一个数据结构来表示整个集合。这将包括我们将要训练的变量,例如 weightsbias,以及将在训练期间更新的各种累积统计信息和计数器的变量,例如 loss_sumaccuracy_sumnum_examples

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

这是一个创建变量的方法。为了简单起见,我们将所有统计数据表示为 tf.float32,这样可以避免在后续阶段进行类型转换。将变量初始化器包装为 lambda 是由 资源变量 要求的。

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

有了模型参数和累积统计信息的变量,我们现在可以定义前向传递方法,该方法计算损失,发出预测,并更新单个批次输入数据的累积统计信息,如下所示。

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

接下来,我们定义两个与本地指标相关的函数,同样使用 TensorFlow。

第一个函数 get_local_unfinalized_metrics 返回未最终确定的指标值(除了自动处理的模型更新之外),这些值有资格在联邦学习或评估过程中聚合到服务器。

def get_local_unfinalized_metrics(variables):
  return collections.OrderedDict(
      num_examples=[variables.num_examples],
      loss=[variables.loss_sum, variables.num_examples],
      accuracy=[variables.accuracy_sum, variables.num_examples])

第二个函数 get_metric_finalizers 返回一个 OrderedDict,其中包含与 get_local_unfinalized_metrics 相同键(即指标名称)的 tf.function。每个 tf.function 都接收指标的未最终确定的值,并计算最终确定的指标。

def get_metric_finalizers():
  return collections.OrderedDict(
      num_examples=tf.function(func=lambda x: x[0]),
      loss=tf.function(func=lambda x: x[0] / x[1]),
      accuracy=tf.function(func=lambda x: x[0] / x[1]))

get_local_unfinalized_metrics 返回的本地未最终确定的指标如何在客户端之间聚合,由定义联邦学习或评估过程时的 metrics_aggregator 参数指定。例如,在 tff.learning.algorithms.build_weighted_fed_avg API(下一节中显示)中,metrics_aggregator 的默认值为 tff.learning.metrics.sum_then_finalize,它首先将来自 CLIENTS 的未最终确定的指标相加,然后在 SERVER 上应用指标最终确定器。

构造 tff.learning.models.VariableModel 的实例

有了以上所有内容,我们就可以为 TFF 构造一个模型表示,类似于您在让 TFF 摄取 Keras 模型时为您生成的模型表示。

import collections
from collections.abc import Callable

class MnistModel(tff.learning.models.VariableModel):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.models.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_unfinalized_metrics(
      self) -> collections.OrderedDict[str, list[tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to unfinalized values."""
    return get_local_unfinalized_metrics(self._variables)

  def metric_finalizers(
      self) -> collections.OrderedDict[str, Callable[[list[tf.Tensor]], tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to finalizers."""
    return get_metric_finalizers()

  @tf.function
  def reset_metrics(self):
    """Resets metrics variables to initial value."""
    for var in self.local_variables:
      var.assign(tf.zeros_like(var))

如您所见,由 tff.learning.models.VariableModel 定义的抽象方法和属性对应于前面介绍变量并定义损失和统计信息的代码片段。

以下是一些值得强调的要点

  • 您的模型将使用的所有状态都必须作为 TensorFlow 变量捕获,因为 TFF 在运行时不使用 Python(请记住,您的代码应该编写成可以部署到移动设备上;有关原因的更深入评论,请参阅 自定义算法 教程)。
  • 您的模型应该描述它接受的数据形式 (input_spec),因为一般来说,TFF 是一个强类型环境,希望为所有组件确定类型签名。声明模型输入的格式是其必不可少的组成部分。
  • 虽然从技术上讲不是必需的,但我们建议将所有 TensorFlow 逻辑(前向传递、指标计算等)包装为 tf.function,因为这有助于确保 TensorFlow 可以序列化,并且无需显式控制依赖关系。

以上内容足以用于评估和联邦 SGD 等算法。但是,对于联邦平均,我们需要指定模型应该如何在每个批次上进行本地训练。在构建联邦平均算法时,我们将指定一个本地优化器。

使用新模型模拟联邦训练

有了以上所有内容,剩下的过程看起来与我们之前看到的相同 - 只需将模型构造函数替换为我们新模型类的构造函数,并在您创建的迭代过程中使用两个联邦计算来循环遍历训练轮次。

training_process = tff.learning.algorithms.build_weighted_fed_avg(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
train_state = training_process.initialize()
result = training_process.next(train_state, federated_train_data)
train_state = result.state
metrics = result.metrics
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.119374), ('accuracy', 0.12345679)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
for round_num in range(2, 11):
  result = training_process.next(train_state, federated_train_data)
  train_state = result.state
  metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.98514), ('accuracy', 0.14012346)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.8617127), ('accuracy', 0.1590535)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.740137), ('accuracy', 0.17860082)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6186547), ('accuracy', 0.20102881)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5006158), ('accuracy', 0.22345679)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3858361), ('accuracy', 0.24794239)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.275704), ('accuracy', 0.27160493)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1709805), ('accuracy', 0.2958848)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.0727067), ('accuracy', 0.3251029)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

要查看 TensorBoard 中的这些指标,请参阅上面“在 TensorBoard 中显示模型指标”中列出的步骤。

评估

到目前为止,我们所有的实验都只展示了联邦训练指标 - 在轮次中所有客户端上训练的所有数据批次的平均指标。这引入了关于过拟合的正常问题,尤其是在为了简单起见,我们在每一轮中都使用了相同的客户端集,但联邦平均算法的训练指标中还存在一种特定于过拟合的概念。如果我们想象每个客户端只有一个数据批次,并且我们在该批次上训练了许多迭代(轮次),那么最容易看到这一点。在这种情况下,本地模型将很快完全拟合到该批次,因此我们平均的本地准确性指标将接近 1.0。因此,这些训练指标可以被视为训练正在进行的标志,但仅此而已。

要在联邦数据上执行评估,您可以使用 tff.learning.build_federated_evaluation 函数构造另一个专门用于此目的的联邦计算,并将您的模型构造函数作为参数传递。请注意,与联邦平均不同,我们在联邦平均中使用了 MnistTrainableModel,而传递 MnistModel 就足够了。评估不执行梯度下降,因此无需构造优化器。

为了进行实验和研究,当存在集中式测试数据集时,用于文本生成的联邦学习 演示了另一种评估选项:从联邦学习中获取训练后的权重,将其应用于标准 Keras 模型,然后简单地调用 tf.keras.models.Model.evaluate() 在集中式数据集上。

evaluation_process = tff.learning.algorithms.build_fed_eval(MnistModel)

您可以检查评估函数的抽象类型签名,如下所示。

print(evaluation_process.next.type_signature.formatted_representation())
(<
  state=<
    global_model_weights=<
      trainable=<
        float32[784,10],
        float32[10]
      >,
      non_trainable=<>
    >,
    distributor=<>,
    client_work=<
      <>,
      <
        num_examples=<
          float32
        >,
        loss=<
          float32,
          float32
        >,
        accuracy=<
          float32,
          float32
        >
      >
    >,
    aggregator=<
      value_sum_process=<>,
      weight_sum_process=<>
    >,
    finalizer=<>
  >@SERVER,
  client_data={<
    x=float32[?,784],
    y=int32[?,1]
  >*}@CLIENTS
> -> <
  state=<
    global_model_weights=<
      trainable=<
        float32[784,10],
        float32[10]
      >,
      non_trainable=<>
    >,
    distributor=<>,
    client_work=<
      <>,
      <
        num_examples=<
          float32
        >,
        loss=<
          float32,
          float32
        >,
        accuracy=<
          float32,
          float32
        >
      >
    >,
    aggregator=<
      value_sum_process=<>,
      weight_sum_process=<>
    >,
    finalizer=<>
  >@SERVER,
  metrics=<
    distributor=<>,
    client_work=<
      eval=<
        current_round_metrics=<
          num_examples=float32,
          loss=float32,
          accuracy=float32
        >,
        total_rounds_metrics=<
          num_examples=float32,
          loss=float32,
          accuracy=float32
        >
      >
    >,
    aggregator=<
      mean_value=<>,
      mean_weight=<>
    >,
    finalizer=<>
  >@SERVER
>)

请注意,评估过程是一个 tff.lenaring.templates.LearningProcess 对象。该对象有一个 initialize 方法,它将创建状态,但最初将包含一个未训练的模型。使用 set_model_weights 方法,必须插入要评估的训练状态的权重。

evaluation_state = evaluation_process.initialize()
model_weights = training_process.get_model_weights(train_state)
evaluation_state = evaluation_process.set_model_weights(evaluation_state, model_weights)

现在,评估状态包含要评估的模型权重,我们可以通过在进程上调用 next 方法来使用评估数据集计算评估指标,就像在训练中一样。

这将再次返回一个 tff.learning.templates.LearingProcessOutput 实例。

evaluation_output = evaluation_process.next(evaluation_state, federated_train_data)

以下是我们得到的结果。请注意,这些数字看起来比上面最后一轮训练报告的数字略好。按照惯例,迭代训练过程报告的训练指标通常反映了训练轮次开始时模型的性能,因此评估指标将始终领先一步。

str(evaluation_output.metrics)
"OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('num_examples', 4860.0), ('loss', 1.6654209), ('accuracy', 0.3621399)])), ('total_rounds_metrics', OrderedDict([('num_examples', 4860.0), ('loss', 1.6654209), ('accuracy', 0.3621399)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])"

现在,让我们编译一个联邦数据的测试样本,并在测试数据上重新运行评估。数据将来自同一组真实用户,但来自不同的保留数据集。

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>)
evaluation_output = evaluation_process.next(evaluation_state, federated_test_data)
str(evaluation_output.metrics)
"OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('num_examples', 580.0), ('loss', 1.7750846), ('accuracy', 0.33620688)])), ('total_rounds_metrics', OrderedDict([('num_examples', 580.0), ('loss', 1.7750846), ('accuracy', 0.33620688)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])"

本教程到此结束。我们鼓励您尝试不同的参数(例如,批次大小、用户数量、轮次、学习率等),修改上面的代码以模拟在每一轮中对用户随机样本进行训练,并探索我们开发的其他教程。