B.log

Random notes mostly on Machine Learning

A simpler derivation of f-GANs

I have been looking at $f$-GANs derivation doing some of my research, and found an easier way to derive its lower bound, without invoking convex conjugate functions.

$f$-GANs are a generalization of standard GANs to arbitrary $f$-divergence. Given a convex function $f$, $f$-divergence, in turn, can be used to measure "difference" between the data distribution $p_\text{data}(x)$ and our model $q(x)$:

$$ D_f(p_\text{data}(x) \mid\mid q(x)) = \E_{q(x)} f \left( \frac{p_\text{data}(x)}{q(x)} \right) $$

Of course, we don't know the data-generating distribution $p_\text{data}(x)$. Moreover, in a typical GAN setting $q(x)$ is an implicit model, and thus its density is not known either 1. Thus, to make things tractable GANs employ tractable sample-based lower bounds 2.

Simple Derivation

Our derivation is based on the following simple inequality, a very well-known fact for convex functions3, namely that a convex function is always greater than its tangent or is equal to at the point of tangency (denoted $r(x)$):

$$ f\left( \frac{p_\text{data}(x)}{q(x)} \right) \ge f\left( r(x) \right) + f'\left( r(x) \right) \left( \frac{p_\text{data}(x)}{q(x)} - r(x) \right) $$

For any non-negative function $r(x)$. Now we take the expected value

$$ \begin{align*} D_f(p_\text{data}(x) \mid\mid q(x)) &\ge \E_{q(x)} \left[ f\left( r(x) \right) + f'\left( r(x) \right) \left( \frac{p_\text{data}(x)}{q(x)} - r(x) \right) \right] \\ & = \E_{q(x)} f\left( r(x) \right) + \E_{p_\text{data}(x)} f'\left( r(x) \right) - \E_{q(x)} f'\left( r(x) \right) r(x) \tag{1} \end{align*} $$

This bound has several nice properties:

  1. It does not require knowing densities, only having samples.
  2. By construction, it's a lower bound for all $r(x)$.
  3. Plugging $r^*(x) = \frac{p_\text{data}(x)}{q(x)}$ recovers the $f$-divergence.

However, this formula looks different from the one in the $f$-GANs paper. Are they related? We'll now show they're exactly the same.

$f$-GANs Derivation

The original derivation, which probably should be attributed to "Estimating divergence functionals and the likelihood ratio by convex risk minimization" by XuanLong Nguyen, Martin J. Wainwright, and Michael I. Jordan (2010) is based on the convex conjugate concept. The convex conjugate $f^*$ for a function $f$ is $$ f^*(t) = \sup_{u \in \text{dom}(f)} \left[ u t - f(u) \right] $$

Nguen et al. have shown the following variational characterization of the $f$-divergence 4: $$ D_f(p(x) \mid\mid q(x)) = \sup_{T(x)} \left[ \E_{p(x)} T(x) - \E_{q(x)} f^*(T(x)) \right] $$ Where $f^*(t)$ is the aforementioned convex conjugate for $f(t)$, and the supremum is taken over all functions. However, we're safe to restrict the range of $T(x)$ to those values where $f^*$ is finite, that is, the set $\mathcal{V} = \{t \in \mathbb{R} \mid f^*(t) < \infty \}$. Now this form is already amendable to practical applications, just make $T(x)$ a neural network whose activation respects $\mathcal{V}$ and maximize the lower bound w.r.t. its parameters. The question then is how to construct this activation.

Skipping the more detailed analysis, we note that the optimal $T(x)$ is known to be $$T^*(x) = f'\left( \frac{p(x)}{q(x)} \right)$$ Since we're only interested in approximating the optimal value, we might as well consider the following parametrization for $T(x)$ (using a non-negative function $r(x)$): $$ T(x) = f'(r(x)) $$ Which gives us the following objective

$$ D_f(p(x) \mid\mid q(x)) = \sup_{r(x)} \left[ \E_{p(x)} f'(r(x)) - \E_{q(x)} f^*(f'(r(x))) \right] \tag{2} $$

Finally, we use an important property of convex conjugate functions: $$ \begin{align*} f^*(f'(r(x))) &= \sup_u \left[ u f'(r(x)) - f(u) \right] \\ &= \sup_u \left[ u f'(r(x)) - r(x) f'(r(x)) - f(u) \right] + r(x) f'(r(x)) \\ &= \sup_u \left[ \underbrace{f(r(x)) + f'(r(x)) (u - r(x)) - f(u)}_{\le 0 \text{ due to convexity of } f} \right] + r(x) f'(r(x)) - f(r(x)) \\ &= r(x) f'(r(x)) - f(r(x)) \\ \end{align*} $$ Where in the last line we've used the fact that for a convex $f(t)$ its tangent at any point is always a lower bound, and the surpremum of 0 is achieved for $u = r(x)$.

Now we plug this equivalent formula into the objective and obtain $$ \begin{align*} D_f(p(x) \mid\mid q(x)) & = \sup_{r(x)} \left[ \mathbb{E}_{p(x)} f'(r(x)) - \mathbb{E}_{q(x)} \left( r(x) f'(r(x)) - f(r(x)) \right) \right] \\ & = \sup_{r(x)} \left[ \mathbb{E}_{q(x)} f(r(x)) + \mathbb{E}_{p(x)} f'(r(x)) - \mathbb{E}_{q(x)} r(x) f'(r(x)) \right] \end{align*} $$

Which exactly recovers the formula (1). Moreover, the conjugate identity holds for all realizations of random variables involved, so not only the bounds (1) and (2) are the same, but their stochastic estimates are too5.

Conclusion

The presented derivation and objective form is interesting for several reasons. First, by design the optimal "discriminator" $r^*(x) = \frac{p_\text{data}(x)}{q(x)}$ is independent of the particular $f$-divergence used. Second, thinking of $r(x)$ as of importance weights approximation gives an intuitive understanding of different terms in the objective (1): the first term is $f$-divergence approximation that uses learned density ratio $r(x)$ instead of the actual density ratio. The rest two terms balance the first one to ensure the lower bound guarantee. In particular, the last term uses $r(x)$ as an importance weight to "approximate" the second one so that they cancel out when the $r(x)$ is optimal. The last, but not least, the presented derivation is simpler.


  1. Actually, most of the time it does not exist at all. But that's a story for another time. 

  2. Although, a lower bound on the loss is not something you'd like to minimize, this is how things are done in the GAN realm. 

  3. We assume $f$ is differentiable here, but if it's not, the statement still holds with $f'$ being replaced with a subgradient. 

  4. Nguen et al. use a bit different convention for $f$-divergences, namely $$D_f(p(x) \mid\mid q(x)) = \E_{p(x)} f\left(\frac{q(x)}{p(x)}\right)$$ 

  5. As long as you use the same samples to estimate different expectations over the distribution $q(x)$. 

comments powered by Disqus