FFJORD

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

设置

首先安装此演示中使用的软件包。

pip install -q dm-sonnet

导入 (tf、具有伴随技巧的 tfp 等)

/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

用于可视化的辅助函数

FFJORD 双射器

在此 colab 中,我们演示了 FFJORD 双射器,最初由 Grathwohl、Will 等人在论文中提出。 arxiv 链接.

简而言之,这种方法背后的想法是在已知的 **基础分布** 和 **数据分布** 之间建立对应关系。

为了建立这种联系,我们需要

  1. 定义一个双射映射 \(\mathcal{T}_{\theta}:\mathbf{x} \rightarrow \mathbf{y}\),\(\mathcal{T}_{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) 在定义 **基础分布** 的空间 \(\mathcal{Y}\) 和数据域空间 \(\mathcal{X}\) 之间。
  2. 有效地跟踪我们执行的变形,以将概率的概念转移到 \(\mathcal{X}\) 上。

第二个条件在定义在 \(\mathcal{X}\) 上的概率分布的以下表达式中形式化

\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]

FFJORD 双射器通过定义一个变换来实现这一点

\[ \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \]

只要描述状态 \(\mathbf{z}\) 演化的函数 \(\mathbf{f}\) 行为良好,并且可以计算 log_det_jacobian,此变换即可逆。

\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int_{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]

在此演示中,我们将训练一个 FFJORD 双射器,将高斯分布扭曲到由 moons 数据集定义的分布上。这将分三个步骤完成

  • 定义 **基础分布**
  • 定义 FFJORD 双射器
  • 最小化数据集的精确对数似然

首先,我们加载数据

数据集

png

接下来,我们实例化一个基础分布

base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

我们使用多层感知器来模拟 state_derivative_fn

虽然对于此数据集来说并非必需,但让 state_derivative_fn 依赖于时间通常是有益的。在这里,我们通过将 t 连接到网络的输入来实现这一点。

class MLP_ODE(snt.Module):
  """Multi-layer NN ode_fn."""
  def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
    super(MLP_ODE, self).__init__(name=name)
    self._num_hidden = num_hidden
    self._num_output = num_output
    self._num_layers = num_layers
    self._modules = []
    for _ in range(self._num_layers - 1):
      self._modules.append(snt.Linear(self._num_hidden))
      self._modules.append(tf.math.tanh)
    self._modules.append(snt.Linear(self._num_output))
    self._model = snt.Sequential(self._modules)

  def __call__(self, t, inputs):
    inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
    return self._model(inputs)

模型和训练参数

现在,我们构建一个 FFJORD 双射器堆栈。每个双射器都提供 ode_solve_fntrace_augmentation_fn,以及它自己的 state_derivative_fn 模型,以便它们代表一系列不同的变换。

构建双射器

现在,我们可以使用 TransformedDistribution,它是使用 stacked_ffjord 双射器扭曲 base_distribution 的结果。

transformed_distribution = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=stacked_ffjord)

现在,我们定义训练过程。我们只需最小化数据的负对数似然。

训练

样本

绘制来自基础分布和变换分布的样本。

evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(
    grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()

png

learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)

for epoch in tqdm.trange(NUM_EPOCHS // 2):
  base_samples, transformed_samples = get_samples()
  transformed_grid = get_transformed_grid()
  evaluation_samples.append(
      (base_samples, transformed_samples, transformed_grid))
  for batch in moons_ds:
    _ = train_step(optimizer, batch)
0%|          | 0/40 [00:00<?, ?it/s]
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()

png

使用学习率进行更长时间的训练会导致进一步的改进。

本示例中未涵盖,FFJORD 双射器支持 Hutchinson 的随机迹估计。可以通过 trace_augmentation_fn 提供特定的估计器。类似地,可以通过定义自定义 ode_solve_fn 来使用替代积分器。