在 TFF 中加载远程数据


在联邦学习的实际应用中,原始训练数据通常分布在许多设备或数据孤岛中,需要在使用之前进行特殊预处理和加载。

本教程介绍了如何使用 TFF 的 DataBackendDataExecutor 接口加载存储在这些远程位置的示例,并使用它们来训练使用联邦学习的模型。我们将通过使用本地存储的训练数据集来演示数据加载 API 的使用,并模拟示例的采样,就好像数据集被划分为不同的远程客户端一样。当您将本教程调整为您的用例时,您只需将该数据集替换为您自己的分布式数据。

如果您不熟悉联邦学习或 TFF,请考虑阅读 用于图像分类的联邦学习 以获取入门知识。

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

开始之前

在开始之前,请运行以下命令以确保您的环境已正确设置。有关更多信息,请参阅 安装 指南。

设置开源环境

导入包

准备输入数据

让我们从加载 TFF 的 EMNIST 数据集 的联邦版本开始,该版本来自内置存储库

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

并构建一个预处理函数来转换 EMNIST 数据集中的原始示例。

NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100


def preprocess(dataset):

  def map_fn(element):
    # Rename the features from `pixels` and `label`, to `x` and `y` for use with
    # Keras.
    return collections.OrderedDict(
        # Transform each `28x28` image into a `784`-element array.
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  # Shuffle the individual examples and `repeat` over several epochs.
  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

让我们验证一下它是否有效

# The local dataset corresponding to a single client as tf.data.Dataset.
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

preprocessed_example_dataset = preprocess(example_dataset)
print(preprocessed_example_dataset)
<MapDataset element_spec=OrderedDict([('x', TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))])>

接下来,我们将构建一个 DataBackend 的实现,它将从 EMNIST 数据集中的客户端加载和预处理本地示例,这对于在联邦学习期间获取可训练示例至关重要。

定义如何获取客户端数据

我们需要一个 DataBackend 的实例来指示 TFF 工作人员如何加载和转换本地数据。

TFF 工作人员是在边缘机器上运行并执行单个或多个逻辑客户端工作的进程。在本例中,我们将用于训练的 EMNIST 数据集已按逻辑客户端进行划分,所有工作人员都将在同一个本地环境中运行。因此,我们的 DataBackend 可以引用与任何客户端相对应的数据。但在非实验性设置中,TFF 工作人员将分布在各个远程机器上,每个机器都映射到一组不同的客户端,您需要确保 DataBackend 可以根据其本地上下文正确解析数据引用。

# A `DataBackend` is a programmatic construct that resolves symbolic references,
# represented as application-specific URIs, to materialized examples that
# TFF operations can process.
class MyDataBackend(tff.framework.DataBackend):

  async def materialize(self, data, type_spec):
    # In this example, the URI contains the Id of a client.
    client_id = int(data.uri[-1])
    # The client Id is used to retrieve the corresponding local data.
    client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[client_id])
    # We process the client dataset before returning so its compatible with our
    # model definitions.
    return preprocess(client_dataset)

设置运行时环境

TFF 计算由 ExecutionContext 调用,为了使在 TFF 计算中定义的数据 URI 在运行时得到理解,必须为 TFF 工作人员定义一个自定义上下文,其中包含指向我们刚刚创建的 DataBackend 的指针,以便可以正确解析 URI。

def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
  # A `DataBackend` object is wrapped by a `DataExecutor`, which queries the
  # backend when a TFF worker encounters an operation requires fetching local
  # data.
  return tff.framework.DataExecutor(
      tff.framework.EagerTFExecutor(device), data_backend=MyDataBackend())


# In a distributed setting, this needs to run in the TFF worker as a service
# connecting to some port. The top-level controller feeding TFF computations
# would then connect to this port.
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.SyncExecutionContext(executor_fn=factory)
tff.framework.set_default_context(ctx)

训练模型

现在我们已准备好使用联邦学习来训练模型。让我们定义一个 Keras 模型

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])


def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

我们可以通过调用辅助函数 联邦平均 算法来将我们模型的这个 TFF 包装定义传递给 tff.learning.algorithms.build_weighted_fed_avg,如下所示

iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

state = iterative_process.initialize()

initialize 计算返回联邦平均过程的初始状态。

要运行一轮训练,我们需要通过收集 URI 引用样本的方式构建数据样本,如下所示

NUM_CLIENTS = 10

element_type = tff.types.StructWithPythonType(
    preprocessed_example_dataset.element_spec,
    container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)

round_data_uris = [f'uri://{i}' for i in range(NUM_CLIENTS)]
round_train_data = tff.framework.CreateDataDescriptor(
    arg_uris=round_data_uris, arg_type=dataset_type)

现在我们可以进行一轮训练

result = iterative_process.next(state, round_train_data)
state = result.state
metrics = result.metrics
print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11234568), ('loss', 11.965633), ('num_examples', 4860), ('num_batches', 4860)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

进行多轮训练

我们可以为选择客户端和组装检索本地数据的输入定义一个 FederatedDataSource 容器。这使得在多轮训练中循环变得很方便,并且可以在多个训练作业中重复使用。

class MyFederatedDataSourceIterator(tff.program.FederatedDataSourceIterator):

  def __init__(self, client_ids: Sequence[str],
               federated_type: tff.FederatedType):
    self._client_ids = client_ids
    self._federated_type = federated_type

  @property
  def federated_type(self) -> tff.FederatedType:
    return self._federated_type

  def select(self, num_clients: Optional[int] = None) -> Any:
    client_ids_sample = random.sample(self._client_ids, num_clients)
    data_uris = [f'uri://{i}' for i in client_ids_sample]
    return tff.framework.CreateDataDescriptor(
        arg_uris=data_uris, arg_type=self._federated_type)


class MyFederatedDataSource(tff.program.FederatedDataSource):

  def __init__(self, client_ids: Sequence[str],
               federated_type: tff.FederatedType):
    self._client_ids = client_ids
    self._federated_type = federated_type
    self._capabilities = [tff.program.Capability.RANDOM_UNIFORM]

  @property
  def federated_type(self) -> tff.FederatedType:
    return self._federated_type

  @property
  def capabilities(self) -> List[tff.program.Capability]:
    return self._capabilities

  def iterator(self) -> tff.program.FederatedDataSourceIterator:
    return MyFederatedDataSourceIterator(self._client_ids, self._federated_type)


train_data_source = MyFederatedDataSource(
    client_ids=emnist_train.client_ids, federated_type=dataset_type)
train_data_iterator = train_data_source.iterator()

现在我们可以像这样运行联邦学习训练循环

NUM_ROUNDS = 10

for round_num in range(2, NUM_ROUNDS + 1):
  round_train_data = train_data_iterator.select(NUM_CLIENTS)
  result = iterative_process.next(state, round_train_data)
  state = result.state
  metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12357217), ('loss', 9.161968), ('num_examples', 4815), ('num_batches', 4815)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20563674), ('loss', 7.0862083), ('num_examples', 4790), ('num_batches', 4790)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30241227), ('loss', 5.6945825), ('num_examples', 4560), ('num_batches', 4560)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3867347), ('loss', 4.7210026), ('num_examples', 4900), ('num_batches', 4900)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42311886), ('loss', 4.205554), ('num_examples', 4585), ('num_batches', 4585)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.4501548), ('loss', 4.1297464), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.56590474), ('loss', 2.8927681), ('num_examples', 5250), ('num_batches', 5250)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.59917355), ('loss', 2.7431731), ('num_examples', 4840), ('num_batches', 4840)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5717234), ('loss', 2.9738288), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

结论

本教程到此结束。我们鼓励您探索我们开发的其他教程,以了解 TFF 框架的许多其他功能。