Random notes on Computer Science, Mathematics and Software Engineering

Neural Variational Inference: Importance Weighted Autoencoders

Previously we covered Variational Autoencoders (VAE) — popular inference tool based on neural networks. In this post we’ll consider, a followup work from Torronto by Y. Burda, R. Grosse and R. Salakhutdinov, Importance Weighted Autoencoders (IWAE). The crucial contribution of this work is introduction of a new lower-bound on the marginal log-likelihood \(\log p(x)\) which generalizes ELBO, but also allows one to use less accurate approximate posteriors \(q(z \mid x, \Lambda)\).

On a dessert we’ll discuss another paper, Variational inference for Monte Carlo objectives by A. Mnih and D. Rezende which aims to broaden the applicability of this approach to models where reparametrization trick can not be used (e.g. for discrete variables).

Importance Weighted Autoencoders

Let’s first answer the question of how one can come up with a lower bound for the marginal log-likelihood? In the very beginning of the series, Classical Theory post, we used some trickery to come up with the ELBO. That massaging of the marginal log-likelihood wasn’t particulary enlightening on how one could invent that lower bound. Now we’re going to consider a principled approach to invention of new lower bounds based on Jensen’s inequality.

Suppose we have some unbiased estimator \(f(x, z)\) of \(p(x)\), that is, \(\mathbb{E}_z f(x, z) = p(x)\). Then

\[ \log p(x) = \log \mathbb{E}_z f(x, z) \stackrel{\text{Jensen}}{\ge} \mathbb{E}_z \log f(x, z) \]

In particular, if \(z \sim q(z \mid x)\) and \(f(x, z) = \tfrac{p(x, z)}{q(z \mid x)}\), we obtain the standard ELBO. The IWAE paper proposes another estimate (actually, a family of estimators parametrized by an integer \(K\)) of marginal \(p(x)\):

\[ f(x, z_1, \dots, z_K) = \frac{1}{K} \sum_{k=1}^K \frac{p(x, z_k)}{q(z_k \mid x)} \]

Where each \(z_k\) comes from the same distribution \(q(z_k \mid x) = q(z \mid x)\). Obviously, \(f(x, z_1, \dots, z_K)\) is still an unbiased estimator of the \(p(x)\), and therefore \(\mathbb{E}_z \log f(x, z_1, \dots, z_K)\) is a valid lower-bound on the marginal log-likelihood.

Let’s analyze this new lower-bound now. First, let’s dissect the ELBO:

\[ \mathcal{L}(\Theta, \Lambda) = \mathbb{E}_q \log \frac{p(x, z \mid \Theta)}{q(z \mid x, \Lambda)} = \mathbb{E}_q \left[\log \frac{p(z \mid x, \Theta)}{q(z \mid x, \Lambda)} \right] + \log p(x \mid \Theta) \]

If \(q\) approximates the true posterior accurately, the first term (which is a KL-divergence, BTW) is close to zero. However, when estmating it using Monte Carlo samples, the ELBO heavily penalizes inaccurate approximations: if \(q(z \mid x, \Lambda)\) gives us samples from high probability regions of the true posterior \(p(z \mid x, \Theta)\) only occasionally (like 20% of times), the gap between the ELBO and the marginal log-likelihood would be huge (\(p(z\mid x, \Theta)\) is small, \(q(z \mid x, \Lambda)\) is big), which does not help learning. As you might have guessed, IWAE allows us to use several samples. Let’s see it in detail:

\[ \mathcal{L}_K(\Theta, \Lambda) := \mathbb{E}_q \left[\log \frac{1}{K} \sum_{k=1}^K \frac{p(x, z_k \mid \Theta)}{q(z_k \mid x, \Lambda)} \right] := \mathbb{E}_q \left[\log \frac{1}{K} \sum_{k=1}^K \frac{p(z_k \mid x, \Theta)}{q(z_k \mid x, \Lambda)} \right] + \log p(x \mid \Theta) \]

This averaging of posterior ratios saves us from bad samples screwing the lower bound, as it’ll be pushed up by good samples (provided the approximation has a reasonable probability of generating a good sample in \(K\) attempts). This allows one to perform model inference even with poor approximations \(q(z \mid x, \Lambda)\). The more samples \(K\) we use — the less accurate approximation we can tolerate. In fact, authors prove the following theorem:

Theorem 1. For all \(K\), the lower bounds satisfy \[ \log p(x \mid \Theta) \ge \mathcal{L}_{K+1}(\Theta, \Lambda) \ge \mathcal{L}_{K}(\Theta, \Lambda) \]

Moreover, if \(p(z, x \mid \Theta) / q(z \mid x, \Lambda)\) is bounded, then \(\mathcal{L}_{K}(\Theta, \Lambda)\) approaches \(\log p(x \mid \Theta)\) as \(K\) goes to infinity.

The convergence result follows from the strong law of large numbers.

As with VAE, we use the reparametrization trick to avoid backpropagation through stochastic units:

\[ \mathcal{L}_K(\Theta, \Lambda) = \mathbb{E}_{\varepsilon_1, \dots, \varepsilon_K} \log \frac{1}{K} \sum_{k=1}^K \overbrace{\frac{p(x, g(\varepsilon_k; \Lambda) \mid \Theta)}{q(g(\varepsilon_k; \Lambda) \mid x, \Lambda)}}^{w(x, \varepsilon_k, \Theta, \Lambda)} \]

The gradients then are

\[ \nabla_\Theta \mathcal{L}_K(\Theta, \Lambda) = \mathbb{E}_{\varepsilon_1, \dots, \varepsilon_K} \sum_{k=1}^K \hat w_k(x, \varepsilon_{1 \dots K}, \Theta, \Lambda) \nabla_\Theta \log w(x, \varepsilon_k, \Theta, \Lambda) \\ \nabla_\Lambda \mathcal{L}_K(\Theta, \Lambda) = \mathbb{E}_{\varepsilon_1, \dots, \varepsilon_K} \sum_{k=1}^K \hat w_k(x, \varepsilon_{1 \dots K}, \Theta, \Lambda) \nabla_\Lambda \log w(x, \varepsilon_k, \Theta, \Lambda) \\ \text{where } \hat w_k(x, \varepsilon_{1 \dots K}, \Theta, \Lambda) := \frac{w(x, \varepsilon_k, \Theta, \Lambda)}{\sum_{k=1}^K w(x, \varepsilon_k, \Theta, \Lambda)} \]

(We used the identity \(\nabla_x f(x) = f(x) \nabla_x \log f(x)\) here).

Just as one would expect, setting \(K=1\) reduces these gradients to ones we’ve seen in VAEs as the only importance weight \(\hat w_1\) is equal to 1. Unfortunatelly, this approach does not allow one to decompose the lower-bound into the reconstruction error and KL-divergence to analytically compute the later. However, authors report indistinguishable performance of 2 approaches (with KL computed analytically or estimated using Monte Carlo) in case of \(K=1\).

BTW, Hugo Larochelle writes notes on different papers, and he has written and made publicly available Notes on Importance Weighted Autoencoders.

Variational inference for Monte Carlo objectives

As I said in the introduction, IWAE has been “generalized” to discrete variables — a case when one can not employ the reparametrization trick, and instead has to somehow reduce high variance of a score function-based estimator. Previously, during our discussion of the Blackbox VI and variance reduction techniques we covered NVIL (Neural Variational Inference and Learning) estimator, which uses another neural network to estimate marginal likelihood and reduce the variance. This work is built upon a similar idea.

First, let’s derive score-function-based gradients for variational parameters \(\Lambda\) (where \(w\) now is defined as \(w(x, z, \Theta, \Lambda) = \frac{p(x, z \mid \Theta)}{q(z \mid x, \Lambda)}\), and \(\hat w\) is a normalized across all samples \(z_{1\dots K}\) version):

\[ \begin{align} \nabla_\Lambda \mathcal{L}_K(\Theta, \Lambda) &= \nabla_\Lambda \mathbb{E}_{q(z_1, \dots, z_K \mid x, \Lambda)} \log \frac{1}{K} \sum_{k=1}^K w(x, z_k, \Theta, \Lambda) \\ &= \nabla_\Lambda \int q(z_1, \dots, z_K \mid x, \Lambda) \log \frac{1}{K} \sum_{k=1}^K w(x, z_k, \Theta, \Lambda) \; dz_1 \dots dz_K \\ &= \mathbb{E}_{q(z_1, \dots, z_K \mid x, \Lambda)} \left[ \sum_{k=1}^K \nabla_\Lambda \log q(z_k \mid x, \Lambda) \log \frac{1}{K} \sum_{k=1}^K w(x, z_k, \Theta, \Lambda) \right] \\ & \quad+ \mathbb{E}_{q(z_1, \dots, z_K \mid x, \Lambda)} \nabla_\Lambda \log \sum_{k=1}^K w(x, z_k, \Theta, \Lambda) \\ &= \mathbb{E}_{q(z_1, \dots, z_K \mid x, \Lambda)} \left[ \sum_{k=1}^K \nabla_\Lambda \log q(z_k \mid x, \Lambda) \log \frac{1}{K} \sum_{k=1}^K w(x, z_k, \Theta, \Lambda) \right] \\ & \quad+ \mathbb{E}_{q(z_1, \dots, z_K \mid x, \Lambda)} \left[ \sum_{k=1}^K \hat w_k(x, z_{1 \dots K}, \Theta, \Lambda) \nabla_\Lambda \log w(x, z_k, \Theta, \Lambda) \right] \end{align} \]

The second term is exactly the gradient of the reparametrized case, and it does not cause us any troubles. The first term, however has some issues.

First, it does not distinguish individual samples’ contributions: indeed, gradients for all samples have the same weight of \(\log \tfrac{1}{K} \sum_{k=1}^K w(x, z_k, \Theta, \Lambda)\) (called the learning signal) regardless of how probable they’re in terms of the true posterior (that is, how well they describe an observation \(x\)). Compare it with the second term, where gradient for each sample \(z_k\) is weighted in proportion to its importance weight \(\hat w_k\).

Second problem is that the learning signal is unbounded, and can be quite high. Again, the second term does not suffer this as importance weights \(\hat w_k\) are normalized to sum to 1.

One can use the NVIL estimator we’ve discussed previously to reduce the variance due to large magnitude of a learning signal. However, it does not address the problem of all gradients having the same weight. For this the authors propose to introduce per-sample baselines that minimize dependencies between samples.

This paper has also caught Dr. Larochelle’s attention: Notes on Variational inference for Monte Carlo objectives.

comments powered by Disqus