TensorFlow Federated:分布式数据上的机器学习
import collections import tensorflow as tf import tensorflow_federated as tff # Load simulation data. source, _ = tff.simulation.datasets.emnist.load_data() def client_data(n): return source.create_tf_dataset_for_client(source.client_ids[n]).map( lambda e: (tf.reshape(e['pixels'], [-1]), e['label']) ).repeat(10).batch(20) # Pick a subset of client devices to participate in training. train_data = [client_data(n) for n in range(3)] # Wrap a Keras model for use with TFF. keras_model = tf.keras.models.Sequential([ tf.keras.layers.Dense( 10, tf.nn.softmax, input_shape=(784,), kernel_initializer='zeros') ]) tff_model = tff.learning.models.functional_model_from_keras( keras_model, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(), input_spec=train_data[0].element_spec, metrics_constructor=collections.OrderedDict( accuracy=tf.keras.metrics.SparseCategoricalAccuracy)) # Simulate a few rounds of training with the selected client devices. trainer = tff.learning.algorithms.build_weighted_fed_avg( tff_model, client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.1)) state = trainer.initialize() for _ in range(5): result = trainer.next(state, train_data) state = result.state metrics = result.metrics print(metrics['client_work']['train']['accuracy'])
-
联邦学习 (FL) API
该层提供了一组高级接口,允许开发者将其现有的 TensorFlow 模型应用于内置的联邦训练和评估实现中。 -
联邦核心 (FC) API
系统的核心是一组更底层的接口,通过在强类型函数式编程环境中将 TensorFlow 与分布式通信算子相结合,简洁地表达新型联邦算法。该层也是我们构建联邦学习的基础。 -
TFF 使开发者能够声明式地表达联邦计算,从而将其部署到不同的运行时环境中。TFF 附带了一个用于实验的高性能多机模拟运行时。请访问 教程 并亲自尝试!
如有问题或需要支持,请在 StackOverflow 上查看 tensorflow-federated 标签。