TensorFlow Federated:分布式数据上的机器学习

import collections
import tensorflow as tf
import tensorflow_federated as tff

# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  return source.create_tf_dataset_for_client(source.client_ids[n]).map(
      lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])
  ).repeat(10).batch(20)

# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)]

# Wrap a Keras model for use with TFF.
keras_model = tf.keras.models.Sequential([
  tf.keras.layers.Dense(
    10, tf.nn.softmax, input_shape=(784,), kernel_initializer='zeros')
])
tff_model = tff.learning.models.functional_model_from_keras(
      keras_model,
      loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      input_spec=train_data[0].element_spec,
      metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy))

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.algorithms.build_weighted_fed_avg(
  tff_model,
  client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.1))
state = trainer.initialize()
for _ in range(5):
  result = trainer.next(state, train_data)
  state = result.state
  metrics = result.metrics
  print(metrics['client_work']['train']['accuracy'])
  • TensorFlow Federated (TFF) 是一个用于在分布式数据上进行机器学习和其他计算的开源框架。TFF 的开发旨在促进对 联邦学习 (FL) 的开放研究和实验。联邦学习是一种机器学习方法,它通过许多参与的客户端训练共享的全局模型,同时将训练数据保留在本地。例如,FL 已被用于在不将敏感输入数据上传到服务器的情况下,训练移动键盘的预测模型

    TFF 使开发者能够在其模型和数据上模拟所包含的联邦学习算法,并尝试新的算法。研究人员可以在 入门指南和完整示例 中找到适用于多种研究的参考。TFF 提供的构建模块也可用于实现非学习类计算,例如联邦分析。TFF 的接口分为两个主要层级:

  • 该层提供了一组高级接口,允许开发者将其现有的 TensorFlow 模型应用于内置的联邦训练和评估实现中。
  • 系统的核心是一组更底层的接口,通过在强类型函数式编程环境中将 TensorFlow 与分布式通信算子相结合,简洁地表达新型联邦算法。该层也是我们构建联邦学习的基础。
  • TFF 使开发者能够声明式地表达联邦计算,从而将其部署到不同的运行时环境中。TFF 附带了一个用于实验的高性能多机模拟运行时。请访问 教程 并亲自尝试!

    如有问题或需要支持,请在 StackOverflow 上查看 tensorflow-federated 标签