B.log

Random notes mostly on Machine Learning

Thoughts on Mutual Information: Formal Limitations

This posts continues the discussion started in the Thoughts on Mutual Information: More Estimators. This time we’ll focus on drawbacks and limitations of these bounds.

Let’s start with a elephant in the room: a year ago an interesting preprint has been uploaded to arxiv: Formal Limitations on the Measurement of Mutual Information in which authors essentially argue that if you don’t know any densities (the hardest case, according to my hierarchy), then any distribution-free high-confidence lower bound on the MI would require \(\exp(\text{MI})\) number of samples and thus black-box MI lower bounds should be deemed impractical.

Formal Limitations

The paper mounts a massive attack on distribution-free lower bounds on the Mutual Information. Not only McAllester and Stratos show that existing bounds are inferior, but also kill any blackbox lower bounds on the KL divergence. The core result that warrants impossibility of cheap and good lower bounds on the Mutual Information is the Theorem 2, which states (in a slightly reformulated notation)

Theorem: Let \(B\) be any distribution-free high-confidence lower bound on \(\mathbb{H}[p(x)]\) computed from a sample \(x_{1:N} \sim p(x)\) More specifically, let \(B(x_{1:N}, \delta)\) be any real-valued function of a sample and a confidence parameter \(\delta\) such that for any \(p(x)\), with probability at least \((1 − \delta)\) over a draw of \(x_{1:N}\) from \(p(x)\), we have \[\mathbb{H}[p(x)] \ge B(x_{1:N}, δ).\] For any such bound, and for \(N \ge 50\) and \(k \ge 2\), with probability at least \(1 − \delta − 1.01/k\) over the draw of \(x_{1:N}\) we have \[B(x_{1:N}, \delta) ≤ \log(2k N^2)\]

Indeed, since (in discrete case) \(I(X, X) = H(X)\) 1, any good2 lower bound on the MI would give a good lower bound on the entropy, and the theorem above says there are no such bounds (only those that are either exponentially expensive to compute or are not high-confidence or are not black-box). Authors then argue that one can have good estimators if they forgo the lower bound guarantee and settle with an estimate that is neither a lower or an upper bound. However, this is undesirable in many cases, especially when we’d like to compare two numbers.

Unfortunately, I found the paper hard to digest and as far as I know, it’s still not published, so probably we should be cautious about the presented result. Nevertheless, I’ll show below that several often used bounds do indeed seem to have this limitation.

The Nguyen-Wainwright-Jordan Bound

The process we’ve followed so far to derive a lower bound on the MI has been somewhat cumbersome: we first decomposed the MI into some expectation and then used fancy bounds on some of the terms. An alternative and easier approach is to recall that the MI is a certain KL divergence take any off-the-shelf lower bounds on the KL divergence.

One such lower bound can be obtained using the Fenchel conjugate functions (Nguyen et al., alternatively see the f-GANs paper):

\[ KL(p(x) \mid\mid q(x)) \ge \mathbb{E}_{p(x)} f(x) - \mathbb{E}_{q(x)} \exp(f(x)) + 1 \]

Where \(f(x)\) (a critic) is any function that takes \(x\) as input and outputs a scalar. The optimal choice can be shown to be \(f^*(x) = \ln \tfrac{p(x)}{q(x)}\). And all is nice, except the lurking menace of the \(\exp\) term. Consider a Monte Carlo estimate in the case of optimal critic (\(x_{1:N} \sim p(x), y_{1:M} \sim q(y)\)): \[ \frac{1}{N} \sum_{n=1}^N \ln \frac{p(x_n)}{q(x_n)} - \frac{1}{M} \sum_{m=1}^M \frac{p(y_m)}{q(y_m)} + 1 \] The first term is exactly the Monte Carlo estimate of the KL divergence, while the second (the balancing term as it’s counterweights the first one) in expectation gives 1. However, ratio of densities might take on extremely large values and in general has enormous variance. Indeed, the variance of the balancing term is

\[ \begin{align*} \mathbb{V}_{q(y_{1:M})} \left[ \frac{1}{M} \sum_{m=1}^M \frac{p(y_m)}{q(y_m)} \right] &= \frac{1}{M} \mathbb{V}_{q(y)} \left[ \frac{p(y)}{q(y)} \right] = \frac{1}{M} \E_{q(y)} \left[ \left(\frac{p(y)}{q(y)} \right)^2 - 1 \right] \\ &= \frac{1}{M} \E_{p(y)} \left[ \frac{p(y)}{q(y)} - 1 \right] = \frac{\E_{p(y)} \left[ \exp \log \frac{p(y)}{q(y)} \right] - 1}{M} \\ & \ge \frac{\exp \E_{p(y)} \left[ \log \frac{p(y)}{q(y)} \right] - 1}{M} = \frac{\exp\left(\text{KL}(p(y) \mid\mid q(y))\right) - 1}{M} \end{align*} \]

So, one can see that indeed, the NWJ bound can’t give us high-confidence few-samples lower bound on any KL, not only the MI. This is because the second term would bias the bound by contributing large zero-mean noise. The only way to drive the magnitude of this noise down is to take more samples, and as the analysis above shows, number of samples should be exponential in the KL (The statement could be made more precise by appealing to the Chebyshev’s inequality).

The Donsker-Varadhan Estimator

Donsker and Varadhan have proposed essentially a tighter bound on the KL divergence, of the form \[ KL(p(x) \mid\mid q(x)) \ge \E_{p(x)} f(x) - \log \E_{q(x)} \exp(f(x)) + 1 \] With the same \(f^*(x) = \ln \tfrac{p(x)}{q(x)}\) being an optimal critic. There are two key differences to the previous bound: the first is that it uses a logarithm in front of the balancing term, preventing it from contributing huge variance (but this variance still has to go somewhere, and we’ll see where it goes), and the second (and the most important) is that this bound is no loger amendable to (unbiased) Monte Carlo estimation due to the logarithm outside of the expectation. In practice people just take an empirical average under the expectation thus obtaining a biased estimate (which in general is neither a lower nor an upper bound):

\[ \frac{1}{N} \sum_{n=1}^N \ln \frac{p(x_n)}{q(x_n)} - \log \frac{1}{M} \sum_{m=1}^M \frac{p(y_m)}{q(y_m)} \]

It can be shown that the balancing term now has huge bias and is always negative. It’s also easy to see that the biased converges to 0 as we take more samples \(M\), so one might hope that with moderately many samples we’d have some tolerable bias. Well, this doesn’t seem to be the case.

Take a closer look at the bias of the balancing term \[ \E_{q(y_{1:M})} \log \frac{1}{M} \sum_{m=1}^M \frac{p(y_m)}{q(y_m)} \] It can be seen as an asymptotically unbiased estimate (a lower bound for all finite \(M\)) of the log-normalizing constant of \(p(y)\) (which is 1 since it’s already normalized) and is well-studied. In particular, Domke and Sheldon have shown (Theorem 3) that, essentially, the bias of the balancing term converges to 0 with the following rate:

\[ O\left(M^{-1} \mathbb{V}_{q(y)} \left[ \frac{p(y)}{q(y)} \right] \right) \]

Which 1) shows us where the variance has gone; 2) hints that in order to eliminate the bias we’d again need to take exponential number of samples.I don’t know what happens to the actual variance of the balancing term, but it can only make things worse.

The Contrastive-Predictive-Coding Bound

Let’s leave the realm of lower bounds on KL now. Previously I have already presented the InfoNCE bound: \[ \text{MI}[p(x, z)] \ge \E_{p(z_{0:K})} \E_{p(x | z_0)} \log \frac{p(x|z_0)}{\frac{1}{K+1} \sum_{k=0}^K p(x|z_k)} \]

Importantly, this bound does not have access to any marginals of the \(p(x,z)\) joint. It’s easy to show that this lower bound is upper bounded by \(\log (K+1)\), which confirms the thesis:

\[ \begin{align*} \E_{p(z_{0:K})} \E_{p(x | z_0)} \log \frac{p(x|z_0)}{\frac{1}{K+1} \sum\limits_{k=0}^K p(x|z_k)} & = \log (K+1) + \E_{p(z_{0:K})} \E_{p(x | z_0)} \log \frac{p(x|z_0)}{\sum\limits_{k=0}^K p(x|z_k)} \\ & \le \log (K+1) \end{align*} \] Which is due to the log’s argument being between 0 and 1. So this \(\log(K+1)\) upper bound on the lower bound means that if the true MI is much larger than this value, the bound will be very loose.

Given all these negative results, one might ask themselves if knowing the marginal \(p(z)\) would do much better. Consider the “known prior” case:

\[ \text{MI}[p(x, z)] \ge \E_{p(x, z_0)} \E_{q_\phi(z_{1:K}|x)} \log \frac{\hat\varrho_\eta(x|z_0)}{\frac{1}{K+1} \sum_{k=0}^K \hat\varrho_\eta(x|z_k) \frac{p(z_k)}{q_\phi(z_k|x)}} \]

Then we have \[ \begin{align*} \E_{\substack{p(x, z_0) \\ q_\phi(z_{1:K}|x)}} & \log \frac{\hat\varrho_\eta(x|z_0)}{\frac{1}{K+1} \sum_{k=0}^K \hat\varrho_\eta(x|z_k) \frac{p(z_k)}{q_\phi(z_k|x)}} \\ & = \log(K+1) + \E_{\substack{p(x, z_0) \\ q_\phi(z_{1:K}|x)}} \left[ \log \frac{\hat\varrho_\eta(x|z_0) \frac{p(z_0)}{q_\phi(z_0|x)}}{\sum_{k=0}^K \hat\varrho_\eta(x|z_k) \frac{p(z_k)}{q_\phi(z_k|x)}} - \log \frac{p(z_0)}{q_\phi(z_0|x)} \right] \\ & \le \log(K+1) + \E_{\substack{p(x, z_0) \\ q_\phi(z_{1:K}|x)}} \log \frac{q_\phi(z_0|x) p(x|z_0)}{p(z_0) p(x|z_0)} \\ & = \log(K+1) + \E_{p(x, z_0)} \log \frac{q_\phi(z|x) p(x|z)}{p(z|x) p(x)} \\ & = \log(K+1) + \E_{p(x, z)} \log \frac{p(x|z)}{p(x)} - \E_{p(x, z)} \log \frac{p(z|x)}{q_\phi(z|x)} \\ & = \log(K+1) + \text{MI}[p(x, z)] - \text{KL}(p(x,z) \mid\mid q_\phi(z|x) p(x)) \end{align*} \]

Which shows that by choosing \(q_\phi(z|x) = p(z)\) we essentially threw the baby out with the bathwater. Yes, \(K\) still needs to be exponential, but this time not in the original MI, but rather in \(\text{KL}(p(x,z) \mid\mid q_\phi(z|x) p(x))\), which can be made much smaller with a good choice of the variational distribution \(q_\phi(z|x)\).

Also recall that we can reparametrize the bound in terms of \(\hat\rho_\eta(x, z) = \hat\varrho_\eta(x|z) p(z)\)

\[ \begin{align*} \text{MI}[p(x, z)] & \ge \E_{p(x, z_0)} \E_{q_\phi(z_{1:K}|x)} \log \frac{\hat\rho_\eta(x, z_0)}{\frac{1}{K+1} \sum_{k=0}^K \frac{\hat\rho_\eta(x, z_k) }{q_\phi(z_k|x)}} - \E_{p(z)} \log p(z) \\ \text{MI}[p(x, z)] & = \E_{p(x, z)} \log p(z|x) - \E_{p(z)} \log p(z) \end{align*} \]

Hence \[ \E_{p(x, z_0)} \log p(z_0|x) \ge \E_{p(x, z_0)} \E_{q_\phi(z_{1:K}|x)} \log \frac{\hat\rho_\eta(x, z_0)}{\frac{1}{K+1} \sum_{k=0}^K \frac{\hat\rho_\eta(x, z_k) }{q_\phi(z_k|x)}} \] Now we can choose the \(p(x)\) marginal freely. Let \(p(x) = \delta(x - \tilde{x})\). Then \[ \E_{p(z_0|\tilde{x})} \log p(z_0|\tilde{x}) \ge \E_{p(z_0|\tilde{x})} \E_{q_\phi(z_{1:K}|\hat{x})} \log \frac{\hat\rho_\eta(\tilde{x}, z_0)}{\frac{1}{K+1} \sum_{k=0}^K \frac{\hat\rho_\eta(\tilde{x}, z_k) }{q_\phi(z_k|\tilde{x})}} \]

So in a sense (and this is indeed how we derived the bound in the first place), this “known prior” lower bound is based on a distribution-free upper bound on the entropy of \(z|x\) and avoids lower bounding any entropies.

But why does it work in practice?

Despite the negative results above, there’s a lot of empirical evidence of successful applications of all of the distribution-free bounds presented above. So what’s going on? Quite possibly, this was the question folks from the Google Brain have asked themselves in their recent preprint On Mutual Information Maximization for Representation Learning. In this paper researchers investigated representations obtained by the Mutual Information maximization principle. For example, one finding is that tighter MI estimates surprisingly led to worse performance. Overall, the apparent conclusion of the paper is that MI estimation perspective does not seem to explain the observed behavior. Authors then suggest the metric learning perspective and reinterpret lower bounds on the MI as metric learning objectives.

Conclusion

All this evidence suggests that estimating the MI is an even harder problem than we used to think. In particular, blackbox MI estimation seems to be intractable in non-toy cases. Luckily, representation learning works nevertheless, probably due to a different phenomena.

However, for many problems it’d be really nice to have a way to quantify the dependence between \(x\) and \(z\). A possible approach here is to consider different divergences between the joint \(p(x, z)\) and the product of marginals \(p(x) p(z)\). For example, one possible direction is to replace the KL divergence with some other \(f\)-divergence (see Lautum Information, for example), or, Wasserstein distance. And there’s already some works in this direction: Wasserstein Dependency Measure for Representation Learning explores, unsurprisingly, the Wasserstein distance, or, Learning deep representations by mutual information estimation and maximization considers Jensen-Shannon divergence instead of the KL divergence. It’d curious to see some theorems / efficient bounds for these and other divergences.

Finally, one additional contribution of the Formal Limitations paper is the impossibility or good lower bounds on the KL divergence (supported by the reasoning above). This raises the question: given the whole family of \(f\)-divergences and their Fenchel conjugate-based blackbox lower bounds, do all of them exhibit such computationally unfavorable behavior? If no, which ones do?

Thanks to Ben Poole, Evgenii Egorov and Arseny Kuznetsov for valuable discussions.


  1. Or, in my notation, \(\text{MI}[p(x) \delta(z - x)] = \mathbb{H}[p(x)]\).

  2. By “good lower bound” I mean distribution-free high-confidence lower bound that uses some decent amount of samples, say, polynomial or even linear in the MI.

comments powered by Disqus