使用 JAX2TF 导入 JAX 模型

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

此笔记本提供了一个完整的可运行示例,演示了如何使用 JAX 创建模型,并将其引入 TensorFlow 以继续训练。这得益于 JAX2TF,这是一个轻量级 API,为从 JAX 生态系统到 TensorFlow 生态系统提供了一条途径。

JAX 是一个高性能数组计算库。为了创建模型,此笔记本使用 Flax,这是一个针对 JAX 的神经网络库。为了训练它,它使用 Optax,这是一个针对 JAX 的优化库。

如果您是使用 JAX 的研究人员,JAX2TF 为您提供了一条使用 TensorFlow 的成熟工具进行生产的途径。

这有很多用途,这里只列举几个:

  • 推理:将为 JAX 编写的模型部署到服务器上(使用 TF Serving)、设备上(使用 TFLite)或 Web 上(使用 TensorFlow.js)。

  • 微调:将使用 JAX 训练的模型的组件使用 JAX2TF 引入 TF,并使用您现有的训练数据和设置在 TensorFlow 中继续训练它。

  • 融合:将使用 JAX 训练的模型部分与使用 TensorFlow 训练的模型部分结合起来,以实现最大的灵活性。

实现 JAX 和 TensorFlow 之间这种互操作性的关键是 jax2tf.convert,它接受在 JAX 之上创建的模型组件(您的损失函数、预测函数等),并创建它们在 TensorFlow 函数中的等效表示,然后可以将其导出为 TensorFlow SavedModel。

设置

import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
import flax
import optax
import os
from matplotlib import pyplot as plt
from jax.experimental import jax2tf
from threading import Lock # Only used in the visualization utility.
from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized.
    print(e)

可视化工具

下载并准备 MNIST 数据集

(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),
                                         tf.one_hot(y, depth=10)))

BATCH_SIZE = 256
train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)
train_data = train_data.cache()
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)

test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))
test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),
                                         tf.one_hot(y, depth=10)))
test_data = test_data.batch(10000)
test_data = test_data.cache()

(one_batch, one_batch_labels) = next(iter(train_data)) # just one batch
(all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000

配置训练

此笔记本将创建一个简单的模型并进行训练,以进行演示。

# Training hyperparameters.
JAX_EPOCHS = 3
TF_EPOCHS = 7
STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE
LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6

# The learning rate schedule for JAX (with Optax).
jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

# THe learning rate schedule for TensorFlow.
tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

使用 Flax 创建模型

class ConvModel(flax.linen.Module):

  @flax.linen.compact
  def __call__(self, x, train):
    x = flax.linen.Conv(features=12, kernel_size=(3,3), padding="SAME", use_bias=False)(x)
    x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)
    x = x.reshape((x.shape[0], -1))  # flatten
    x = flax.linen.Dense(features=200, use_bias=True)(x)
    x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)
    x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x)
    x = flax.linen.relu(x)
    x = flax.linen.Dense(features=10)(x)
    #x = flax.linen.log_softmax(x)
    return x

  # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).
  # `jax.grad` will differentiate it against the fist argument.
  # The user must split trainable and non-trainable variables into `params` and `other_state`.
  # Must pass a different RNG key each time for the dropout mask to be different.
  def loss(self, params, other_state, rng, data, labels, train):
    logits, batch_stats = self.apply({'params': params, **other_state},
                                     data,
                                     mutable=['batch_stats'],
                                     rngs={'dropout': rng},
                                     train=train)
    # The loss averaged across the batch dimension.
    loss = optax.softmax_cross_entropy(logits, labels).mean()
    return loss, batch_stats

  def predict(self, state, data):
    logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False)
    probabilities = flax.linen.log_softmax(logits)
    return probabilities

  def accuracy(self, state, data, labels):
    probabilities = self.predict(state, data)
    predictions = jnp.argmax(probabilities, axis=-1)
    dense_labels = jnp.argmax(labels, axis=-1)
    accuracy = jnp.equal(predictions, dense_labels).mean()
    return accuracy

编写训练步骤函数

# The training step.
@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model
def train_step(model, state, optimizer_state, rng, data, labels):

  other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables
  (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True)

  updates, optimizer_state = optimizer.update(grads, optimizer_state)
  params = optax.apply_updates(params, updates)
  new_state = state.copy(add_or_replace={**batch_stats, 'params': params})

  rng, _ = jax.random.split(rng)

  return new_state, optimizer_state, rng, loss

编写训练循环

def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies):
  p = Progress(STEPS_PER_EPOCH)
  rng = jax.random.PRNGKey(0)
  for epoch in range(epochs):

    # This is where the learning rate schedule state is stored in the optimizer state.
    optimizer_step = optimizer_state[1].count

    # Run an epoch of training.
    for step, (data, labels) in enumerate(train_data):
      p.step(reset=(step==0))
      state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())
      losses.append(loss)
    avg_loss = np.mean(losses[-step:])
    avg_losses.append(avg_loss)

    # Run one epoch of evals (10,000 test images in a single batch).
    other_state, params = state.pop('params')
    # Gotcha: must discard modified batch_stats here
    eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)
    eval_losses.append(eval_loss)
    eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy())
    eval_accuracies.append(eval_accuracy)

    print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", jlr_decay(optimizer_step))

  return state, optimizer_state

创建模型和优化器(使用 Optax)

# The model.
model = ConvModel()
state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout"

# The optimizer.
optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule
optimizer_state = optimizer.init(state['params'])

losses=[]
avg_losses=[]
eval_losses=[]
eval_accuracies=[]

训练模型

new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

部分训练模型

您将在 TensorFlow 中继续训练模型。

model = ConvModel()
state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout"

# The optimizer.
optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule
optimizer_state = optimizer.init(state['params'])

losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []
state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

仅保存足够用于推理的部分

如果您想部署您的 JAX 模型(以便您可以使用 model.predict() 进行推理),只需将其导出到 SavedModel 即可。本节将演示如何实现这一点。

# Test data with a different batch size to test polymorphic shapes.
x, y = next(iter(train_data.unbatch().batch(13)))

m = tf.Module()
# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.
state_vars = tf.nest.map_structure(tf.Variable, state)
# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).
m.vars = tf.nest.flatten(state_vars)
# Convert the desired JAX function (`model.predict`).
predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"])
# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).
@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])
def predict(data):
    return predict_fn(state_vars, data)
m.predict = predict
tf.saved_model.save(m, "./")
# Test the converted function.
print("Converted function predictions:", np.argmax(m.predict(x).numpy(), axis=-1))
# Reload the model.
reloaded_model = tf.saved_model.load("./")
# Test the reloaded converted function (the result should be the same).
print("Reloaded  function predictions:", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

保存所有内容

如果您想要进行全面导出(如果您计划将模型导入 TensorFlow 进行微调、融合等,这将非常有用),本节将演示如何保存模型,以便您可以访问以下方法:

  • model.predict
  • model.accuracy
  • model.loss(包括 train=True/False 布尔值、用于 dropout 和 BatchNorm 状态更新的 RNG)
from collections import abc

def _fix_frozen(d):
  """Changes any mappings (e.g. frozendict) back to dict."""
  if isinstance(d, list):
    return [_fix_frozen(v) for v in d]
  elif isinstance(d, tuple):
    return tuple(_fix_frozen(v) for v in d)
  elif not isinstance(d, abc.Mapping):
    return d
  d = dict(d)
  for k, v in d.items():
    d[k] = _fix_frozen(v)
  return d
class TFModel(tf.Module):
  def __init__(self, state, model):
    super().__init__()

    # Special care needed for the train=True/False parameter in the loss
    @jax.jit
    def loss_with_train_bool(state, rng, data, labels, train):
      other_state, params = state.pop('params')
      loss, batch_stats = jax.lax.cond(train,
                                       lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True),
                                       lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False),
                                       state, data, labels)
      # must use JAX to split the RNG, therefore, must do it in a @jax.jit function
      new_rng, _ = jax.random.split(rng)
      return loss, batch_stats, new_rng

    self.state_vars = tf.nest.map_structure(tf.Variable, state)
    self.vars = tf.nest.flatten(self.state_vars)
    self.jax_rng = tf.Variable(jax.random.PRNGKey(0))

    self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=["...", "...", "(b, 28, 28, 1)", "(b, 10)", "..."])
    self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=["...", "(b, 28, 28, 1)", "(b, 10)"])
    self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"])

  # Must specify TensorSpec manually for variable batch size to work
  @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])
  def predict(self, data):
    # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly
    # otherwise, all model weights would be embedded in the TF graph as constants.
    return self.predict_fn(self.state_vars, data)

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def train_loss(self, data, labels):
      loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True)
      # update batch norm stats
      flat_vars = tf.nest.flatten(self.state_vars['batch_stats'])
      flat_values = tf.nest.flatten(batch_stats['batch_stats'])
      for var, val in zip(flat_vars, flat_values):
        var.assign(val)
      # update RNG
      self.jax_rng.assign(new_rng)
      return loss

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def eval_loss(self, data, labels):
      loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False)
      return loss

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def accuracy(self, data, labels):
    return self.accuracy_fn(self.state_vars, data, labels)
# Instantiate the model.
tf_model = TFModel(state, model)

# Save the model.
tf.saved_model.save(tf_model, "./")

重新加载模型

reloaded_model = tf.saved_model.load("./")

# Test if it works and that the batch size is indeed variable.
x,y = next(iter(train_data.unbatch().batch(13)))
print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))
x,y = next(iter(train_data.unbatch().batch(20)))
print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

print(reloaded_model.accuracy(one_batch, one_batch_labels))
print(reloaded_model.accuracy(all_test_data, all_test_labels))

在 TensorFlow 中继续训练转换后的 JAX 模型

optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)

# Set the iteration step for the learning rate to resume from where it left off in JAX.
optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)

p = Progress(STEPS_PER_EPOCH)

for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):

  # This is where the learning rate schedule state is stored in the optimizer state.
  optimizer_step = optimizer.iterations

  for step, (data, labels) in enumerate(train_data):
    p.step(reset=(step==0))
    with tf.GradientTape() as tape:
      #loss = reloaded_model.loss(data, labels, True)
      loss = reloaded_model.train_loss(data, labels)
      grads = tape.gradient(loss, reloaded_model.vars)
      optimizer.apply_gradients(zip(grads, reloaded_model.vars))
      losses.append(loss)
  avg_loss = np.mean(losses[-step:])
  avg_losses.append(avg_loss)

  eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy()
  eval_losses.append(eval_loss)
  eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy()
  eval_accuracies.append(eval_accuracy)

  print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", tflr_decay(optimizer.iterations).numpy())
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH)

# The loss takes a hit when the training restarts, but does not go back to random levels.
# This is likely caused by the optimizer momentum being reinitialized.

下一步

您可以在其文档网站上了解更多关于 JAXFlax 的信息,这些网站包含详细的指南和示例。如果您是 JAX 的新手,请务必探索 JAX 101 教程,并查看 Flax 快速入门。要了解有关将 JAX 模型转换为 TensorFlow 格式的更多信息,请查看 GitHub 上的 jax2tf 实用程序。如果您有兴趣将 JAX 模型转换为在浏览器中使用 TensorFlow.js 运行,请访问 JAX on the Web with TensorFlow.js。如果您想准备 JAX 模型在 TensorFLow Lite 中运行,请访问 JAX Model Conversion For TFLite 指南。