在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
设置
首先安装此演示中使用的软件包。
pip install -q dm-sonnet
导入 (tf、具有伴随技巧的 tfp 等)
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead. import pandas.util.testing as tm
用于可视化的辅助函数
FFJORD 双射器
在此 colab 中,我们演示了 FFJORD 双射器,最初由 Grathwohl、Will 等人在论文中提出。 arxiv 链接.
简而言之,这种方法背后的想法是在已知的 **基础分布** 和 **数据分布** 之间建立对应关系。
为了建立这种联系,我们需要
- 定义一个双射映射 \(\mathcal{T}_{\theta}:\mathbf{x} \rightarrow \mathbf{y}\),\(\mathcal{T}_{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) 在定义 **基础分布** 的空间 \(\mathcal{Y}\) 和数据域空间 \(\mathcal{X}\) 之间。
- 有效地跟踪我们执行的变形,以将概率的概念转移到 \(\mathcal{X}\) 上。
第二个条件在定义在 \(\mathcal{X}\) 上的概率分布的以下表达式中形式化
\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]
FFJORD 双射器通过定义一个变换来实现这一点
\[ \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \]
只要描述状态 \(\mathbf{z}\) 演化的函数 \(\mathbf{f}\) 行为良好,并且可以计算 log_det_jacobian
,此变换即可逆。
\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int_{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]
在此演示中,我们将训练一个 FFJORD 双射器,将高斯分布扭曲到由 moons
数据集定义的分布上。这将分三个步骤完成
- 定义 **基础分布**
- 定义 FFJORD 双射器
- 最小化数据集的精确对数似然
首先,我们加载数据
数据集
接下来,我们实例化一个基础分布
base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)
我们使用多层感知器来模拟 state_derivative_fn
。
虽然对于此数据集来说并非必需,但让 state_derivative_fn
依赖于时间通常是有益的。在这里,我们通过将 t
连接到网络的输入来实现这一点。
class MLP_ODE(snt.Module):
"""Multi-layer NN ode_fn."""
def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
super(MLP_ODE, self).__init__(name=name)
self._num_hidden = num_hidden
self._num_output = num_output
self._num_layers = num_layers
self._modules = []
for _ in range(self._num_layers - 1):
self._modules.append(snt.Linear(self._num_hidden))
self._modules.append(tf.math.tanh)
self._modules.append(snt.Linear(self._num_output))
self._model = snt.Sequential(self._modules)
def __call__(self, t, inputs):
inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
return self._model(inputs)
模型和训练参数
现在,我们构建一个 FFJORD 双射器堆栈。每个双射器都提供 ode_solve_fn
和 trace_augmentation_fn
,以及它自己的 state_derivative_fn
模型,以便它们代表一系列不同的变换。
构建双射器
现在,我们可以使用 TransformedDistribution
,它是使用 stacked_ffjord
双射器扭曲 base_distribution
的结果。
transformed_distribution = tfd.TransformedDistribution(
distribution=base_distribution, bijector=stacked_ffjord)
现在,我们定义训练过程。我们只需最小化数据的负对数似然。
训练
样本
绘制来自基础分布和变换分布的样本。
evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
1, 4, figsize=(16, 6))
plot_panel(
grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()
learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)
for epoch in tqdm.trange(NUM_EPOCHS // 2):
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append(
(base_samples, transformed_samples, transformed_grid))
for batch in moons_ds:
_ = train_step(optimizer, batch)
0%| | 0/40 [00:00<?, ?it/s] WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.while_loop(c, b, vars, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars)) 100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()
使用学习率进行更长时间的训练会导致进一步的改进。
本示例中未涵盖,FFJORD 双射器支持 Hutchinson 的随机迹估计。可以通过 trace_augmentation_fn
提供特定的估计器。类似地,可以通过定义自定义 ode_solve_fn
来使用替代积分器。