使用加速器进行 TFF 模拟

本教程将介绍如何使用加速器设置 TFF 模拟。目前我们重点关注单机(多)GPU,并将在此教程中更新多机和 TPU 设置。

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

开始之前

首先,让我们确保笔记本连接到具有相关组件编译的后端。

pip install --quiet --upgrade tensorflow-federated
pip install -U tensorboard_plugin_profile
%load_ext tensorboard
import collections
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

检查 TF 是否可以检测到物理 GPU 并为 TFF GPU 模拟创建虚拟多 GPU 环境。两个虚拟 GPU 将具有有限的内存,以演示如何配置 TFF 运行时。

gpu_devices = tf.config.list_physical_devices('GPU')
if not gpu_devices:
  raise ValueError('Cannot detect physical GPU device in TF')
# TODO: b/277213652 - Remove this call, as it doesn't work with C++ executor
tf.config.set_logical_device_configuration(
    gpu_devices[0], 
    [tf.config.LogicalDeviceConfiguration(memory_limit=1024),
     tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
tf.config.list_logical_devices()
[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/device:GPU:1', device_type='GPU')]

运行以下“Hello World”示例,以确保 TFF 环境已正确设置。如果它不起作用,请参阅安装指南以获取说明。

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

EMNIST 实验设置

在本教程中,我们将使用联合平均算法训练 EMNIST 图像分类器。让我们从 TFF 网站加载 MNIST 示例开始。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)

我们定义一个函数,根据simple_fedavg示例对 EMNIST 示例进行预处理。请注意,参数client_epochs_per_round控制联合学习中客户端的本地轮数。

def preprocess_emnist_dataset(client_epochs_per_round, batch_size, test_batch_size):

  def element_fn(element):
    return collections.OrderedDict(
        x=tf.expand_dims(element['pixels'], -1), y=element['label'])

  def preprocess_train_dataset(dataset):
    # Use buffer_size same as the maximum client dataset size,
    # 418 for Federated EMNIST
    return dataset.map(element_fn).shuffle(buffer_size=418).repeat(
        count=client_epochs_per_round).batch(batch_size, drop_remainder=False)

  def preprocess_test_dataset(dataset):
    return dataset.map(element_fn).batch(test_batch_size, drop_remainder=False)

  train_set = emnist_train.preprocess(preprocess_train_dataset)
  test_set = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())
  return train_set, test_set

我们使用类似 VGG 的模型,即每个块具有两个 3x3 卷积,并且当特征图被下采样时,滤波器数量会加倍。

def _conv_3x3(input_tensor, filters, strides):
  """2D Convolutional layer with kernel size 3x3."""

  x = tf.keras.layers.Conv2D(
      filters=filters,
      strides=strides,
      kernel_size=3,
      padding='same',
      kernel_initializer='he_normal',
      use_bias=False,
  )(input_tensor)
  return x


def _basic_block(input_tensor, filters, strides):
  """A block of two 3x3 conv layers."""

  x = input_tensor
  x = _conv_3x3(x, filters, strides)
  x = tf.keras.layers.Activation('relu')(x)

  x = _conv_3x3(x, filters, 1)
  x = tf.keras.layers.Activation('relu')(x)
  return x


def _vgg_block(input_tensor, size, filters, strides):
  """A stack of basic blocks."""
  x = _basic_block(input_tensor, filters, strides=strides)
  for _ in range(size - 1):
      x = _basic_block(x, filters, strides=1)
  return x


def create_cnn(num_blocks, conv_width_multiplier=1, num_classes=10):
  """Create a VGG-like CNN model. 

  The CNN has (6*num_blocks + 2) layers.
  """
  input_shape = (28, 28, 1)  # channels_last
  img_input = tf.keras.layers.Input(shape=input_shape)
  x = img_input
  x = tf.image.per_image_standardization(x)

  x = _conv_3x3(x, 16 * conv_width_multiplier, 1)
  x = _vgg_block(x, size=num_blocks, filters=16 * conv_width_multiplier, strides=1)
  x = _vgg_block(x, size=num_blocks, filters=32 * conv_width_multiplier, strides=2)
  x = _vgg_block(x, size=num_blocks, filters=64 * conv_width_multiplier, strides=2)

  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Dense(num_classes)(x)

  model = tf.keras.models.Model(
      img_input,
      x,
      name='cnn-{}-{}'.format(6 * num_blocks + 2, conv_width_multiplier))
  return model

现在让我们定义 EMNIST 的训练循环。请注意,tff.learning.algorithms.build_weighted_fed_avg中的use_experimental_simulation_loop=True建议用于高性能 TFF 模拟,并且需要利用单机上的多 GPU。有关如何定义在 GPU 上具有高性能的自定义联合学习算法的示例,请参阅simple_fedavg示例,其中一个关键功能是显式使用for ... iter(dataset)进行训练循环。

def keras_evaluate(model, test_data, metric):
  metric.reset_states()
  for batch in test_data:
    preds = model(batch['x'], training=False)
    metric.update_state(y_true=batch['y'], y_pred=preds)
  return metric.result()


def run_federated_training(client_epochs_per_round, 
                           train_batch_size, 
                           test_batch_size, 
                           cnn_num_blocks, 
                           conv_width_multiplier,
                           server_learning_rate, 
                           client_learning_rate, 
                           total_rounds, 
                           clients_per_round, 
                           rounds_per_eval,
                           logdir='logdir'):

  train_data, test_data = preprocess_emnist_dataset(
      client_epochs_per_round, train_batch_size, test_batch_size)
  data_spec = test_data.element_spec

  def _model_fn():
    keras_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return tff.learning.models.from_keras_model(
        keras_model, input_spec=data_spec, loss=loss)

  def _server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=server_learning_rate)

  def _client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=client_learning_rate)

  learning_process = tff.learning.algorithms.build_weighted_fed_avg(
      model_fn=_model_fn, 
      server_optimizer_fn=_server_optimizer_fn, 
      client_optimizer_fn=_client_optimizer_fn, 
      use_experimental_simulation_loop=True)

  metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
  eval_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
  logging.info(eval_model.summary())

  server_state = learning_process.initialize()
  start_time = time.time()
  for round_num in range(total_rounds):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=clients_per_round,
        replace=False)
    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]
    if round_num == total_rounds-1:
      with tf.profiler.experimental.Profile(logdir):
        result = learning_process.next(
            server_state, sampled_train_data)
    else:
      result = learning_process.next(
            server_state, sampled_train_data)
    server_state = result.state
    train_metrics = result.metrics['client_work']['train']
    print(f'Round {round_num} training loss: {train_metrics["loss"]}, '
     f'time: {(time.time()-start_time)/(round_num+1.)} secs')
    if round_num % rounds_per_eval == 0 or round_num == total_rounds-1:
      model_weights = learning_process.get_model_weights(server_state)
      model_weights.assign_weights_to(eval_model)
      accuracy = keras_evaluate(eval_model, test_data, metric)
      print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')

单 GPU 执行

TFF 的默认运行时与 TF 相同:当提供 GPU 时,将选择第一个 GPU 进行执行。我们使用相对较小的模型运行先前定义的联合训练,并进行几轮。最后一轮执行使用tf.profiler进行分析,并通过tensorboard进行可视化。分析验证了使用了第一个 GPU。

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=2, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=10,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_4 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_5 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_6 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_7 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_8 (Activation)    (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_9 (Activation)    (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_10 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_11 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4688243865966797, time: 13.382015466690063 secs
Round 0 validation accuracy: 15.240497589111328
Round 1 training loss: 2.3217368125915527, time: 9.311999917030334 secs
Round 2 training loss: 2.3100595474243164, time: 6.972411632537842 secs
Round 2 validation accuracy: 11.226489067077637
Round 3 training loss: 2.303222417831421, time: 6.467299699783325 secs
Round 4 training loss: 2.2976326942443848, time: 5.526083135604859 secs
Round 4 validation accuracy: 11.224040031433105
Round 5 training loss: 2.2919719219207764, time: 5.468692660331726 secs
Round 6 training loss: 2.2911534309387207, time: 4.935825347900391 secs
Round 6 validation accuracy: 11.833855628967285
Round 7 training loss: 2.2871201038360596, time: 4.918408691883087 secs
Round 8 training loss: 2.2818832397460938, time: 4.602836343977186 secs
Round 8 validation accuracy: 11.385677337646484
Round 9 training loss: 2.2790346145629883, time: 4.99558527469635 secs
Round 9 validation accuracy: 11.226489067077637
%tensorboard --logdir=logdir --port=0

更大模型和 OOM

让我们在 CPU 上运行更大的模型,并进行较少的联合轮数。

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_39 (Conv2D)           (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_40 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_36 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_41 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_37 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_42 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_38 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_43 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_39 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_44 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_40 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_45 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_41 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_46 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_42 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_47 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_43 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_48 (Conv2D)           (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_44 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_49 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_45 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_50 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_46 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_51 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_47 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_52 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_48 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_53 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_49 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_54 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_50 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_55 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_51 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_56 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_52 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_57 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_53 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_54 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_55 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_60 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_56 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_61 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_57 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_62 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_58 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_63 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_59 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d_3 ( (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.437223434448242, time: 24.121686458587646 secs
Round 0 validation accuracy: 9.024785041809082
Round 1 training loss: 2.3081459999084473, time: 19.48685622215271 secs
Round 2 training loss: 2.305305242538452, time: 15.73950457572937 secs
Round 2 validation accuracy: 9.791339874267578
Round 3 training loss: 2.303149700164795, time: 15.194068729877472 secs
Round 4 training loss: 2.3026506900787354, time: 14.036769819259643 secs
Round 4 validation accuracy: 12.193867683410645

此模型可能会在单个 GPU 上遇到内存不足问题。从大规模 CPU 实验迁移到 GPU 模拟可能会受到内存使用量的限制,因为 GPU 通常具有有限的内存。TFF 运行时中可以调整几个参数以缓解 OOM 问题

# Control concurrency by `max_concurrent_computation_calls`.
tff.backends.native.set_sync_local_cpp_execution_context(
    max_concurrent_computation_calls=16/2)

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_5 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_6 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_7 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_8 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_9 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_10 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_11 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_13 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_14 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_15 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_16 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_17 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_18 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_19 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_20 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_21 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_22 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_23 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4990053176879883, time: 11.922378778457642 secs
Round 0 validation accuracy: 11.224040031433105
Round 1 training loss: 2.307560920715332, time: 9.916815996170044 secs
Round 2 training loss: 2.3032877445220947, time: 7.68927804629008 secs
Round 2 validation accuracy: 11.224040031433105
Round 3 training loss: 2.302366256713867, time: 7.681552231311798 secs
Round 4 training loss: 2.301671028137207, time: 7.613566827774048 secs
Round 4 validation accuracy: 11.224040031433105

优化性能

通常可以在 TFF 中使用 TF 中可以实现更高性能的技术,例如混合精度训练XLA。混合精度的加速(在 V100 等 GPU 上)和内存节省通常非常显著,可以通过tf.profiler进行检查。

# Mixed precision training. 
tff.backends.native.set_sync_local_cpp_execution_context()
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    logdir='mixed'
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_5 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_6 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_7 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_8 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_9 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_10 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_11 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_13 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_14 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_15 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_16 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_17 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_18 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_19 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_20 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_21 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_22 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_23 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4187185764312744, time: 18.763780117034912 secs
Round 0 validation accuracy: 9.977468490600586
Round 1 training loss: 2.305102825164795, time: 13.712820529937744 secs
Round 2 training loss: 2.304737091064453, time: 9.993690172831217 secs
Round 2 validation accuracy: 11.779976844787598
Round 3 training loss: 2.2996833324432373, time: 9.29404467344284 secs
Round 4 training loss: 2.299349308013916, time: 9.195427560806275 secs
Round 4 validation accuracy: 11.224040031433105
%tensorboard --logdir=mixed --port=0