JAX 上的 TensorFlow Probability

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow Probability (TFP) 是一个用于概率推理和统计分析的库,现在也支持 JAX!对于不熟悉的人来说,JAX 是一个基于可组合函数转换的加速数值计算库。

JAX 上的 TFP 支持常规 TFP 中许多最实用的功能,同时保留了现在许多 TFP 用户熟悉的抽象和 API。

设置

JAX 上的 TFP 依赖 TensorFlow;让我们完全从这个 Colab 中卸载 TensorFlow。

pip uninstall tensorflow -y -q

我们可以使用 TFP 的最新 nightly 版本安装 JAX 上的 TFP。

pip install -Uq tfp-nightly[jax] > /dev/null

让我们导入一些有用的 Python 库。

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

让我们也导入一些基本的 JAX 功能。

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

导入 JAX 上的 TFP

要在 JAX 上使用 TFP,只需导入 jax “基底” 并像往常一样使用它 tfp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

演示:贝叶斯逻辑回归

为了演示我们使用 JAX 后端可以做什么,我们将实现应用于经典 Iris 数据集的贝叶斯逻辑回归。

首先,让我们导入 Iris 数据集并提取一些元数据。

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

我们可以使用 tfd.JointDistributionCoroutine 定义模型。我们将对权重和偏差项都使用标准正态先验,然后编写一个 target_log_prob 函数,该函数将采样的标签固定到数据。

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

我们从 dist 中采样以生成 MCMC 的初始状态。然后,我们可以定义一个函数,该函数接收一个随机键和一个初始状态,并从 No-U-Turn-Sampler (NUTS) 中生成 500 个样本。请注意,我们可以使用 JAX 变换(如 jit)使用 XLA 编译我们的 NUTS 采样器。

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

让我们使用我们的样本通过对每个权重集的预测概率进行平均来执行贝叶斯模型平均 (BMA)。

首先,让我们编写一个函数,该函数对于给定的参数集将生成每个类别的概率。我们可以使用 dist.sample_distributions 获取模型中的最终分布。

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

我们可以对样本集使用 vmap(classifier_probs) 来获取我们每个样本的预测类别概率。然后,我们计算每个样本的平均准确率以及贝叶斯模型平均的准确率。

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

看起来 BMA 将我们的错误率降低了近三分之一!

基础知识

JAX 上的 TFP 具有与 TF 相同的 API,只是它不接受 TF 对象(如 tf.Tensor),而是接受 JAX 等效项。例如,在以前使用 tf.Tensor 作为输入的任何地方,API 现在都期望一个 JAX DeviceArray。TFP 方法不再返回 tf.Tensor,而是返回 DeviceArray。JAX 上的 TFP 也适用于 JAX 对象的嵌套结构,例如 DeviceArray 的列表或字典。

分布

TFP 中的大多数分布都支持 JAX,其语义与 TF 中的对应分布非常相似。它们也注册为 JAX Pytrees,因此它们可以作为 JAX 变换函数的输入和输出。

基本分布

分布的 log_prob 方法的工作方式相同。

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

从分布中采样需要显式地将 PRNGKey(或整数列表)作为 seed 关键字参数传递。如果未显式传递种子,则会抛出错误。

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

分布的形状语义在 JAX 中保持不变,其中每个分布都具有 event_shapebatch_shape,并且绘制多个样本将添加额外的 sample_shape 维度。

例如,具有向量参数的 tfd.MultivariateNormalDiag 将具有向量事件形状和空批次形状。

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

另一方面,使用向量参数化的 tfd.Normal 将具有标量事件形状和向量批次形状。

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

在 JAX 中,对样本进行 log_prob 操作的语义与 TensorFlow 中相同。

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

由于 JAX 的 DeviceArray 与 NumPy 和 Matplotlib 等库兼容,我们可以直接将样本输入绘图函数。

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution 方法与 JAX 变换兼容。

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

由于 TFP 分布被注册为 JAX pytree 节点,我们可以编写以分布作为输入或输出的函数,并使用 jit 对其进行变换,但它们目前尚不支持作为 vmap 函数的参数。

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

变换后的分布

变换后的分布,即样本经过 Bijector 处理的分布,也可以直接使用(双射器也可以使用!见下文)。

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

联合分布

TFP 提供了 JointDistribution,可以将组件分布组合成单个分布,覆盖多个随机变量。目前,TFP 提供三种核心变体(JointDistributionSequentialJointDistributionNamedJointDistributionCoroutine),它们都支持 JAX。 AutoBatched 变体也全部支持。

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

其他分布

高斯过程在 JAX 模式下也能正常工作!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

隐马尔可夫模型也受支持。

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

一些分布,例如 PixelCNN,由于严格依赖于 TensorFlow 或 XLA 不兼容性,目前尚不支持。

双射器

TFP 的大多数双射器在 JAX 中都受支持!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

双射器与 JAX 变换(如 jitgradvmap)兼容。

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

一些双射器,例如 RealNVPFFJORD,目前尚不支持。

MCMC

我们也已将 tfp.mcmc 移植到 JAX,因此我们可以在 JAX 中运行哈密顿蒙特卡罗 (HMC) 和无 U 转变采样器 (NUTS) 等算法。

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

与 TF 上的 TFP 不同,我们需要使用 seed 关键字参数将 PRNGKey 传递到 sample_chain 中。

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

要运行多个链,我们可以将一批状态传递到 sample_chain 中,或者使用 vmap(尽管我们尚未探索两种方法之间的性能差异)。

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

优化器

JAX 上的 TFP 支持一些重要的优化器,如 BFGS 和 L-BFGS。让我们设置一个简单的缩放二次损失函数。

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS 可以找到此损失的最小值。

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

L-BFGS 也可以。

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

要对 L-BFGS 进行 vmap,让我们设置一个函数,该函数针对单个起点优化损失。

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

注意事项

TF 和 JAX 之间存在一些根本差异,TFP 在两种底层结构中的某些行为将有所不同,并非所有功能都受支持。例如,

  • JAX 上的 TFP 不支持任何类似于 tf.Variable 的东西,因为 JAX 中不存在类似的东西。这也意味着 tfp.util.TransformedVariable 等实用程序也不受支持。
  • tfp.layers 目前在后端不受支持,因为它依赖于 Keras 和 tf.Variable
  • tfp.math.minimize 在 JAX 上的 TFP 中不起作用,因为它依赖于 tf.Variable
  • 在 JAX 上的 TFP 中,张量形状始终是具体的整数值,永远不会像 TF 上的 TFP 那样未知/动态。
  • TF 和 JAX 中的伪随机性处理方式不同(见附录)。
  • tfp.experimental 中的库不能保证在 JAX 底层结构中存在。
  • TF 和 JAX 之间的 dtype 提升规则不同。为了保持一致性,JAX 上的 TFP 尝试在内部尊重 TF 的 dtype 语义。
  • 双射器尚未注册为 JAX pytree。

要查看 JAX 上的 TFP 中支持的完整列表,请参阅 API 文档

结论

我们已将 TFP 的许多功能移植到 JAX,并期待看到大家将构建出什么。一些功能目前尚不支持;如果您发现我们遗漏了一些对您很重要的功能(或发现错误!),请与我们联系 - 您可以发送电子邮件至 [email protected] 或在 我们的 Github 仓库 中提交问题。

附录:JAX 中的伪随机性

JAX 的伪随机数生成 (PRNG) 模型是无状态的。与有状态模型不同,没有可变的全局状态在每次随机抽取后发生变化。在 JAX 的模型中,我们从一个 PRNG 开始,它就像一对 32 位整数。我们可以使用 jax.random.PRNGKey 来构造这些键。

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

JAX 中的随机函数使用一个键来确定性地生成随机变量,这意味着它们不应再次使用。例如,我们可以使用 key 来采样一个正态分布的值,但我们不应在其他地方再次使用 key。此外,将相同的值传递到 random.normal 将生成相同的值。

print(random.normal(key))
-0.20584226

那么我们如何从单个键中抽取多个样本呢?答案是键拆分。基本思想是我们可以将一个 PRNGKey 拆分为多个,每个新键都可以被视为一个独立的随机性来源。

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

键拆分是确定性的,但具有混沌性,因此每个新键现在都可以用来抽取不同的随机样本。

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

有关 JAX 的确定性键拆分模型的更多详细信息,请参阅 本指南