\(\text{KL}(p\|q)\) minimization

One form of variational inference minimizes the Kullback-Leibler divergence from \(p(\mathbf{z} \mid \mathbf{x})\) to \(q(\mathbf{z}\;;\;\lambda)\), \[\begin{aligned} \lambda^* &= \arg\min_\lambda \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) )\\ &= \arg\min_\lambda\; \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \log p(\mathbf{z} \mid \mathbf{x}) - \log q(\mathbf{z}\;;\;\lambda) \big].\end{aligned}\] The KL divergence is a non-symmetric, information theoretic measure of similarity between two probability distributions.

Minimizing an intractable objective function

The \(\text{KL}(p\|q)\) objective we seek to minimize is intractable; it directly involves the posterior \(p(\mathbf{z} \mid \mathbf{x})\). Ignoring this for the moment, consider its gradient \[\begin{aligned} \nabla_\lambda\; \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) ) &= - \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \big].\end{aligned}\] Both \(\text{KL}(p\|q)\) and its gradient are intractable because of the posterior expectation. We can use importance sampling to both estimate the objective and calculate stochastic gradients (Oh & Berger, 1992).

Adaptive Importance sampling

First rewrite the expectation to be with respect to the variational distribution, \[\begin{aligned} - \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \big] &= - \mathbb{E}_{q(\mathbf{z}\;;\;\lambda)} \Bigg[ \frac{p(\mathbf{z} \mid \mathbf{x})}{q(\mathbf{z}\;;\;\lambda)} \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \Bigg].\end{aligned}\]

We then use importance sampling to obtain a noisy estimate of this gradient. The basic procedure follows these steps:

  1. draw \(S\) samples \(\{\mathbf{z}_s\}_1^S \sim q(\mathbf{z}\;;\;\lambda)\),
  2. evaluate \(\nabla_\lambda\; \log q(\mathbf{z}_s\;;\;\lambda)\),
  3. compute the normalized importance weights \[\begin{aligned} w_s &= \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \end{aligned}\]
  4. compute the weighted mean.

The key insight is that we can use the joint \(p(\mathbf{x},\mathbf{z})\) instead of the posterior when estimating the normalized importance weights \[\begin{aligned} w_s &= \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \\ &= \frac{p(\mathbf{x}, \mathbf{z}_s)}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{x}, \mathbf{z}_s)}{q(\mathbf{z}_s\;;\;\lambda)}.\end{aligned}\] This follows from Bayes’ rule \[\begin{aligned} p(\mathbf{z} \mid \mathbf{x}) &= p(\mathbf{x}, \mathbf{z}) / p(\mathbf{x})\\ &= p(\mathbf{x}, \mathbf{z}) / \text{a constant function of }\mathbf{z}.\end{aligned}\]

Importance sampling thus gives the following noisy yet unbiased gradient estimate \[\begin{aligned} \nabla_\lambda\; \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) ) &= - \frac{1}{S} \sum_{s=1}^S w_s \nabla_\lambda\; \log q(\mathbf{z}_s\;;\;\lambda).\end{aligned}\] The objective \(\text{KL}(p\|q)\) can be calculated in a similar fashion. The only new ingredient for its gradient is the score function \(\nabla_\lambda \log q(\mathbf{z}\;;\;\lambda)\). Edward uses automatic differentiation, specifically with TensorFlow’s computational graphs, making this gradient computation both simple and efficient to distribute.

Adaptive importance sampling follows this gradient to a local optimum using stochastic optimization. It is adaptive because the variational distribution \(q(\mathbf{z}\;;\;\lambda)\) iteratively gets closer to the posterior \(p(\mathbf{z} \mid \mathbf{x})\).

Implementation

Note: These details are outdated since Edward v1.1.3.

We implement \(\text{KL}(q\|p)\) minimization with the importance sampling gradient in the class KLpq. It inherits from VariationalInference, which provides a collection of default methods such as an optimizer.

class KLpq(VariationalInference):
  def __init__(self, *args, **kwargs):
    super(KLpq, self).__init__(*args, **kwargs)

  def initialize(self, n_samples=1, *args, **kwargs):
    self.n_samples = n_samples
    return super(KLpq, self).initialize(*args, **kwargs)

  def build_loss(self):
    x = self.data
    z = {key: rv.sample([self.n_samples])
         for key, rv in six.iteritems(self.latent_vars)}

    # normalized importance weights
    q_log_prob = 0.0
    for key, rv in six.iteritems(self.latent_vars):
      q_log_prob += tf.reduce_sum(rv.log_prob(tf.stop_gradient(z[key])),
                                  list(range(1, len(rv.get_batch_shape()) + 1)))

    log_w = self.model_wrapper.log_prob(x, z) - q_log_prob
    log_w_norm = log_w - log_sum_exp(log_w)
    w_norm = tf.exp(log_w_norm)

    self.loss = tf.reduce_mean(w_norm * log_w)
    return -tf.reduce_mean(q_log_prob * tf.stop_gradient(w_norm))

Two methods are added: initialize() and build_loss(). The method initialize() follows the same initialization as VariationalInference, and adds an argument: n_samples for the number of samples from the variational model.

The method build_loss() implements the \(\text{KL}(q\|p)\) objective and its gradient. It draws self.n_samples samples from the variational model. It registers the Monte Carlo estimate of \(\text{KL}(q\|p)\) in TensorFlow’s computational graph, and stores it in self.loss, to track progress of the inference for diagnostics.

The method returns an object whose automatic differentiation is a stochastic gradient of \(\text{KL}(p\|q)\). The TensorFlow function tf.stop_gradient() tells the computational graph to stop traversing nodes to propagate gradients. In this case, the only gradients taken are \(\nabla_\lambda \log q(\mathbf{z}_s\;;\;\lambda)\), one for each sample \(\mathbf{z}_s\). We multiply it by w_norm element-wise and return the mean.

Computing the normalized importance weights is a numerically challenging task. Underflow is a common issue. Here we use the log_sum_exp trick to compute weights in the log space, available in edward.util. However, the normalized weights are still exponentiated to compute the gradient. As with importance sampling in general, this inference method does not scale to high dimensions.

See the API for more details.

References

Oh, M.-S., & Berger, J. O. (1992). Adaptive importance sampling in monte carlo integration. Journal of Statistical Computation and Simulation, 41(3-4), 143–168.