\(\text{KL}(p\|q)\) Minimization

One form of variational inference minimizes the Kullback-Leibler divergence from \(p(\mathbf{z} \mid \mathbf{x})\) to \(q(\mathbf{z}\;;\;\lambda)\), \[\begin{aligned} \lambda^* &= \arg\min_\lambda \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) )\\ &= \arg\min_\lambda\; \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \log p(\mathbf{z} \mid \mathbf{x}) - \log q(\mathbf{z}\;;\;\lambda) \big].\end{aligned}\] The KL divergence is a non-symmetric, information theoretic measure of similarity between two probability distributions.

Minimizing an intractable objective function

The \(\text{KL}(p\|q)\) objective we seek to minimize is intractable; it directly involves the posterior \(p(\mathbf{z} \mid \mathbf{x})\). Ignoring this for the moment, consider its gradient \[\begin{aligned} \nabla_\lambda\; \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) ) &= - \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \big].\end{aligned}\] Both \(\text{KL}(p\|q)\) and its gradient are intractable because of the posterior expectation. We can use importance sampling to both estimate the objective and calculate stochastic gradients (Oh & Berger, 1992).

Adaptive Importance sampling

First rewrite the expectation to be with respect to the variational distribution, \[\begin{aligned} - \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \big] &= - \mathbb{E}_{q(\mathbf{z}\;;\;\lambda)} \Bigg[ \frac{p(\mathbf{z} \mid \mathbf{x})}{q(\mathbf{z}\;;\;\lambda)} \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \Bigg].\end{aligned}\]

We then use importance sampling to obtain a noisy estimate of this gradient. The basic procedure follows these steps:

  1. draw \(S\) samples \(\{\mathbf{z}_s\}_1^S \sim q(\mathbf{z}\;;\;\lambda)\),
  2. evaluate \(\nabla_\lambda\; \log q(\mathbf{z}_s\;;\;\lambda)\),
  3. compute the normalized importance weights \[\begin{aligned} w_s &= \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \end{aligned}\]
  4. compute the weighted sum.

The key insight is that we can use the joint \(p(\mathbf{x},\mathbf{z})\) instead of the posterior when estimating the normalized importance weights \[\begin{aligned} w_s &= \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \\ &= \frac{p(\mathbf{x}, \mathbf{z}_s)}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{x}, \mathbf{z}_s)}{q(\mathbf{z}_s\;;\;\lambda)}.\end{aligned}\] This follows from Bayes’ rule \[\begin{aligned} p(\mathbf{z} \mid \mathbf{x}) &= p(\mathbf{x}, \mathbf{z}) / p(\mathbf{x})\\ &= p(\mathbf{x}, \mathbf{z}) / \text{a constant function of }\mathbf{z}.\end{aligned}\]

Importance sampling thus gives the following biased yet consistent gradient estimate \[\begin{aligned} \nabla_\lambda\; \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) ) &= - \sum_{s=1}^S w_s \nabla_\lambda\; \log q(\mathbf{z}_s\;;\;\lambda).\end{aligned}\] The objective \(\text{KL}(p\|q)\) can be calculated in a similar fashion. The only new ingredient for its gradient is the score function \(\nabla_\lambda \log q(\mathbf{z}\;;\;\lambda)\). Edward uses automatic differentiation, specifically with TensorFlow’s computational graphs, making this gradient computation both simple and efficient to distribute.

Adaptive importance sampling follows this gradient to a local optimum using stochastic optimization. It is adaptive because the variational distribution \(q(\mathbf{z}\;;\;\lambda)\) iteratively gets closer to the posterior \(p(\mathbf{z} \mid \mathbf{x})\).

For more details, see the API as well as its implementation in Edward’s code base.

References

Oh, M.-S., & Berger, J. O. (1992). Adaptive importance sampling in Monte Carlo integration. Journal of Statistical Computation and Simulation, 41(3-4), 143–168.