B.log

Random notes on Computer Science, Mathematics and Software Engineering

Neural Variational Inference: Blackbox Mode

In the previous post we covered Stochastic VI: an efficient and scalable variational inference method for exponential family models. However, there’re many more distributions than those belonging to the exponential family. Inference in these cases requires significant amount of model analysis. In this post we consider Black Box Variational Inference by Ranganath et al. This work just as the previous one comes from David Blei lab — one of the leading researchers in VI. And, just for the dessert, we’ll touch upon another paper, which will finally introduce some neural networks in VI.

Blackbox Variational Inference

As we have learned so far, the goal of VI is to maximize the ELBO \(\mathcal{L}(\Theta, \Lambda)\). When we maximize it by \(\Lambda\), we decrease the gap between the marginal likelihood of the model considered \(\log p(x \mid \Theta)\), and when we maximize it by \(\Theta\) we acltually fit the model. So let’s concentrate on optimizing this objective:

\[ \mathcal{L}(\Theta, \Lambda) = \mathbb{E}_{q(z \mid x, \Lambda)} \left[\log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda) \right] \]

Let’s find gradients of this objective:

\[ \begin{align} \nabla_{\Lambda} \mathcal{L}(\Theta, \Lambda) &= \nabla_{\Lambda} \int q(z \mid x, \Lambda) \left[\log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda) \right] dz \\ &= \int \nabla_{\Lambda} q(z \mid x, \Lambda) \left[\log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda) \right] dz - \int q(z \mid x, \Lambda) \nabla_{\Lambda} \log q(z \mid x, \Lambda) dz \\ &= \mathbb{E}_{q} \left[\frac{\nabla_{\Lambda} q(z \mid x, \Lambda)}{q(z \mid x, \Lambda)} \log \frac{p(x, z \mid \Theta)}{q(z \mid x, \Lambda)} \right] - \int q(z \mid x, \Lambda) \frac{\nabla_{\Lambda} q(z \mid x, \Lambda)}{q(z \mid x, \Lambda)} dz \\ &= \mathbb{E}_{q} \left[\nabla_{\Lambda} \log q(z \mid x, \Lambda) \log \frac{p(x, z \mid \Theta)}{q(z \mid x, \Lambda)} \right] - \int \nabla_{\Lambda} q(z \mid x, \Lambda) dz \\ &= \mathbb{E}_{q} \left[\nabla_{\Lambda} \log q(z \mid x, \Lambda) \log \frac{p(x, z \mid \Theta)}{q(z \mid x, \Lambda)} \right] - \nabla_{\Lambda} \overbrace{\int q(z \mid x, \Lambda) dz}^{=1} \\ &= \mathbb{E}_{q} \left[\nabla_{\Lambda} \log q(z \mid x, \Lambda) \log \frac{p(x, z \mid \Theta)}{q(z \mid x, \Lambda)} \right] \end{align} \]

In statistics \(\nabla_\Lambda \log q(z \mid x, \Lambda)\) is known as score function. For more on this “trick” see a blogpost by Shakir Mohamed. In many cases of practical interest \(\log p(x, z, \mid \Theta)\) is too complicated to compute this expectation in closed form. Recall that we already used stochastic optimization successfully, so we can settle with just an estimate of true gradient. We get one by approximating the expectation using Monte-Carlo estimates using \(L\) samples \(z^{(l)} \sim q(z \mid x, \Lambda)\) (in practice we sometimes use just \(L=1\) sample. We expect correct averaging to happen automagically due to use of minibatches):

\[ \nabla_{\Lambda} \mathcal{L}(\Theta, \Lambda) \approx \frac{1}{L} \sum_{l=1}^L \nabla_{\Lambda} \log q(z^{(l)} \mid x, \Lambda) \log \frac{p(x, z^{(l)} \mid \Theta)}{q(z^{(l)} \mid x, \Lambda)} \]

For model parameters \(\Theta\) gradients look even simpler, as we don’t need to differentiate w.r.t. expectation distribution’s parameters:

\[ \begin{align} \nabla_{\Theta} \mathcal{L}(\Theta, \Lambda) &= \mathbb{E}_{q} \nabla_{\Theta} \log p(x, z \mid \Theta) \approx \frac{1}{L} \sum_{l=1}^L \nabla_{\Theta} \log p(x, z^{(l)} \mid \Theta) \end{align} \]

We can even “naturalize” these gradients by premultiplying by the inverse Fisher Information Matrix \(\mathcal{I}(\Lambda)^{-1}\). And that’s it! Much simpler than before, right? Of course, there’s no free lunch, so there must be a catch… And there is: performance of stochastic optimization methods crucially depends on the variance of gradient estimators. It makes perfect sense: the higher the variance — the less information about the step direction we get. And unfortunately, in practice the aforementioned estimator based on the score function has impractically high variance. Luckily, in Monte Carlo community there are many variance reductions techniques known, we now describe some of them.

The first technique we’ll describe is Rao-Blackwellization. The idea is simple: if it’s possible to compute the expectation w.r.t. some of random variables, you should do it. If you think of it, it’s an obvious advice as you essentially reduce amount of randomness in your Monte Carlo estimates. But let’s put it more formally: we use chain rule to rewrite joint expectation as marginal expectation of conditional one:

\[ \mathbb{E}_{X, Y} f(X, Y) = \mathbb{E}_X \left[ \mathbb{E}_{Y \mid X} f(X, Y) \right] \]

Let’s see what happens with variance (in scalar case) when we estimate expectation of \(\mathbb{E}_{Y \mid X} f(X, Y)\) instead of expectation of \(f(X, Y)\):

\[ \begin{align} \text{Var}_X(\mathbb{E}_{Y \mid X} f(X, Y)) &= \mathbb{E} (\mathbb{E}_{Y \mid X} f(X, Y))^2 - (\mathbb{E}_{X, Y} f(X, Y))^2 \\ &= \text{Var}_{X,Y}(f(X, Y)) - \mathbb{E}_X \left(\mathbb{E}_{Y \mid X} f(X, Y)^2 - (\mathbb{E}_{Y \mid X} f(X, Y))^2 \right) \\ &= \text{Var}_{X,Y}(f(X, Y)) - \mathbb{E}_X \text{Var}_{Y\mid X} (f(X, Y)) \end{align} \]

This formula says that Rao-Blackwellizing an estimator reduces its variance by \(\mathbb{E}_X \text{Var}_{Y\mid X} (f(X, Y))\). Indeed, you can think of this term as of a measure of how much information \(Y\) contains about \(X\) that’s relevant to computing \(f(X, Y)\). Suppose \(Y = X\): then you have \(\mathbb{E}_X f(X, X)\), and taking expectation w.r.t. \(Y\) does not reduce amount of randomness in the estimator. And this is what the formula tells us as \(\text{Var}_{Y \mid X} f(X, Y)\) would be 0 in this case. Here’s another example: suppose \(f\) does not use \(X\) at all: then only randomness in \(Y\) affects the estimate, and after Rao-Blackwellization we expect the variance to drop to 0. And the formula agrees with out expectations as \(\mathbb{E}_X \text{Var}_{Y \mid X} f(X, Y) = \text{Var}_Y f(X, Y)\) for any \(X\) since \(f(X, Y)\) does not depend on \(X\).

Next technique is Control Variates, which is slightly less intuitive. The idea is that we can add zero-mean function \(h(X)\) that’ll preserve the expectation, but reduce the variance. Again, for a scalar case

\[ \text{Var}(f(X) - \alpha h(X)) = \text{Var}(f(X)) - 2 \alpha \text{Cov}(f(X), h(X)) + \alpha^2 \text{Var}(f(X)) \]

Optimal \(\alpha^* = \frac{\text{Cov}(f(X), h(X))}{\text{Var}(f(X))}\). This formula reflects an obvious fact: if we want to reduce the variance, \(h(X)\) must be correlated with \(f(X)\). Sign of correlation does not matter, as \(\alpha^*\) will adjust. BTW, in reinforcement learning \(\alpha\) is called baseline.

As we already have learned, \(\mathbb{E}_{q(z \mid x, \Lambda)} \nabla_\Lambda \log q(z \mid x, \Lambda) = 0\), so the score function is a good candidate for \(h(x)\). Therefore our estimates become

\[ \nabla_{\Lambda} \mathcal{L}(\Theta, \Lambda) \approx \frac{1}{L} \sum_{l=1}^L \nabla_{\Lambda} \log q(z^{(l)} \mid x, \Lambda) \circ \left(\log \frac{p(x, z^{(l)} \mid \Theta)}{q(z^{(l)} \mid x, \Lambda)} - \alpha^* \right) \]

Where \(\circ\) is pointwise multiplication and \(\alpha\) is a vector of \(|\Lambda|\) components with \(\alpha_i\) being a baseline for variational parameter \(\Lambda_i\):

\[ \alpha^*_i = \frac{\text{Cov}(\nabla_{\Lambda_i} \log q(z \mid x, \Lambda)\left( \log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda) \right), \nabla_{\Lambda_i} \log q(z \mid x, \Lambda))}{\text{Var}(\nabla_{\Lambda_i} \log q(z \mid x, \Lambda)\left( \log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda) \right))} \]

Neural Variational Inference and Learning

Hoooray, neural networks! In this section I’ll briefly describe a variance reduction technique coined by A. Mnih and K. Gregor in Neural Variational Inference and Learning in Belief Networks. The idea is surprisingly simple: why not learn a baseline \(\alpha\) using a neural network?

\[ \nabla_{\Lambda} \mathcal{L}(\Theta, \Lambda) \approx \frac{1}{L} \sum_{l=1}^L \nabla_{\Lambda} \log q(z^{(l)} \mid x, \Lambda) \circ \left(\log \frac{p(x, z^{(l)} \mid \Theta)}{q(z^{(l)} \mid x, \Lambda)} - \alpha^* - \alpha(x) \right) \]

Where \(\alpha(x)\) is a neural network trained to minimize

\[ \mathbb{E}_{q(z \mid x, \Lambda)} \left( \log \frac{p(x, z^{(l)} \mid \Theta)}{q(z^{(l)} \mid x, \Lambda)} - \alpha^* - \alpha(x) \right)^2 \]

What’s the motivation of this objective? The gradient step of \(\nabla_\Lambda \mathcal{L}(\Theta, \Lambda)\) can be seen as pushing \(q(z\mid x, \Lambda)\) towards \(p(x, z \mid \Theta)\). Since \(q\) has to be normalized like any other proper distribution, it’s actually pushed towards the true posterior \(p(z \mid x, \Theta)\). We can rewrite the gradient \(\nabla_\Lambda \mathcal{L}(\Theta, \Lambda)\) as

\[ \begin{align} \nabla_{\Lambda} \mathcal{L}(\Theta, \Lambda) &= \mathbb{E}_{q} \left[\nabla_{\Lambda} \log q(z \mid x, \Lambda) \left(\log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda) \right) \right] \\ &= \mathbb{E}_{q} \left[\nabla_{\Lambda} \log q(z \mid x, \Lambda) \left(\log p(z \mid x, \Theta) - \log q(z \mid x, \Lambda) + \log p(x \mid \Theta) \right) \right] \end{align} \]

While this additional \(\log p(x \mid \Theta)\) term does not contribute to the expectation, it affects the variance on the estimator. Therefore, \(\alpha(x)\) is supposed to estimate the marginal log-likelihood \(\log p(x \mid \Theta)\).

The paper also lists several other variance reduction techniques that can be used in combination with the neural network-based baseline:

  • Constant baseline — analogue of Control Variates, uses running average of \(\log p(x, z \mid \Theta) - \log q(z \mid x, \Lambda)\) as a baseline
  • Variance normalization — normalizes the learning signal to unit variance, equivalent to adaptive learning rate
  • Local learning signals — falls out of the scope of this post as requires it model-specific analysis and alternations, and can’t be used in Blackbox regime
comments powered by Disqus