TFF 中 JAX 的实验性支持

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

除了作为 TensorFlow 生态系统的一部分之外,TFF 还旨在与其他前端和后端 ML 框架实现互操作性。目前,对其他 ML 框架的支持仍处于孵化阶段,支持的 API 和功能可能会发生变化(很大程度上取决于 TFF 用户的需求)。本教程介绍了如何使用 TFF 将 JAX 作为替代 ML 前端,以及 XLA 编译器作为替代后端。此处显示的示例基于端到端的完全原生 JAX/XLA 堆栈。在未来的教程中,将讨论跨框架混合代码的可能性(例如,JAX 与 TensorFlow)。

与往常一样,我们欢迎您的贡献。如果您需要 JAX/XLA 支持或与其他 ML 框架互操作的能力,请考虑帮助我们将这些功能发展到与 TFF 的其余部分相当的水平。

开始之前

请参阅 TFF 文档的主体,了解如何配置您的环境。根据您运行本教程的位置,您可能需要取消注释并运行以下代码中的部分或全部代码。

# !pip install --quiet --upgrade tensorflow-federated
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

本教程还假设您已查看 TFF 的主要 TensorFlow 教程,并且熟悉 TFF 的核心概念。如果您尚未完成此操作,请考虑至少查看其中一个教程。

JAX 计算

TFF 中对 JAX 的支持旨在与 TFF 与 TensorFlow 互操作的方式对称,从导入开始

import jax
import numpy as np
import tensorflow_federated as tff

此外,就像 TensorFlow 一样,表达任何 TFF 代码的基础是本地运行的逻辑。您可以使用 JAX 表达此逻辑,如下所示,使用 @tff.jax_computation 包装器。它的行为类似于您现在熟悉的 @tff.tf_computation。让我们从一些简单的事情开始,例如一个将两个整数相加的计算

@tff.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

您可以像使用 TFF 计算一样使用上面定义的 JAX 计算。例如,您可以检查它的类型签名,如下所示

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

请注意,我们使用了 np.int32 来定义参数的类型。TFF 不区分 Numpy 类型(例如 np.int32)和 TensorFlow 类型(例如 tf.int32)。从 TFF 的角度来看,它们只是指代同一事物的不同方式。

现在,请记住,TFF 不是 Python(如果这听起来很陌生,请查看我们之前的一些教程,例如关于自定义算法的教程)。您可以将 @tff.jax_computation 包装器与任何可以跟踪和序列化的 JAX 代码一起使用,即与您通常使用 @jax.jit 注释的代码一起使用,预期将其编译为 XLA(但您不需要实际使用 @jax.jit 注释将您的 JAX 代码嵌入 TFF 中)。

实际上,在幕后,TFF 会立即将 JAX 计算编译为 XLA。您可以通过手动从 add_numbers 中提取和打印序列化后的 XLA 代码来自己验证这一点,如下所示

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

将 JAX 计算表示为 XLA 代码可以看作是使用 TensorFlow 表达的计算的 tf.GraphDef 的功能等效项。它可移植,可以在支持 XLA 的各种环境中执行,就像 tf.GraphDef 可以执行在任何 TensorFlow 运行时上一样。

TFF 提供了一个基于 XLA 编译器作为后端的运行时堆栈。您可以按如下方式激活它

tff.backends.xla.set_local_python_execution_context()

现在,您可以执行我们上面定义的计算

add_numbers(2, 3)
5

很简单。让我们更进一步,做一些更复杂的事情,比如 MNIST。

使用预置 API 进行 MNIST 训练的示例

像往常一样,我们首先为数据批次和模型定义一堆 TFF 类型(请记住,TFF 是一个强类型框架)。

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

现在,让我们在 JAX 中定义一个模型的损失函数,将模型和单个数据批次作为参数

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

现在,一种方法是使用预置 API。以下是如何使用我们的 API 创建基于刚刚定义的损失函数的训练过程的示例。

STEP_SIZE = 0.001

trainer = tff.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

您可以像使用从 tf.Keras 模型构建的训练器一样使用上面的内容。例如,以下是如何创建用于训练的初始模型

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

为了执行实际训练,我们需要一些数据。让我们生成随机数据以保持简单。由于数据是随机的,我们将对训练数据进行评估,因为否则,如果使用随机评估数据,模型很难预期会表现良好。此外,对于这个小型演示,我们不会担心随机抽取客户端(我们将其留作练习,让用户通过遵循其他教程中的模板来探索这些类型的更改)

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

有了这些,我们可以执行一步训练,如下所示

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

让我们评估训练步骤的结果。为了保持简单,我们可以在集中式环境中进行评估

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

损失正在下降。太好了!现在,让我们在多个轮次中运行它

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

正如您所见,在 TFF 中使用 JAX 与使用 TensorFlow 并没有太大区别,尽管实验性 API 在功能方面尚未与 TensorFlow API 相媲美。

幕后

如果您不想使用我们预先打包的 API,您可以实现自己的自定义计算,这与您在 TensorFlow 的自定义算法教程中看到的方式非常相似,只是您将使用 JAX 的梯度下降机制。例如,以下是如何定义一个在单个小批量上更新模型的 JAX 计算。

@tff.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

以下是如何测试它是否有效。

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

使用 JAX 的一个注意事项是,它没有提供等效于 tf.data.Dataset 的功能。因此,为了迭代数据集,您需要使用 TFF 的声明式结构来对序列进行操作,例如下面所示的结构。

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

让我们看看它是否有效。

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

执行一轮训练的计算看起来与您在 TensorFlow 教程中看到的非常相似。

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

让我们看看它是否有效。

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

正如您所见,在 TFF 中使用 JAX,无论是通过预先打包的 API 还是直接使用低级 TFF 结构,都类似于使用 TensorFlow 与 TFF。敬请关注未来的更新,如果您想看到对跨 ML 框架的互操作性的更好支持,请随时向我们发送拉取请求!