使用 tff 的 ClientData。

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TFF 中建模的联邦计算中,按客户端(例如用户)键控数据集的概念至关重要。TFF 提供了接口 tff.simulation.datasets.ClientData 来抽象化此概念,并且 TFF 托管的数据集(stackoverflowshakespeareemnistcifar100gldv2)都实现了此接口。

如果您正在使用自己的数据集进行联邦学习,TFF 强烈建议您要么实现 ClientData 接口,要么使用 TFF 的辅助函数之一来生成一个 ClientData,该 ClientData 表示您磁盘上的数据,例如 tff.simulation.datasets.ClientData.from_clients_and_fn

由于大多数 TFF 的端到端示例都从 ClientData 对象开始,因此使用您的自定义数据集实现 ClientData 接口将使您更容易浏览使用 TFF 编写的现有代码。此外,ClientData 构造的 tf.data.Datasets 可以直接迭代以生成 numpy 数组的结构,因此 ClientData 对象可以在迁移到 TFF 之前与任何基于 Python 的 ML 框架一起使用。

如果您打算将模拟扩展到许多机器或部署它们,则可以使用几种模式来简化您的工作。下面我们将介绍几种使用 ClientData 和 TFF 的方法,以使我们的从小规模迭代到大规模实验再到生产部署的体验尽可能顺利。

我应该使用哪种模式将 ClientData 传递到 TFF 中?

我们将深入讨论 TFF 的 ClientData 的两种用法;如果您符合以下两种类别中的任何一种,您将明显更喜欢其中一种。如果不是,您可能需要更详细地了解每种方法的优缺点,才能做出更细致的选择。

  • 我想尽快在本地机器上进行迭代;我不需要能够轻松利用 TFF 的分布式运行时。

    • 您想将 tf.data.Datasets 直接传递到 TFF 中。
    • 这使您可以使用 tf.data.Dataset 对象进行命令式编程,并任意处理它们。
    • 它提供了比以下选项更多的灵活性;将逻辑推送到客户端要求此逻辑可序列化。
  • 我想在 TFF 的远程运行时中运行我的联邦计算,或者我很快就会这样做。

    • 在这种情况下,您想将数据集构建和预处理映射到客户端。
    • 这会导致您仅将 client_ids 列表直接传递到您的联邦计算中。
    • 将数据集构建和预处理推送到客户端可以避免序列化瓶颈,并且在数百到数千个客户端的情况下显着提高性能。

设置开源环境

导入包

操作 ClientData 对象

让我们从加载和探索 TFF 的 EMNIST ClientData 开始

client_data, _ = tff.simulation.datasets.emnist.load_data()

检查第一个数据集可以告诉我们 ClientData 中包含哪些类型的示例。

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

请注意,数据集会生成 collections.OrderedDict 对象,这些对象具有 pixelslabel 键,其中 pixels 是一个形状为 [28, 28] 的张量。假设我们希望将输入展平为形状 [784]。一种可能的方法是将预处理函数应用于我们的 ClientData 对象。

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

此外,我们可能希望执行一些更复杂的(可能是有状态的)预处理,例如洗牌。

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

tff.Computation 交互

现在我们可以对 ClientData 对象进行一些基本操作,我们已准备好将数据馈送到 tff.Computation。我们定义一个 tff.templates.IterativeProcess,它实现了 联邦平均,并探索将数据传递给它的不同方法。

keras_model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(784,)),
    tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # Note: input spec is the _batched_ shape, and includes the
    # label tensor which will be passed to the loss function. This model is
    # therefore configured to accept data _after_ it has been preprocessed.
    input_spec=collections.OrderedDict(
        x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
        y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64),
    ),
    metrics_constructor=collections.OrderedDict(
        loss=lambda: tf.keras.metrics.SparseCategoricalCrossentropy(
            from_logits=True
        ),
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy,
    ),
)

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    tff_model,
    client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.01),
)

在我们开始使用此 IterativeProcess 之前,需要对 ClientData 的语义进行说明。一个 ClientData 对象代表可用于联邦训练的整个群体,一般来说,生产 FL 系统的执行环境无法访问,并且特定于模拟。 ClientData 确实为用户提供了绕过联邦计算并通过 ClientData.create_tf_dataset_from_all_clients 以通常方式训练服务器端模型的能力。

TFF 的模拟环境使研究人员能够完全控制外循环。特别是,这意味着用户或 Python 驱动程序脚本必须解决客户端可用性、客户端掉线等问题。例如,可以通过调整 ClientDataclient_ids 上的采样分布来模拟客户端掉线,这样数据量更多(相应地,本地计算运行时间更长)的用户将以较低的概率被选中。

然而,在真实的联邦系统中,模型训练器无法显式地选择客户端;客户端的选择委托给执行联邦计算的系统。

tf.data.Datasets 直接传递给 TFF

ClientDataIterativeProcess 之间进行交互的一种选择是在 Python 中构造 tf.data.Datasets,并将这些数据集传递给 TFF。

请注意,如果我们使用预处理后的 ClientData,那么我们生成的数据集将是我们上面定义的模型所期望的适当类型。

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]
    )
    for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  result = trainer.next(state, preprocessed_data_for_clients)
  state = result.state
  train_metrics = result.metrics['client_work']['train']
  t2 = time.time()
  print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')
loss 2.89, round time 2.35 seconds
loss 3.05, round time 2.26 seconds
loss 2.80, round time 0.63 seconds
loss 2.94, round time 3.18 seconds
loss 3.17, round time 2.44 seconds

但是,如果我们采用这种方法,我们将 **无法轻松地迁移到多机模拟**。我们在本地 TensorFlow 运行时构造的数据集可以 *捕获来自周围 Python 环境的状态*,并在尝试引用不再可用的状态时,在序列化或反序列化过程中失败。例如,这可能会在 TensorFlow 的 tensor_util.cc 中出现难以理解的错误。

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

在客户端上映射构造和预处理

为了避免此问题,TFF 建议用户将数据集实例化和预处理视为 *在每个客户端本地发生的事情*,并使用 TFF 的帮助程序或 federated_map 在每个客户端显式地运行此预处理代码。

从概念上讲,首选此方法的原因很清楚:在 TFF 的本地运行时,客户端仅“意外地”访问全局 Python 环境,因为整个联邦编排都在一台机器上进行。值得注意的是,类似的思考方式导致了 TFF 的跨平台、始终可序列化的函数式理念。

TFF 通过 ClientData 的属性 dataset_computation 使这种更改变得简单,这是一个 tff.Computation,它接受一个 client_id 并返回关联的 tf.data.Dataset

请注意,preprocess 只与 dataset_computation 一起工作;预处理后的 ClientDatadataset_computation 属性包含我们刚刚定义的整个预处理管道。

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(str -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(str -> <x=float32[?,784],y=int64[?,1]>*)

我们可以调用 dataset_computation 并接收 Python 运行时中的一个急切数据集,但这种方法的真正强大之处在于,当我们与迭代过程或其他计算进行组合时,完全避免在全局急切运行时中物化这些数据集。TFF 提供了一个帮助程序函数 tff.simulation.compose_dataset_computation_with_iterative_process,它可以用来完成这项工作。

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

这两个 tff.templates.IterativeProcesses 和上面的那个都以相同的方式运行;但前者接受预处理后的客户端数据集,而后者接受表示客户端 ID 的字符串,在其主体中处理数据集构造和预处理——事实上,state 可以在这两者之间传递。

for _ in range(5):
  t1 = time.time()
  result = trainer_accepting_ids.next(state, selected_client_ids)
  state = result.state
  train_metrics = result.metrics['client_work']['train']
  t2 = time.time()
  print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')

扩展到大量客户端

trainer_accepting_ids 可以立即在 TFF 的多机运行时中使用,并且避免物化 tf.data.Datasets 和控制器(因此避免序列化它们并将它们发送到工作器)。

这显着加快了分布式模拟的速度,尤其是在客户端数量众多时,并且能够进行中间聚合以避免类似的序列化/反序列化开销。

可选深入研究:在 TFF 中手动组合预处理逻辑

TFF 从一开始就旨在实现组合性;TFF 帮助程序刚刚执行的组合类型完全在我们用户的控制范围内。我们可以手动将我们刚刚定义的预处理计算与训练器自己的 next 进行组合,非常简单。

selected_clients_type = tff.FederatedType(
    preprocessed_and_shuffled.dataset_computation.type_signature.parameter,
    tff.CLIENTS,
)


@tff.federated_computation(
    trainer.next.type_signature.parameter[0], selected_clients_type
)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(
      preprocessed_and_shuffled.dataset_computation, selected_clients
  )
  return trainer.next(server_state, preprocessed_data)


manual_trainer_with_preprocessing = tff.templates.IterativeProcess(
    initialize_fn=trainer.initialize, next_fn=new_next
)

事实上,这实际上是我们使用的帮助程序在幕后所做的事情(加上执行适当的类型检查和操作)。我们甚至可以以略微不同的方式表达相同的逻辑,方法是将 preprocess_and_shuffle 序列化为一个 tff.Computation,并将 federated_map 分解为一个步骤,该步骤构造未预处理的数据集,另一个步骤在每个客户端运行 preprocess_and_shuffle

我们可以验证这条更手动路径会导致计算结果与 TFF 帮助程序的类型签名相同(参数名称除外)。

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,client_data={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)
(<server_state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,selected_clients={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)