B.log

Random notes mostly on Machine Learning

Neural Variational Inference: Scaling Up

In the previous post I covered well-established classical theory developed in early 2000-s. Since then technology has made huge progress: now we have much more data, and a great need to process it and process it fast. In big data era we have huge datasets, and can not afford too many full passes over it, which might render classical VI methods impractical. Recently M. Hoffman et al. dissected classical Mean-Field VI to introduce stochasticity right into its heart, which resulted in Stochastic Variational Inference.

Stochastic Variational Inference

We start with model assumptions: we have 2 types of latent variables, the global latent variable β and a bunch of local variables zn for each observation xn. Recalling our GMM example, β can be thought of as a mixture weights π, and zn are membership indicators, as previously. These variables are assumed to come from some exponential family distribution:

p(xn,znβ)=h(xn,zn)exp(βTt(xn,zn)al(β))p(β)=h(β)exp(αTt(β)ag(α))

Where t() and h() are overloaded by their argument, so t(β) and t(znj) correspond to two different functions. t() gives a natural parameter and also sufficient statistics. ag and al are log-normalizing constants which for exponential family distributions have an interesting property, namely, the gradient of the log-normalizing constant is the expectation of sufficient statistics: αag(α)=Et(β).

From these assumptions we can derive complete conditionals (conditional distribution given all other hidden variables and observables) for β and znj:

p(βx,z,α)n=1Np(xn,znβ)p(βα)=h(β)n=1Nh(xn,zn)exp(βTn=1Nt(xn,zn)Nal(β)+αTt(β)ag(α))h(β)exp(ηg(x,z,α)Tt(β))

Where t(β)=(β,al(β)), ηg(x,z,α)=(α1+n=1Nt(xn,zn),α2+N). We see that the (unnormalized) posterior distribution for β has the same functional form as the (unnormalized) prior p(β), therefore after normalization it'd be

p(βx,z,α)=h(β)exp(ηg(x,z,α)Tt(β)ag(ηg(x,z,α)))

The same applies to local variables znj:

p(znjxn,zn,j,β)h(znj)exp(ηl(xn,zn,j,β)Tt(znj)) Hence p(znjxn,zn,j,β)=h(znj)exp(ηl(xn,zn,j,β)Tt(znj)am(ηl(xn,zn,j,β)))

Even though we've managed to find the complete conditional for β, it might be intractable to find the posterior for all latent variables p(β,zx,α). We therefore turn to the mean field approximation:

q(z,βΛ)=q(βλ)n=1Nj=1Jq(znjϕnj)

We assume these marginal distributions come from the exponential family:

q(βλ)=h(β)exp(λTt(β)ag(λ))q(znjϕnj)=h(znj)exp(ϕnjTt(znj)am(ϕnj))

Let's find the optimal variational parameters now by optimizing the ELBO L(Θ,Λ) (Θ is model parameters, α, and Λ contains variational parameters ϕ and λ) by λ and ϕnj:

L(λ)=Eq(logp(x,z,β)logq(β)logq(z))=Eq(logp(βx,z)logq(β))+const=Eq(ηg(x,z,α)Tt(β)λTt(β)+ag(λ))+const=(Eq(z)ηg(x,z,α)λ)TEq(β)t(β)+ag(λ)+const=(Eq(z)ηg(x,z,α)λ)Tλag(λ)t(β)+ag(λ)+const

Where we used aforementioned property of exponential family distributions: λag(λ)=Eq(β)t(β). The gradient then is λL(λ)=λ2ag(λ)(Eq(z)ηg(x,z,α)λ)

After setting it to zero we get an update for global latent variables: λ=Eq(z)ηg(x,z,α). Following the same reasoning we derive the optimal update for ϕnj:

L(ϕnj)=Eq(logp(znjxn,zn,j,β)logq(znj))+const=Eq(ηl(xn,zn,j,β)Tt(znj)ϕnjTt(znj)+am(ϕnj))+const=(Eq(β)q(zn,j)ηl(xn,zn,j,β)ϕnj)TEq(znj)t(znj)+am(ϕnj)+const

The gradient then is ϕnjL(ϕ)=ϕnj2am(ϕnj)(Eq(β)q(zn,j)ηl(xn,zn,j,β)ϕnj), and the update is ϕnj=Eq(β)q(zn,j)ηl(xn,zn,j,β).

So far we found mean-field updates, as well as corresponding gradients of the ELBO for variational parameters λ and ϕnj. Next step is to transform these gradients into natural gradients. Intuitively, classical gradient defines local linear approximation, where the notion of locality comes from the Euclidean space. However, parameters influence the ELBO only through distributions q, so we might like to alter our idea of locality based on how much the distributions change. This is what natural gradient does: it defines local linear approximation where locality means small distance (symmetrized KL-divergence) between distributions. There's great formal explanation in the paper, and if you want to read more on that matter, I refer you to a great post by Roger Grosse, Differential geometry for machine learning.

The natural gradient can be obtained from the usual gradient using a simple linear transformation:

λNf(λ)=I(λ)1λf(λ)

Where I(λ):=Eq(βλ)[λlogq(βλ)(λlogq(βλ))T] is Fisher Information Matrix. Here I considered parameter λ of the distribution q(βλ), you got the idea. For the exponential family distribution this Information Matrix takes an especially simple form:

I(λ)=Eq(t(β)λag(λ))(t(β)λag(λ))T=Eq(t(β)Eqt(β))(t(β)Eqt(β))T=Covq(t(β))=λ2ag(λ)

Where we've used another differential identity for exponential family. All these calculations lead us to the natural gradients of ELBO for variational parameters:

λNL(λ)=Eq(z)ηg(x,z,α)λϕnjNL(λ)=Eq(β)q(zn,j)ηl(xn,zn,j,β)ϕnj

Surprisingly, computation-wise calculating natural gradients is even simpler that calculating classical gradients! There's an interesting connection between the mean-field update and a natural gradient step. In particular, if we make a step along the natural gradient with step size equal 1, we'd get λnew=λold+(Eq(z)ηg(x,z,α)λold)=Eq(z)ηg(x,z,α). The same applies to parameters ϕ. This means that the mean field updates are exactly natural gradient steps, and vice versa.

Recall, we have derived mean field updates by finding a minima of KL-divergence with the true posterior, that is in just one step (one update) we arrive at minimum. Obviously, we have the same in the natural gradient formulation, when just one step brings us to the optimum.

Now, the last component is stochasticity itself. So far we have only played a little with mean-field update scheme, and discovered its connection to the natural gradient optimization. We note that we have 2 parameters: local ϕnj and global parameter λ. The first one is easy to optimize over as it depends only on one, nth sample xn. The second one, though, needs to incorporate information from all the samples, which is computationally prohibitive in large scale regime. Luckily, now once we know the equivalence between mean-field update and natural gradient step, we can borrow ideas from stochastic optimization to make this process more scalable.

Let's first reformulate the ELBO to include the sum over samples xn:

L(Θ,Λ)=Eq[logp(βα)logq(βλ)+n=1N(logp(xn,znβ)logq(znϕn))]=Eq[logp(βα)logq(βλ)+NEI(logp(xI,zIβ)logq(zIϕI))]

Where IUnif{1,,N} — uniformly distribution index of a sample. Now let's estimate L using a sample S (assume N divides by sample size |S|) of uniformly chosen indices, this'd result in an unbiased estimator (it's gradient would also be unbiased, so we can maximize the true ELBO by maximizing the estimate). Author of the paper start with single-sample derivation and then extend it to minibatches, but I decided I'd go straight to the minibatch case:

LS(Θ,Λ):=Eq[logp(βα)logq(βλ)+N|S|iS(logp(xi,ziβ)logq(ziϕi))]=Eq[logp(βα)logq(βλ)+n=1N/|S|iS(logp(xi,ziβ)logq(ziϕi))]

This estimate is exactly L(Θ,Λ) calculated on sample consisting of {xi,zi}iS repeated N/|S| times. Hence its natural gradient w.r.t. λ is

λNLS(λ)=Eq(z)ηg({xS}n=1N/|S|,{zS}n=1N/|S|,α)λ

One important note: for stochastic optimization we can't use constant step size. As Robbins-Monro conditions suggest, we need to use schedule ρt such that ρt= and ρt2<. Then the update λnew=λold+ρtλNLS(λ)=(1ρt)λold+ρtEq(z)ηg({xS}n=1N/|S|,{zS}n=1N/|S|,α)

Finally we have the following optimization scheme:

  • Start with random initialization for λ(0)
  • For t from 0 to MAX_ITER
    1. Sample SUnif{1,,N}|S|
    2. For each sample iS update the local variational parameter ϕi,j=Eq(β)q(zi,j)ηl(xi,zi,j,β)
    3. Replicate the sample N/|S| times and compute the global update λ^=Eq(z)ηg({xS}n=1N/|S|,{zS}n=1N/|S|,α)
    4. Update the global update λ(t+1)=(1ρt)λ(t)+ρtλ^