Random notes on Computer Science, Mathematics and Software Engineering

Neural Variational Inference: Scaling Up

In the previous post I covered well-established classical theory developed in early 2000-s. Since then technology has made huge progress: now we have much more data, and a great need to process it and process it fast. In big data era we have huge datasets, and can not afford too many full passes over it, which might render classical VI methods impractical. Recently M. Hoffman et al. dissected classical Mean-Field VI to introduce stochasticity right into its heart, which resulted in Stochastic Variational Inference.

Stochastic Variational Inference

We start with model assumptions: we have 2 types of latent variables, the global latent variable \(\beta\) and a bunch of local variables \(z_n\) for each observation \(x_n\). Recalling our GMM example, \(\beta\) can be thought of as a mixture weights \(\pi\), and \(z_n\) are membership indicators, as previously. These variables are assumed to come from some exponential family distribution:

\[ p(x_n, z_n \mid \beta) = h(x_n, z_n) \exp \left( \beta^T t(x_n, z_n) - a_l(\beta) \right) \\ \\ p(\beta) = h(\beta) \exp(\alpha^T t(\beta) - a_g(\alpha)) \]

Where \(t(\cdot)\) and \(h(\cdot)\) are overloaded by their argument, so \(t(\beta)\) and \(t(z_{nj})\) correspond to two different functions. \(t(\cdot)\) gives a natural parameter and also sufficient statistics. \(a_g\) and \(a_l\) are log-normalizing constants which for exponential family distributions have an interesting property, namely, the gradient of the log-normalizing constant is the expectation of sufficient statistics: \(\nabla_\alpha a_g(\alpha) = \mathbb{E} t(\beta)\).

From these assumptions we can derive complete conditionals (conditional distribution given all other hidden variables and observables) for \(\beta\) and \(z_{nj}\):

\[ \begin{align} p(\beta \mid x, z, \alpha) &\propto \prod_{n=1}^N p(x_n, z_n \mid \beta) p(\beta \mid \alpha) \\ &= h(\beta) \prod_{n=1}^N h(x_n, z_n) \exp \left( \beta^T \sum_{n=1}^N t(x_n, z_n) - N a_l(\beta) + \alpha^T t(\beta) - a_g(\alpha) \right) \\ &\propto h(\beta) \exp \left( \eta_g(x, z, \alpha)^T t(\beta) \right) \end{align} \]

Where \(t(\beta) = (\beta, -a_l(\beta))\), \(\eta_g(x, z, \alpha) = (\alpha_1 + \sum_{n=1}^N t(x_n, z_n), \alpha_2 + N)\). We see that the (unnormalized) posterior distribution for \(\beta\) has the same functional form as the (unnormalized) prior \(p(\beta)\), therefore after normalization it’d be

\[ p(\beta \mid x, z, \alpha) = h(\beta) \exp \left( \eta_g(x, z, \alpha)^T t(\beta) - a_g(\eta_g(x, z, \alpha)) \right) \]

The same applies to local variables \(z_{nj}\):

\[ p(z_{nj} \mid x_n, z_{n,-j}, \beta) \propto h(z_{nj}) \exp \left( \eta_l(x_n, z_{n,-j}, \beta)^T t(z_{nj}) \right) \] Hence \[ p(z_{nj} \mid x_n, z_{n,-j}, \beta) = h(z_{nj}) \exp \left( \eta_l(x_n, z_{n,-j}, \beta)^T t(z_{nj}) - a_m(\eta_l(x_n, z_{n,-j}, \beta)) \right) \]

Even though we’ve managed to find the complete conditional for \(\beta\), it might be intractable to find the posterior for all latent variables \(p(\beta, z \mid x, \alpha)\). We therefore turn to the mean field approximation:

\[ q(z, \beta \mid \Lambda) = q(\beta \mid \lambda) \prod_{n=1}^N \prod_{j=1}^J q(z_{nj} \mid \phi_{nj}) \]

We assume these marginal distributions come from the exponential family:

\[ q(\beta \mid \lambda) = h(\beta) \exp(\lambda^T t(\beta) - a_g(\lambda)) \\ q(z_{nj} \mid \phi_{nj}) = h(z_{nj}) \exp(\phi_{nj}^T t(z_{nj}) - a_m(\phi_{nj})) \]

Let’s find the optimal variational parameters now by optimizing the ELBO \(\mathcal{L}(\Theta, \Lambda)\) (\(\Theta\) is model parameters, \(\alpha\), and \(\Lambda\) contains variational parameters \(\phi\) and \(\lambda\)) by \(\lambda\) and \(\phi_{nj}\):

\[ \begin{align} \mathcal{L}(\lambda) &= \mathbb{E}_{q} \left( \log p(x, z, \beta) - \log q(\beta) - \log q(z) \right) = \mathbb{E}_{q} \left( \log p(\beta \mid x, z) - \log q(\beta) \right) + \text{const} \\ &= \mathbb{E}_{q} \left( \eta_g(x, z, \alpha)^T t(\beta) - \lambda^T t(\beta) + a_g(\lambda) \right) + \text{const} \\ &= \left(\mathbb{E}_{q(z)} \eta_g(x, z, \alpha) - \lambda \right)^T \mathbb{E}_{q(\beta)} t(\beta) + a_g(\lambda) + \text{const} \\ &= \left(\mathbb{E}_{q(z)} \eta_g(x, z, \alpha) - \lambda \right)^T \nabla_\lambda a_g(\lambda) t(\beta) + a_g(\lambda) + \text{const} \end{align} \]

Where we used aforementioned property of exponential family distributions: \(\nabla_\lambda a_g(\lambda) = \mathbb{E}_{q(\beta)} t(\beta)\). The gradient then is \[ \nabla_\lambda \mathcal{L}(\lambda) = \nabla_\lambda^2 a_g(\lambda) \left(\mathbb{E}_{q(z)} \eta_g(x, z, \alpha) - \lambda \right) \]

After setting it to zero we get an update for global latent variables: \(\lambda = \mathbb{E}_{q(z)} \eta_g(x, z, \alpha)\). Following the same reasoning we derive the optimal update for \(\phi_{nj}\):

\[ \begin{align} \mathcal{L}(\phi_{nj}) &= \mathbb{E}_{q} \left( \log p(z_{nj} \mid x_n, z_{n,-j}, \beta) - \log q(z_{nj}) \right) + \text{const} \\ &= \mathbb{E}_{q} \left( \eta_l(x_n, z_{n,-j}, \beta)^T t(z_{nj}) - \phi_{nj}^T t(z_{nj}) + a_m(\phi_{nj})\right) + \text{const} \\ &= \left(\mathbb{E}_{q(\beta) q(z_{n,-j})} \eta_l(x_n, z_{n,-j}, \beta) - \phi_{nj} \right)^T \mathbb{E}_{q(z_{nj})} t(z_{nj}) + a_m(\phi_{nj}) + \text{const} \\ \end{align} \]

The gradient then is \(\nabla_{\phi_{nj}} \mathcal{L}(\phi) = \nabla_{\phi_{nj}}^2 a_m(\phi_{nj}) \left(\mathbb{E}_{q(\beta) q(z_{n,-j})} \eta_l(x_n, z_{n,-j}, \beta) - \phi_{nj} \right)\), and the update is \(\phi_{nj} = \mathbb{E}_{q(\beta) q(z_{n,-j})} \eta_l(x_n, z_{n,-j}, \beta)\).

So far we found mean-field updates, as well as corresponding gradients of the ELBO for variational parameters \(\lambda\) and \(\phi_{nj}\). Next step is to transform these gradients into natural gradients. Intuitively, classical gradient defines local linear approximation, where the notion of locality comes from the Euclidean space. However, parameters influence the ELBO only through distributions \(q\), so we might like to alter our idea of locality based on how much the distributions change. This is what natural gradient does: it defines local linear approximation where locality means small distance (symmetrized KL-divergence) between distributions. There’s great formal explanation in the paper, and if you want to read more on that matter, I refer you to a great post by Roger Grosse, Differential geometry for machine learning.

The natural gradient can be obtained from the usual gradient using a simple linear transformation:

\[ \nabla_\lambda^\text{N} f(\lambda) = \mathcal{I}(\lambda)^{-1} \nabla_{\lambda} f(\lambda) \]

Where \(\mathcal{I}(\lambda) := \mathbb{E}_{q(\beta \mid \lambda)} \left[ \nabla_\lambda \log q(\beta \mid \lambda) (\nabla_\lambda \log q(\beta \mid \lambda))^T \right]\) is Fisher Information Matrix. Here I considered parameter \(\lambda\) of the distribution \(q(\beta \mid \lambda)\), you got the idea. For the exponential family distribution this Information Matrix takes an especially simple form:

\[ \begin{align} \mathcal{I}(\lambda) &= \mathbb{E}_q (t(\beta) - \nabla_\lambda a_g(\lambda)) (t(\beta) - \nabla_\lambda a_g(\lambda))^T = \mathbb{E}_q (t(\beta) - \mathbb{E}_q t(\beta)) (t(\beta) - \mathbb{E}_q t(\beta))^T \\ &= \text{Cov}_q (t(\beta)) = \nabla_\lambda^2 a_g(\lambda) \end{align} \]

Where we’ve used another differential identity for exponential family. All these calculations lead us to the natural gradients of ELBO for variational parameters:

\[ \nabla_\lambda^\text{N} \mathcal{L}(\lambda) = \mathbb{E}_{q(z)} \eta_g(x, z, \alpha) - \lambda \\ \nabla_{\phi_{nj}}^\text{N} \mathcal{L}(\lambda) = \mathbb{E}_{q(\beta) q(z_{n,-j})} \eta_l(x_n, z_{n,-j}, \beta) - \phi_{nj} \]

Surprisingly, computation-wise calculating natural gradients is even simpler that calculating classical gradients! There’s an interesting connection between the mean-field update and a natural gradient step. In particular, if we make a step along the natural gradient with step size equal 1, we’d get \(\lambda^{\text{new}} = \lambda^{\text{old}} + (\mathbb{E}_{q(z)} \eta_g(x, z, \alpha) - \lambda^{\text{old}}) = \mathbb{E}_{q(z)} \eta_g(x, z, \alpha)\). The same applies to parameters \(\phi\). This means that the mean field updates are exactly natural gradient steps, and vice versa.

Recall, we have derived mean field updates by finding a minima of KL-divergence with the true posterior, that is in just one step (one update) we arrive at minimum. Obviously, we have the same in the natural gradient formulation, when just one step brings us to the optimum.

Now, the last component is stochasticity itself. So far we have only played a little with mean-field update scheme, and discovered its connection to the natural gradient optimization. We note that we have 2 parameters: local \(\phi_{nj}\) and global parameter \(\lambda\). The first one is easy to optimize over as it depends only on one, \(n\)th sample \(x_n\). The second one, though, needs to incorporate information from all the samples, which is computationally prohibitive in large scale regime. Luckily, now once we know the equivalence between mean-field update and natural gradient step, we can borrow ideas from stochastic optimization to make this process more scalable.

Let’s first reformulate the ELBO to include the sum over samples \(x_n\):

\[ \begin{align} \mathcal{L}(\Theta, \Lambda) &= \mathbb{E}_{q} \left[ \log p(\beta \mid \alpha) - \log q(\beta \mid \lambda) + \sum_{n=1}^N \left(\log p(x_n, z_n \mid \beta) - \log q(z_n \mid \phi_n) \right) \right] \\ & = \mathbb{E}_{q} \left[ \log p(\beta \mid \alpha) - \log q(\beta \mid \lambda) + N \mathbb{E}_{I} \left(\log p(x_I, z_I \mid \beta) - \log q(z_I \mid \phi_I) \right) \right] \end{align} \]

Where \(I \sim \text{Unif}\{1, \dots, N\}\) — uniformly distribution index of a sample. Now let’s estimate \(\mathcal{L}\) using a sample \(S\) (assume \(N\) divides by sample size \(|S|\)) of uniformly chosen indices, this’d result in an unbiased estimator (it’s gradient would also be unbiased, so we can maximize the true ELBO by maximizing the estimate). Author of the paper start with single-sample derivation and then extend it to minibatches, but I decided I’d go straight to the minibatch case:

\[ \begin{align} \mathcal{L}_S(\Theta, \Lambda) & := \mathbb{E}_{q} \left[ \log p(\beta \mid \alpha) - \log q(\beta \mid \lambda) + \frac{N}{|S|} \sum_{i \in S} \left(\log p(x_i, z_i \mid \beta) - \log q(z_i \mid \phi_i) \right) \right] \\ & = \mathbb{E}_{q} \left[ \log p(\beta \mid \alpha) - \log q(\beta \mid \lambda) + \sum_{n=1}^{N / |S|} \sum_{i \in S} \left(\log p(x_i, z_i \mid \beta) - \log q(z_i \mid \phi_i) \right) \right] \end{align} \]

This estimate is exactly \(\mathcal{L}(\Theta, \Lambda)\) calculated on sample consisting of \(\{x_i, z_i\}_{i \in S}\) repeated \(N / |S|\) times. Hence its natural gradient w.r.t. \(\lambda\) is

\[ \nabla_\lambda^\text{N} \mathcal{L}_S(\lambda) = \mathbb{E}_{q(z)} \eta_g(\{x_S\}_{n=1}^{N/|S|}, \{z_S\}_{n=1}^{N/|S|}, \alpha) - \lambda \\ \]

One important note: for stochastic optimization we can’t use constant step size. As Robbins-Monro conditions suggest, we need to use schedule \(\rho_t\) such that \(\sum \rho_t = \infty\) and \(\sum \rho_t^2 < \infty\). Then the update \(\lambda^{\text{new}} = \lambda^{\text{old}} + \rho_t \nabla_\lambda^\text{N} \mathcal{L}_S(\lambda) = (1 - \rho_t) \lambda^{\text{old}} + \rho_t \mathbb{E}_{q(z)} \eta_g(\{x_S\}_{n=1}^{N/|S|}, \{z_S\}_{n=1}^{N/|S|}, \alpha)\)

Finally we have the following optimization scheme:
  • Start with random initialization for \(\lambda^{(0)}\)
  • For \(t\) from 0 to MAX_ITER
    1. Sample \(S \sim \text{Unif}\{1, \dots, N\}^{|S|}\)
    2. For each sample \(i \in S\) update the local variational parameter \(\phi_{i,j} = \mathbb{E}_{q(\beta) q(z_{i,-j})} \eta_l(x_i, z_{i,-j}, \beta)\)
    3. Replicate the sample \(N / |S|\) times and compute the global update \(\hat \lambda = \mathbb{E}_{q(z)} \eta_g(\{x_S\}_{n=1}^{N/|S|}, \{z_S\}_{n=1}^{N/|S|}, \alpha)\)
    4. Update the global update \(\lambda^{(t+1)} = (1-\rho_t) \lambda^{(t)} + \rho_t \hat \lambda\)
comments powered by Disqus