An inference network is a flexible construction for parameterizing approximating distributions during inference. They are used in Helmholtz machines (Dayan, Hinton, Neal, & Zemel, 1995), deep Boltzmann machines (Salakhutdinov & Larochelle, 2010), and variational auto-encoders (Kingma & Welling, 2014; Rezende, Mohamed, & Wierstra, 2014).
Recall that probabilistic models often have local latent variables, that is, latent variables associated with a data point; for example, the latent cluster assignment of a data point or the hidden representation of a text or image. Variational inference on models with local latent variables requires optimizing over each variational factor, \[q(\mathbf{z}; \mathbf{\lambda}) = \prod_{n=1}^N q(z_n; \lambda_n),\] where \(z_n\) are the latent variables associated to a data point \(x_n\). The set of parameters \(\mathbf{\lambda}=\{\lambda_n\}\) grows with the size of data. This makes computation difficult when the data does not fit in memory. Further, a new set of parameters \(\mathbf{\lambda}'=\{\lambda_n'\}\) must be optimized at test time, where there may be a new set of data points \(\{x_n'\}\).
An inference network replaces local variational parameters with global parameters. It is a neural network which takes \(x_n\) as input and which outputs its local variational parameters \(\lambda_n\), \[q(\mathbf{z}\mid \mathbf{x}; \theta) = \prod_{n=1}^N q(z_n \mid x_n; \lambda_n) = \prod_{n=1}^N q(z_n; \lambda_n = \mathrm{NN}(x_n; \theta)).\] This amortizes inference by only defining a set of global parameters, namely, the parameters of the neural network. These parameters have a fixed size, making the memory complexity of inference a constant. Further, the same set of parameters can be used at test time: regardless of seeing new or old data points \(x_n'\), we simply pass it through the neural network (a forward pass) to obtain its corresponding variational factor \(q(z_n'; \lambda_n' = \mathrm{NN}(x_n'; \theta))\).
Note that the inference network is an approximation: it is a common misunderstanding that the inference network produces a more expressive variational model. The inference network restricts the size of parameters, meaning it can only do as well as an approximation as the original variational distribution without the inference network.
There are often good reasons beyond computational reasons for wanting to use an inference network, depending on the problem setting. Originating from Helmholtz machines, they are motivated by the hypothesis that "the human perceptual system is a statistical inference engine whose function is to infer the probable causes of sensory input." The inference network is this function. In cognition, "humans and robots are in the setting of amortized inference: they have to solve many similar inference problems, and can thus offload part of the computational work to shared precomputation and adaptation over time" (Stuhlmüller, Taylor, & Goodman, 2013).
Inference networks are easy to build in Edward. In the example below, a data point x
is a 28 by 28 pixel image (from MNIST).
We define a TensorFlow placeholder x_ph
for the neural network’s input, which is a batch of data points.
x_ph = tf.placeholder(tf.float32, [M, 28 * 28])
We build a neural network using Keras. It takes a 28 by 28 pixel image and outputs a \(10\)-dimensional output, one for the mean (mu
) and one for the standard deviation (sigma
).
from edward.models import Normal
from keras.layers import Dense
hidden = Dense(256, activation='relu')(x_ph)
qz = Normal(loc=Dense(10)(hidden),
scale=Dense(10, activation='softplus')(hidden))
The variational model is parameterized by the neural network’s output.
During each step of inference, we pass in a batch of data points to feed the placeholder. Then we train the variational parameters (inference network’s parameters) according to gradients of the variational objective.
An example script using this variational model can be found at examples/vae.py
in the Github repository. An example with a convolutional architecture can be found at examples/vae_convolutional_prettytensor.py
in the Github repository.
Inference networks are also known as recognition models, recognition networks, or inverse mappings.
Variational models parameterized by inference networks are also known as probabilistic encoders, in analogy from coding theory to probabilistic decoders. We recommend against this terminology, in favor of making explicit the separation of model and inference. That is, variational inference with inference networks is a general technique that can be applied to models beyond probabilistic decoders.
This tutorial is taken from text originating in Tran, Ranganath, & Blei (2016). We thank Kevin Murphy for motivating the tutorial as it is based on our discussions, and also related discussion with Jaan Altosaar.
Dayan, P., Hinton, G. E., Neal, R. M., & Zemel, R. S. (1995). The Helmholtz machine. Neural Computation, 7(5), 889–904.
Kingma, D., & Welling, M. (2014). Auto-encoding variational Bayes. In International conference on learning representations.
Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In ICML (pp. 1278–1286).
Salakhutdinov, R., & Larochelle, H. (2010). Efficient learning of deep boltzmann machines. In International conference on artificial intelligence and statistics.
Stuhlmüller, A., Taylor, J., & Goodman, N. (2013). Learning stochastic inverses. In Neural information processing systems.
Tran, D., Ranganath, R., & Blei, D. M. (2016). The variational gaussian process. In International conference on learning representations.