ed.ImplicitKLqp

Class ImplicitKLqp

Inherits From: GANInference

Aliases:

  • Class ed.ImplicitKLqp
  • Class ed.inferences.ImplicitKLqp

Defined in edward/inferences/implicit_klqp.py.

Variational inference with implicit probabilistic models (Tran et al., 2017).

It minimizes the KL divergence

\(\text{KL}( q(z, \beta; \lambda) \| p(z, \beta \mid x) ),\)

where \(z\) are local variables associated to a data point and \(\beta\) are global variables shared across data points.

Global latent variables require log_prob() and need to return a random sample when fetched from the graph. Local latent variables and observed variables require only a random sample when fetched from the graph. (This is true for both \(p\) and \(q\).)

All variational factors must be reparameterizable: each of the random variables (rv) satisfies rv.is_reparameterized and rv.is_continuous.

Notes

Unlike GANInference, discriminator takes dict's as input, and must subset to the appropriate values through lexical scoping from the previously defined model and latent variables. This is necessary as the discriminator can take an arbitrary set of data, latent, and global variables.

Note the type for discriminator's output changes when one passes in the scale argument to initialize().

  • If scale has at most one item, then discriminator outputs a tensor whose multiplication with that element is broadcastable. (For example, the output is a tensor and the single scale factor is a scalar.)
  • If scale has more than one item, then in order to scale its corresponding output, discriminator must output a dictionary of same size and keys as scale.

Methods

init

__init__(
    latent_vars,
    data=None,
    discriminator=None,
    global_vars=None
)

Create an inference algorithm.

Args:

  • discriminator: function. Function (with parameters). Unlike GANInference, it is interpreted as a ratio estimator rather than a discriminator. It takes three arguments: a data dict, local latent variable dict, and global latent variable dict. As with GAN discriminators, it can take a batch of data points and local variables, of size \(M\), and output a vector of length \(M\).
  • global_vars: dict of RandomVariable to RandomVariable, optional. Identifying which variables in latent_vars are global variables, shared across data points. These will not be encompassed in the ratio estimation problem, and will be estimated with tractable variational approximations.

build_loss_and_gradients

build_loss_and_gradients(var_list)

Build loss function

\(-\Big(\mathbb{E}_{q(\beta)} [\log p(\beta) - \log q(\beta) ] + \sum_{n=1}^N \mathbb{E}_{q(\beta)q(z_n\mid\beta)} [ r^*(x_n, z_n, \beta) ] \Big).\)

We minimize it with respect to parameterized variational families \(q(z, \beta; \lambda)\).

\(r^*(x_n, z_n, \beta)\) is a function of a single data point \(x_n\), single local variable \(z_n\), and all global variables \(\beta\). It is equal to the log-ratio

\(\log p(x_n, z_n\mid \beta) - \log q(x_n, z_n\mid \beta),\)

where \(q(x_n)\) is the empirical data distribution. Rather than explicit calculation, \(r^*(x, z, \beta)\) is the solution to a ratio estimation problem, minimizing the specified ratio_loss.

Gradients are taken using the reparameterization trick (Kingma and Welling, 2014).

Notes

This also includes model parameters \(p(x, z, \beta; \theta)\) and variational distributions with inference networks \(q(z\mid x)\).

There are a bunch of extensions we could easily do in this implementation:

  • further factorizations can be used to better leverage the graph structure for more complicated models;
  • score function gradients for global variables;
  • use more samples; this would require the copy() utility function for q's as well, and an additional loop. we opt not to because it complicates the code;
  • analytic KL/swapping out the penalty term for the globals.

finalize

finalize()

Function to call after convergence.

initialize

initialize(
    ratio_loss='log',
    *args,
    **kwargs
)

Initialize inference algorithm. It initializes hyperparameters and builds ops for the algorithm's computation graph.

Args:

  • ratio_loss: str or fn, optional. Loss function minimized to get the ratio estimator. 'log' or 'hinge'. Alternatively, one can pass in a function of two inputs, psamples and qsamples, and output a point-wise value with shape matching the shapes of the two inputs.
print_progress(info_dict)

Print progress to output.

run

run(
    variables=None,
    use_coordinator=True,
    *args,
    **kwargs
)

A simple wrapper to run inference.

  1. Initialize algorithm via initialize.
  2. (Optional) Build a TensorFlow summary writer for TensorBoard.
  3. (Optional) Initialize TensorFlow variables.
  4. (Optional) Start queue runners.
  5. Run update for self.n_iter iterations.
  6. While running, print_progress.
  7. Finalize algorithm via finalize.
  8. (Optional) Stop queue runners.

To customize the way inference is run, run these steps individually.

Args:

  • variables: list, optional. A list of TensorFlow variables to initialize during inference. Default is to initialize all variables (this includes reinitializing variables that were already initialized). To avoid initializing any variables, pass in an empty list.
  • use_coordinator: bool, optional. Whether to start and stop queue runners during inference using a TensorFlow coordinator. For example, queue runners are necessary for batch training with file readers. *args, **kwargs: Passed into initialize.

update

update(
    feed_dict=None,
    variables=None
)

Run one iteration of optimization.

Args:

  • feed_dict: dict, optional. Feed dictionary for a TensorFlow session run. It is used to feed placeholders that are not fed during initialization.
  • variables: str, optional. Which set of variables to update. Either "Disc" or "Gen". Default is both.

Returns:

dict. Dictionary of algorithm-specific information. In this case, the iteration number and generative and discriminative losses.

Notes

The outputted iteration number is the total number of calls to update. Each update may include updating only a subset of parameters.