This is the final post of the stochastic computation graphs series. Last time we discussed models with discrete relaxations of stochastic nodes, which allowed us to employ the power of reparametrization.
These methods, however, posses one flaw: they consider different models, thus introducing inherent bias – your test time discrete model will be doing something different from what your training time model did. Therefore in this post we'll get back to the REINFORCE aka Score Function estimator, and see if we can fix its problems.
Back to REINFORCE
REINFORCE^{1} estimator arises from the following identity:
$$ \begin{align*} \nabla_\theta \mathcal{F}(\theta) & = \nabla_\theta \mathbb{E}_{p(z\theta)} f(z) = \nabla_\theta \int f(z) p(z\theta) dz = \int f(z) \nabla_\theta p(z\theta) dz \\ &= \int f(z) \nabla_\theta \log p(z\theta) p(z\theta) dz = \mathbb{E}_{p(z\theta)} f(z) \nabla_\theta \log p(z\theta) \end{align*} $$
This allows us to estimate the gradient of the expected objective using Monte Carlo estimation:
$$ \hat{\nabla}_\theta^{\text{SF}} \mathcal{F} = \frac{1}{L} \sum_{l=1}^L f(z^{(l)}) \nabla_\theta \log p(z^{(l)}\theta) , \quad \text{ where } z^{(l)} \sim p(z\theta) $$
The downside of this method is that it does not use the gradient information of the objective $f$. This is useful in cases where we don't have access to such information, for example, in Reinforcement Learning. However, when working with Stochastic Computation Graphs, we usually do have gradients $\nabla_z f(z)$ available, and I believe methods that intelligently use this gradient should perform better.
However, the score function estimator does not use this information, yet it's an unbiased estimator of the true gradient. What's the problem then? The problem is in impractically high variance that requires one to obtain some astronomical amount of samples to reduce the variance and make optimization actually feasible ^{3}. Recall the intuition behind this from the first post: a REINFORCE estimator $\hat{\nabla}_\theta^{\text{SF}} \mathcal{F}$ is just $L$ singlesample gradients averaged together and each single singlesample gradient $f(z) \nabla_\theta \log p(z\theta)$ essentially implements a random search: it wants to increase the probability of a given sample $z$ proportionally to $f(z)$, and if the later is negative, then reduce it. Each of the samples then pulls the probability towards itself, and this lack of a consensus is the source of the problem.
However, despite REINFORCE being essentially a random search in disguise, not all is lost yet. As we shall see, one can extend it with lots of different tricks, greatly reducing the variance.
Control Variates
One method for reducing the variance in statistics (and the major one for this post) is the method of Control Variates, that's based on the idea that if you have two negatively correlated random variables, their sum could have lower variance. Indeed, let's assume we have random variables $X$ and $Y$ such that $\mathbb{D}(X) = \sigma^2_x$, $\mathbb{D}(Y) = \sigma^2_y$ and $\text{Cov}(X, Y) = \tau \sigma_x \sigma_y$. Then
$$ \mathbb{D}(X + Y)  \mathbb{D}(X) = \mathbb{D}(Y) + 2 \text{Cov}(X, Y) = \sigma^2_y  2 \tau \sigma_x \sigma_y = \sigma_y (\sigma_y  2 \tau \sigma_x) $$
So if $\sigma_y < 2 \tau \sigma_x$, then the sum $X + Y$ will have lower variance than the $X$ alone. Of course, $Y$ needs to be centered $\mathbb{E} Y = 0$ to not bias the $X$, but centering does not affect the variance.
We'll be considering control variates of a special form: $b(z) \nabla_\theta \log p(z\theta)$ where $b$ is a baseline and can be either a scalar of a vector (the multiplication is pointwise then)^{2}. This leads to the estimator of the following form
$$ \hat{\nabla}_\theta^\text{SF} \mathcal{F}(\theta) = (f(z)  b(z)) \nabla_\theta \log p(z\theta) $$
Here I used only one sample to simplify the notation (and will be doing so from now on), in practice you always can average several samples, though that probably won't help you much ^{3}. However, by using a baseline we might have introduced unwanted bias in our gradient estimation. Let's see:
$$ \begin{align*} \mathbb{E}_{p(z\theta)} \hat{\nabla}_\theta^\text{SF} \mathcal{F}(\theta) &= \mathbb{E}_{p(z\theta)} (f(z)  b(z)) \nabla_\theta \log p(z\theta) = \nabla_\theta \mathbb{E}_{p(z\theta)} \left[ f(z)  b(z) \right] \\ &= \nabla_\theta \mathbb{E}_{p(z\theta)} f(z)  \nabla_\theta \mathbb{E}_{p(z\theta)} b(z) \end{align*} $$
Looks like we did indeed bias the estimator! In order to be able to reduce the variance and keep the estimator unbiased, we should remove the introduced bias from the $\hat{\nabla}_\theta^\text{SF} \mathcal{F}(\theta)$:
$$ \hat{\nabla}_\theta^\text{SF} \mathcal{F}(\theta) = (f(z)  b(z)) \nabla_\theta \log p(z\theta) + \nabla_\theta \mathbb{E}_{p(z\theta)} b(z) $$
This, of course, only works if you can compute the last term analytically. Estimating it with REINFORCE won't help you, as you'd then recover the standard Score Function estimator.
The easiest baseline one can think of is a constant baseline. It doesn't introduce any bias: indeed offsetting the target $f(z)$ should not (and does not) change the true gradient of the expectation. However, as we've seen in the first part of the series, it can mess with the variance. So, let's use a baseline that would minimize the total variance of the adjusted estimator:
$$ \hat{\nabla}_\theta^\text{SFconst} \mathcal{F}(\theta) = (f(z)  b) \nabla_\theta \log p(z\theta) $$
The total variance along all $D$ coordinates of this gradient estimator is $$ \begin{align*} \sum_{d=1}^D &\mathbb{D}\left[\hat{\nabla}_{\theta_d}^\text{SFconst} \mathcal{F}(\theta)\right] = \sum_{d=1}^D \mathbb{D}\left[(f(z)  b) \nabla_{\theta_d} \log p(z\theta)\right] \\ &= \sum_{d=1}^D \Bigl( {\scriptsize \mathbb{D}\left[f(z) \nabla_{\theta_d} \log p(z\theta)\right]  2b \text{Cov}\left[f(z) \nabla_{\theta_d} \log p(z\theta), \nabla_{\theta_d} \log p(z\theta)\right] + b^2 \mathbb{D}\left[\nabla_{\theta_d} \log p(z\theta)\right] } \Bigr) \end{align*} $$
The formula does look a bit terrifying, but we only care about $b$ at the moment, and the variance is quadratic in b. The optimal value thus is obtained by minimizing the quadratic formula:
$$ b = \frac{\sum_{d=1}^D \text{Cov}\left[f(z) \nabla_{\theta_d} \log p(z\theta), \nabla_{\theta_d} \log p(z\theta)\right]}{\sum_{d=1}^D \mathbb{D}\left[\nabla_{\theta_d} \log p(z\theta)\right]} = \frac{\sum_{d=1}^D \mathbb{E}\left[f(z) (\nabla_{\theta_d} \log p(z\theta))^2\right]}{\sum_{d=1}^D \mathbb{E}\left[(\nabla_{\theta_d} \log p(z\theta))^2\right]} $$
Where we used the fact that $\mathbb{E} \nabla_{\theta_d} \log p(z\theta) = 0$ for any $d$. The moments in the formula can not be computed analytically, but one can estimate them using running averages.
In the same fashion one can derive the optimal vectorvalued baseline $b$ (and even the matrixvalued!), consisting of individual baselines for each dimension of the gradient:
$$ b_d = \frac{\mathbb{E}\left[f(z) (\nabla_{\theta_d} \log p(z\theta))^2\right]}{\mathbb{E}\left[(\nabla_{\theta_d} \log p(z\theta))^2\right]} $$
Selfcritical Learning
Ideally, the baseline approximates $f(z)$ as good as possible without using the actual sample $z$ ^{7}. However, it can still depend on $\theta$ without introducing any bias:
$$ \mathbb{E}_{p(z\theta)} b(\theta) \nabla_\theta \log p(z\theta) = b(\theta) \mathbb{E}_{p(z\theta)} \nabla_\theta \log p(z\theta) = 0 $$
So, how can we use $\theta$ and $f$ to approximate $f(z)$ without touching the sample $z$ itself? Authors of the Selfcritical Sequence Training for Image Captioning paper suggested to replace the stochastic $z$ with a deterministic most probable outcome:
$$ \hat{z} = \text{argmax}_k \; p(z = k  \theta) $$
And then we use $f(\hat z)$ as a baseline:
$$ \hat{\nabla}_\theta^\text{SFSC} \mathcal{F}(\theta) = (f(z)  f(\hat{z})) \nabla_\theta \log p(z\theta) $$
This is a very interesting baseline. Unlike the standard REINFORCE, where each sample pulls probability towards itself, this estimator pulls probability in only for samples that are better than the most likely one. Conversely, for samples that are worse than the most likely one, this estimator pushes probability away. Unsurprisingly, this baseline is just a constant baseline that automatically adapts to whether a probability should be increased or decreased for a given sample $z$.
Special Cases
When $f$ is of some special form, one can design ad hoc variance reduction techniques. In particular, we'll consider two of them:
NVIL
NVIL stands for Neural Variational Inference and Learning after a paper it was introduced in. Essentially, it combines tricks to reduce the variance people of Reinforcement Learning came up with to reduce the variance of REINFORCE (which they usually call the Policy Gradients method). The paper introduced three methods: signal centering, variance normalization and local learning signals. The variance normalization normalizes the gradient by a running average estimate of its standard deviation – this is what, say, the Adam optimizer would do for you automatically, so let's not stop here.
Signal centering can be considering as a baseline amortization for a contextdependent case. Let me decypher that: oftentimes stochastic random variable $z$ depends on some context $x$ (for example, state of the environment in RL, or the observation $x$ in the amortized variational inference), then the expected objective becomes $\mathcal{F}(\thetax) = \mathbb{E}_{p(zx,\theta)} f(x, z)$. Then we can make the baseline $b$ depend on $x$ as well without any sacrifice:
$$ \hat{\nabla}_\theta^\text{SFNVIL} \mathcal{F}(\theta) = (f(x, z)  b(x)) \nabla_\theta \log p(zx, \theta) $$
We could reuse the formulas from the previous section, but that'd require us to store independent baseline for each $x$ in the trainset – doesn't scale. Therefore instead we'll amortize the baseline using a neural network $b(x\varphi)$ with parameters $\varphi$ and learn it by minimizing the expected squared error ^{4} $$\varphi^* = \text{argmin}_\phi \mathbb{E}_{p(zx,\theta)} (b(x\varphi)  f(x, z))^2$$
The local learning signal allows you to exploit some nontrivial structure in $f(z)$ (and $p(z\theta)$). Namely, suppose we divided our $z$ into $N$ chunks: $z = (z_1, \dots, z_N)$, and $f$ is a sum of rewards on prefixes: $f(z) = \sum_{n=1}^N f_n(z_{<n})$ ^{5}. It's then obvious that choice of later blocks $z_n$ layers does not influence the prior rewards $f_m$ for $m < n$. Indeed, one can see that the true gradient obeys the following:
$$ \begin{align*} \nabla_\theta \mathcal{F}(\theta) &= \mathbb{E}_{p(z_{\le N}  \theta)} \sum_{n=1}^N \left(\sum_{k=1}^N f_k(z_{\le k})\right) \nabla_\theta \log p(z_nz_{<n}, \theta) \\ &= \sum_{n=1}^N \sum_{k=1}^N \mathbb{E}_{p(z_{\le N}  \theta)} \left[ f_k(z_{\le k}) \nabla_\theta \log p(z_nz_{<n}, \theta)\right] \\ &= {\scriptsize \sum_{n=1}^N \left(\sum_{k=1}^{n1} \mathbb{E}_{p(z_{\le n}  \theta)} \left[f_k(z_{\le k}) \nabla_\theta \log p(z_nz_{<n}, \theta) \right] + \sum_{k=n}^N \mathbb{E}_{p(z_{\le N}  \theta)} \left[f_k(z_{\le k}) \nabla_\theta \log p(z_nz_{<n}, \theta) \right]\right)} \\ &= {\scriptsize \sum_{n=1}^N \left(\mathbb{E}_{z_{<n}} \left[\left(\sum_{k=1}^{n1} f_k(z_{\le k}) \right) \overbrace{\mathbb{E}_{z_nz_{<n}} \nabla_\theta \log p(z_nz_{<n}, \theta)}^{=0}\right] + \sum_{k=n}^N \mathbb{E}_{z_{\le N}} \left[f_k(z_{\le k}) \nabla_\theta \log p(z_nz_{<n}, \theta) \right]\right)} \\ &= \mathbb{E}_{p(z\theta)} \left[{\scriptsize \sum_{n=1}^N \sum_{k=n}^N f_k(z_{\le k}) \nabla_\theta \log p(z_nz_{<n}, \theta)} \right] \\ \end{align*} $$
Naturally, the part of the gradient corresponding to $n$th chunk is weighted by the total reward we'd get after deciding upon $z_n$, since the previous rewards do not depend on $z_n$.
Combined with the contextdependent baseline the estimator would be
$$ \hat{\nabla}_\theta^\text{SFNVIL} \mathcal{F}(\theta) = {\scriptsize \sum_{n=1}^N \sum_{k=n}^N (f_k(x, z_{\le k})  b_k(x)) \nabla_\theta \log p(z_nx, z_{<n}, \theta)} \\ $$
Moreover, the baseline can be made dependent on some previous $z$ since such baseline does not introduce any bias:
$$ \begin{align*} \mathbb{E}_{p(zx, \theta)} & \sum_{n=1}^N \sum_{k=n}^N b_{n,k}(x, z_{<n}) \nabla_\theta \log p(z_nx, z_{<n}, \theta) \\ & = \sum_{n=1}^N \sum_{k=n}^N \mathbb{E}_{p(z_{<n}x, \theta)} \mathbb{E}_{p(z_nx, z_{<n}, \theta)} b_{n,k}(x, z_{<n}) \nabla_\theta \log p(z_nx, z_{<n}, \theta) \\ & = \sum_{n=1}^N \sum_{k=n}^N \mathbb{E}_{p(z_{<n}x, \theta)} b_{n,k}(x, z_{<n}) \overbrace{\mathbb{E}_{p(z_nx, z_{<n}, \theta)} \nabla_\theta \log p(z_nx, z_{<n}, \theta)}^{=0} = 0 \end{align*} $$
However, learning ${n \choose 2}$ different baselines is computationally demanding, so one would probably at least assume some common underlying structure.
VIMCO
Another case of using the particular structure is the the VIMCO (Variational inference for Monte Carlo objectives) estimator. Again, consider a case of the latent variable $z$ being divided in $N$ chunks, but now $z_n$ are independent identically distributed samples: $z_n \sim p(z\theta)$. Suppose $f$ has the following form: $f(z) = g\left(\tfrac{1}{N} {\scriptsize\sum_{n=1}^N} h(z_n)\right)$. Then the REINFORCE gradient estimate would be:
$$ \begin{align*} \nabla_\theta \mathcal{F}(\theta) &= \mathbb{E}_{p(z  \theta)} \sum_{n=1}^N g\left(\tfrac{1}{N} {\scriptsize\sum_{n=1}^N} h(z_n)\right) \nabla_\theta \log p(z_n\theta) \\ \end{align*} $$
The problem with this estimator is that $g(\dots)$ is a common multiplier, and defines a magnitude of the gradient for each of $N$ samples, without any distinction, despite some samples $z_n$ might have turned out better than others. We would like to penalise such samples lesser, performing a kind of credit assignment.
Just as in the previous section, we can consider baselines $b_n$ that depend on samples $z$. To keep them from biasing the gradient estimate we need to make sure each $b_n$ does not depend on $z_n$. However, it can depend on all other $z$ (denoted $z_{n}$) since they are independent of $z_n$. Thus the bias of such baseline is:
$$ \begin{align*} \mathbb{E}_{p(z  \theta)} \sum_{n=1}^N b_n(z_{n}) \nabla_\theta \log p(z_n\theta) = \sum_{n=1}^N \mathbb{E}_{p(z_{n}  \theta)} b_n(z_{n}) \overbrace{ \mathbb{E}_{p(z_{n}  \theta)} \nabla_\theta \log p(z_n\theta) }^{=0} = 0 \end{align*} $$
Authors of the VIMCO paper also suggested an interesting trick to avoid learning $b_n(z_{n})$: we want $b_{n}(z_{n})$ to approximate $f(z)$ as good as possible and we actually have access to everything we need to compute $f(z)$ except the term that depends on $z_n$: $h(z_n)$. However, all samples $z$ are identically distributed, so we can approximate this missing term as the average of others:
$$ \hat h_n(z_{n}) = \frac{1}{N1} \sum_{j \not= n} h(z_j) \stackrel{\text{hopefully}}{\approx} h(z_n) $$
Then our baseline becomes
$$ b_n(z_{n}) = g\left(\tfrac{{\scriptsize\sum_{j \not= n}} h(z_j) + \hat h_n(z_{n})}{N} \right) $$
One can also consider other averaging schemes for $\hat h_n(z_{n})$ to approximate $h(z_n)$: geometric, harmonic, Minkowski, etc.
MuProp
So far we have been considering only baselines $b$ that have zero expected value and thus do not bias the gradient estimator. However, there are cases when we actually know the baseline's expectation and can compensate the introduced bias.
The MuProp paper suggests to use a Taylor expansion as a baseline, provided we can compute certain moments of the distribution $p(z\theta)$ in a closed form. For example, if $p(z\theta) = \mathcal{N}(z \mid \mu(\theta), \Sigma(\theta))$, then we already have access to 1st and 2nd moments  the mean and the covariance matrix.
Consider a Taylor expansion of $f(z)$ at $\mu(\theta) = \mathbb{E}_{p(z\theta)} z$ of the first order:
$$ b_\theta(z) = f(\mu(\theta)) + \nabla_z f(\mu(\theta))^T (z  \mu(\theta)) $$
Then the bias introduced by such baseline would be
$$ \begin{align*} \mathbb{E}_{p(z\theta)} & b_\theta(z) \nabla_\theta \log p(z\theta) \\ &= \mathbb{E}_{p(z\theta)} \left[ f(\mu(\theta)) + \nabla_z f(\mu(\theta))^T (z  \mu(\theta)) \right] \nabla_\theta \log p(z\theta) \\ &= \mathbb{E}_{p(z\theta)} \left[ \nabla_z f(\mu(\theta))^T z + f(\mu(\theta))  \nabla_z f(\mu(\theta))^T \mu(\theta) \right] \nabla_\theta \log p(z\theta) \\ &= \mathbb{E}_{p(z\theta)} \left[ \nabla_z f(\mu(\theta))^T z \nabla_\theta \log p(z\theta) \right] \\ & \quad\quad\quad + \left[ f(\mu(\theta))  \nabla_z f(\mu(\theta))^T \mu(\theta) \right] \overbrace{\mathbb{E}_{p(z\theta)} \nabla_\theta \log p(z\theta)}^{=0} \\ &= \nabla_z f(\mu(\theta))^T \mathbb{E}_{p(z\theta)} \left[ z \nabla_\theta \log p(z\theta) \right] = \nabla_z f(\mu(\theta))^T \nabla_\theta \mathbb{E}_{p(z\theta)} \left[ z \right] \\ & = \nabla_z f(\mu(\theta))^T \nabla_\theta \mu(\theta) = \nabla_\theta f(\mu(\theta)) \end{align*} $$
So the (1st order) MuProp estimator has the following form:
$$ \hat{\nabla}_\theta^\text{SFMuProp} \mathcal{F}(\theta) = (f(z)  f(\mu(\theta))  \nabla_z f(z)^T (z  \mu(\theta))) \nabla_\theta \log p(z\theta) + \nabla_\theta f(\mu(\theta)) $$
An appealing property is that not only does this gradient estimator is unbiased, but it also uses the gradients of $f$ in the $\nabla_\theta f(\mu(\theta))$, essentially propagating the learning signal though the mean of the random variable $z$, and then correcting for the introduced bias with REINFORCE.
One could, of course, envision a secondorder baseline, especially considering we have the covariance matrix readily available for many distributions. However, such baseline would be more computationally demanding, requiring us to compute the Hessian matrix of $f(z)$ and evaluate it at some point, which would cost at least $\text{dim}(z)^2$ computations. Higher order expansions would require even more computations, thus it's hard to achive high nonlinearity in the baseline using MuProp alone ^{6}.
REBAR
REBAR^{8} is a clever way to use the GumbelSoftmax (aka Concrete) Relaxation as a baseline.
A naive approach to the task would be to recall the GumbelMax trick: as we have already seen, this trick gives us the reparametrization, albeit not a differentiable one. However, we can move the nondifferentiability into the $f(z)$, and then invoke REINFORCE to estimate gradient of average of the nondifferentiable function (from now on we will assume $z$ is a onehot vector and argmax is an operator that returns a onehot vector, indicating position of the maximal element in the input and overall will be abusing notation treating the same $z$ a onehot vector or a number depending on a context):
$$ \nabla_\theta \mathbb{E}_{p(z\theta)} f(z) = \nabla_\theta \mathbb{E}_{p(\zeta\theta)} f(\text{argmax} \zeta) = \mathbb{E}_{p(\zeta\theta)} f(\text{argmax} \zeta) \nabla_\theta \log p(\zeta\theta) $$
Where $\zeta_k$ is obtained by shifting an independent standard Gumbel r.v. $\gamma_k$ by a logit of $k$th probability:
$$ \zeta_k = \log p(z = k  \theta) + \gamma_k, \quad\quad \gamma_k \sim \text{Gumbel}(0, 1) $$
Thus $\zeta_k$ also has a Gumbel distribution: $\zeta_k \sim \text{Gumbel}(\log p(z = k  \theta), 1)$. Ok, so what have we bought ourselves here? So far it looks like we gained nothing but instead only complicated the whole thing with these extra $\zeta$s. However, we just obtained a crucial property: we separated nondifferentiability from the reparametrization. We now can sample continuous reparametrizeable $\zeta$s and the troublesome part – the argmax – is now a part of $f$. And this opens up a new way to use baselines with nonzero expectation:
$$ \mathbb{E}_{p(\zeta\theta)} (f(\text{argmax} \zeta)  b(\zeta)) \nabla_\theta \log p(\zeta  \theta) + \nabla_\theta \mathbb{E}_{p(\zeta\theta)} b(\zeta) $$
And the most interesting thing is that the bias correction term, $\nabla_\theta \mathbb{E}_{p(\zeta\theta)} b(\zeta)$, is differentiable and reparametrizable, and thus its gradient can be estimated with the reparametrization trick. Now, that's nice, but we can't just take any $b(\zeta)$ and hope for variance reduction. In order to actually benefit from such baseline, we need $b(\zeta) \approx f(\text{argmax} \zeta)$. Luckily, we already know a way to organize this: the GumbelSoftmax obtained nicely by setting $b(\zeta) = f(\text{softmax}_\tau(\zeta))$:
$$ \hat{\nabla}_\theta^\text{SFREBARnaive} \mathcal{F}(\theta) = (f(\text{argmax} \zeta)  b(\zeta)) \nabla_\theta \log p(\zeta  \theta) + \nabla_\theta f(\text{softmax}_\tau(\zeta)) $$
However, there's a reason I called this estimator naive. If you actually try implementing this estimator, you would hardly see any improvements. If you look closely, you'd notice that we actually increased the variance of the REINFORCE estimator by switching to $\zeta$s, and this increase might not be compensated by the GumbelSoftmax baseline we introduced.
I guess it all looks a bit confusing at this moment, so lets take a closer look at the original REINFORCE estimator and the naive REBAR without baseline:
$$ \begin{align*} \hat{\nabla}_\theta^\text{SF} \mathcal{F}(\theta) &= f(z) \nabla_\theta \log p(z  \theta) \\ \hat{\nabla}_\theta^\text{SFREBARnaivewithoutbaseline} \mathcal{F}(\theta) &= f(\text{argmax} \zeta) \nabla_\theta \log p(\zeta  \theta) \end{align*} $$
You'd think they're the same, however actually they're quite different. But not in the first terms, $f(z)$ and $f(\text{argmax} \zeta)$, as those are basically the same. It's the second term that's important to us: the vanilla REINFORCE has $\nabla_\theta \log p({\color{red} z}\theta)$, whereas our naive REBAR has $\nabla_\theta \log p({\color{red} \zeta}\theta)$. This seemingly innocent difference is a huge deal! To see why recall the REINFORCE intuition: it is not a gradient method, but rather a random search in disguise: it tries a bunch of points, and increases probabilities of those performing good. However, the major problem is that different $\zeta$s can lead to the same $z$: indeed the argmax takes on only finite number of different values, whereas there's continuum of different vectors $\zeta$. This, in result, means that our naive REBAR estimate would be trying some $\zeta$ (corresponding to some $z$) and then trying to pull the probability mass towards (or away from) this point, maybe undoing some useful work it did for a different $\zeta$ (but same $z$).
To fix this issue we need to stay in "space of $\nabla_\theta \log p(z\theta)$" – a control variate of the form $b(z) \nabla_\theta \log p(z\theta)$. And one is given with help of the following clever identity:
$$ \begin{align*} \nabla_\theta \mathbb{E}_{p(\zeta\theta)} b(\text{softmax}_\tau(\zeta)) &= \nabla_\theta \mathbb{E}_{p(z, \zeta\theta)} b(\text{softmax}_\tau(\zeta)) \\ &= \nabla_\theta \mathbb{E}_{p(z\theta)} \mathbb{E}_{p(\zetaz, \theta)} b(\text{softmax}_\tau(\zeta)) \\ &= \mathbb{E}_{p(z\theta)} \mathbb{E}_{p(\zetaz, \theta)} b(\text{softmax}_\tau(\zeta)) \nabla_\theta \log p(z\theta) \\& \quad + \mathbb{E}_{p(z\theta)} \nabla_\theta \mathbb{E}_{p(\zetaz, \theta)} b(\text{softmax}_\tau(\zeta)) \end{align*} $$
On the left hand side we have the usual GumbelSoftmax relaxed gradient which we can compute using the reparametrization. On the right hand size we have a REINFORCElike gradient  which is a good candidate for a baseline  and another weirdly looking term. We can rearrange the terms to express the bias of such a baseline through the other two terms:
$$ \begin{align*} \mathbb{E}_{p(z\theta)} \mathbb{E}_{p(\zetaz, \theta)} & b(\text{softmax}_\tau(\zeta)) \nabla_\theta \log p(z\theta) \\ = \nabla_\theta \mathbb{E}_{p(\zeta\theta)} & b(\text{softmax}_\tau(\zeta))  \mathbb{E}_{p(z\theta)} \nabla_\theta \mathbb{E}_{p(\zetaz, \theta)} b(\text{softmax}_\tau(\zeta)) \end{align*} $$
But what about that weirdly looking last term? Can it be estimated efficiently? First, note that we do not need to differentiate through $z$, the dependence through $z$ was already accounted for. The expectation we need to differentiate is taken over $p(\zetaz, \theta)$ which is a distribution over $\zeta$ such that $\text{argmax} \zeta = z$. A reassuring observation is that such random variable is continuous. Moreover, the restriction $\text{argmax} \zeta = z$ defines a connected region of $\mathbb{R}^K$, which means there does exist a differentible reparametrization for such random variable! We won't be deriving this reparametrization here, please refer to Chris Maddison's blog. That said, the reparametrization is
$$ \zeta_kz = \begin{cases} \log ( \log v_k), \quad\quad\quad & \text{if z = k}, \\ \log \left(\frac{\log v_k}{p(z=k\theta)}  \log v_z \right), \quad\quad\quad & \text{otherwise}. \end{cases} $$
Where $v \sim U[0,1]^K$ is a $K$dimensional standard uniform r.v. Now, having this reparametrization we can estimate both terms in the bias correction via the reparametrization trick, which leads to the following estimate (I use notation $\hat{z}z$ to mean singular object, the conditional relaxed variable, it's not $\hat{z}$ with some $z$ applied to it, and neither it's $b(\cdotz)$):
$$ \begin{align*} \hat{\nabla}_\theta^\text{SFREBAR} \mathcal{F}(\theta) = \left[f(z)  b(\hat{z}z) \right] \nabla_\theta \log p(z  \theta) + \nabla_\theta b(\hat{z})  \nabla_\theta b(\hat{z}z) \end{align*} $$ We use $u \sim U[0,1]^K$ to reparametrize $\zeta$ (which leads to both $z$ and $\hat{z}$), and $v \sim U[0,1]^K$ is used to reparametrize $\zetaz$, see the formula above. The quantities of the REBAR gradient estimate are computed as follows: $$ \begin{align*} z = \text{argmax} \zeta, \quad\quad \hat{z} = \text{softmax}_\tau(\zeta), \quad\quad \hat{z}z = \text{softmax}_\tau(\zetaz), \\ \zeta_k = \log p(z = k  \theta)  \log(\log u_k), \quad\quad \zeta_kz = \text{given above} \end{align*} $$
What about $b(\cdot)$? Authors use $b(\cdot) = \eta f(\cdot)$ where $\eta$ is some hyperparameter that regulates the strength of a baseline. But turns out, we can avoid hyperparameter search for this variable...
Hyperparameter learing and RELAX
An important observation is that the gradient estimator we've obtained is unbiased^{9}. That is, for any choice of hyperparameters $\tau$ (the GumbelSoftmax temperature) and $\eta$ the average value of our estimator is equal to the true gradient. Thus, we can actually learn their values! The only question is, well, which objective should we minimize? We can't minimize the problem's loss $f(\cdot)$, since we already have its gradient. The next logical step is to minimize the variance of the gradient estimator. $$ \text{Var}\left( \hat{\nabla}_\theta^\text{SFREBAR} \mathcal{F}(\theta) \right) = \sum_{i} \left( \mathbb{E} \left[\hat{\nabla}_{\theta_i}^\text{SFREBAR} \mathcal{F}(\theta)\right]^2  \left[ \mathbb{E} \hat{\nabla}_{\theta_i}^\text{SFREBAR} \mathcal{F}(\theta)\right]^2 \right) $$ Where the expectation is taken over all randomness. Moreover, since the estimator is unbiased, we can omit the 2nd term in the sum, since it'll be constant w.r.t. $\tau$ and $\eta$.
Thus the objective for $\tau$ and $\eta$ is $$ \begin{align*} \tau^*, \eta^* &= \text{argmin}_{\tau, \eta} \text{Var}\left( \hat{\nabla}_\theta^\text{SFREBAR} \mathcal{F}(\theta) \right) \\ &= \text{argmin}_{\tau, \eta} \mathbb{E} \sum_{i} \left[\hat{\nabla}_{\theta_i}^\text{SFREBAR} \mathcal{F}(\theta)\right]_2^2 = \text{argmin}_{\tau, \eta} \mathbb{E} \left\ \hat{\nabla}_{\theta}^\text{SFREBAR} \mathcal{F}(\theta) \right\^2 \end{align*} $$
This optimization problem can be solved using stochastic optimization. We first get a stochastic estimate of the gradient w.r.t. $\theta$, and then obtain an estimate of the gradient w.r.t. "hyperparameters" $\tau$ and $\eta$. Practical implementation is somewhat tricky, the MagicBox operator might be useful.
Finally, it's worth noticing that although we can't apply this estimator some scenarios like Reinforcement Learning (because we don't have access to $f(\cdot)$), it's possible to introduce a minor modification to overcome this issue. Remember the moment we decided to put $b(\cdot) = \eta f(\cdot)$? At this moment we could have made any other choice, for example consider $b(\cdot) = h_\eta(\cdot)$  a neural network with parameters $\eta$ that takes $\hat{z}$ as input and returns the same thig $f(\cdot)$ would return (a scalar in our case). Then we can learn the parameters $\eta$ of this network in the same way as before.
This gives us the socalled RELAX gradient estimator: $$ \hat{\nabla}_\theta^\text{SFRELAX} \mathcal{F}(\theta) = \left[f(z)  h_\eta(\hat{z}z) \right] \nabla_\theta \log p(z  \theta) + \nabla_\theta h_\eta(\hat{z})  \nabla_\theta h_\eta(\hat{z}z) $$
This estimator now does not assume access to the optimizeable function $f(\cdot)$, nor its differentiability, so it can be applied in larger number of scenarios. Of course, having an access to a differentiable $f(\cdot)$ would put this estimator into a disadvantage compared to REBAR, since the later already has a pretty good idea as to how the baseline should look like.
Overall, I like the REBAR/RELAX gradient estimator for its use of the target function's gradient $\nabla_z f(\cdot)$ and nonlinear baseline somewhat closely approximating the target $f(z)$. However, it's effectiveness comes at a cost: you'd need 3 times more computation: one discrete run $f(z)$, one relaxed run $f(\hat{z})$ and one conditionally relaxed run $f(\hat{z}z)$ – which is much more computation than the plain GumbelSoftmax does.
Conclusion
This post closes the series of Stochastic Computation Graphs. There are many other methods, but for some reason I left them uncovered. Maybe I consider them weird mathematical hacks or simply didn't know about their existence! Overall, I think all these estimators I covered in 3 posts and reasoning behind them establish a solid toolkit for many problems of practical interest.

REINFORCE stands for REward Increment = Nonnegative Factor × Offset Reinforcement × Characteristic Eligibility ↩

One could also use matrix baselines and multiply them by the $\nabla \log p(z\theta)$ as usual, but we won't cover these – this method does not scale well with number of parameters in $\theta$. ↩

Monte Carlo averaging isn't very efficient. The variance decreases as $O(1/L)$ for $L$ samples, and typical error (by invoking the CLT) drops as $O(1 / \sqrt{L})$. That is, to reduce the typical error of MC approximation by a factor of 1000, you'd need an order of millions samples! It's very hard to beat the high variance by sampling alone. ↩↩

Actually, it'd make much more sense to minimize the variance of the obtained estimator directly, we'll discuss this later when talking about the REBAR and RELAX methods. ↩

The Evidence Lower Bound of Variational Inference can be presented in this way. Namely, the ELBO is $$ \begin{align*} \mathcal{F}(\theta) &= \mathbb{E}_{q(z_{1, \dots, N\theta})} \log \frac{p(X, z_1, \dots, z_N  \theta)}{q(z_{1, \dots, N\theta})} \\ &= \mathbb{E}_{q(z_{1, \dots, N\theta})} \left[ \log p(Xz_{1, \dots, N}, \theta) + \sum_{n=1}^N \log \frac{p(z_n  z_{<n})}{q(z_n  z_{<n}, \theta)} \right] \end{align*} $$ Then each intermediate layer gives you reward corresponding to the KL divergence with the prior, and the last layer also gives you the reconstruction reward. ↩

This might be due to the Taylor expansion being an unfortunate choice. Probably, considering some other expansion would be advantageous, but I'm unaware of any such works. ↩

You might ask, wait, what if we use an independent and identically distributed sample $z'$ in the baseline? Consider the following: $$ \left( f(z)  b(z') \right) \nabla_\theta \log p(z\theta), \quad\quad\quad z, z' \sim p(z\theta) $$ This is a valid and unbiased gradient estimate, however since $z$ and $z'$ are independent, this is essentially a stochastic version of the following estimator: $$ \left( f(z)  \mathbb{E}_{p(z'\theta)} b(z') \right) \nabla_\theta \log p(z\theta), \quad\quad\quad z \sim p(z\theta) $$ So we're better off with approximating with a constant (w.r.t. $z$) baseline $b(\theta)$ the expectation of that value, and it is done in the NVIL method we'll talk about later. ↩

The name is a very clever joke. Rebar is a term from construction works for steel bars that are used to reinforce concrete, and Concrete distribution is the name for the distribution of the GumbelSoftmax relaxed random variables. ↩

Unlike the GumbelSoftmax, which was biased for all $\tau > 0$. In a sense, REBAR is a debiased version of GumbelSoftmax. ↩