ed.WGANInference

Class WGANInference

Inherits From: GANInference

Aliases:

  • Class ed.WGANInference
  • Class ed.inferences.WGANInference

Defined in edward/inferences/wgan_inference.py.

Parameter estimation with GAN-style training (Goodfellow et al., 2014), using the Wasserstein distance (Arjovsky, Chintala, & Bottou, 2017).

Works for the class of implicit (and differentiable) probabilistic models. These models do not require a tractable density and assume only a program that generates samples.

Notes

Argument-wise, the only difference from GANInference is conceptual: the discriminator is better described as a test function or critic. WGANInference continues to use discriminator only to share methods and attributes with GANInference.

The objective function also adds to itself a summation over all tensors in the REGULARIZATION_LOSSES collection.

Examples

z = Normal(loc=tf.zeros([100, 10]), scale=tf.ones([100, 10]))
x = generative_network(z)

inference = ed.WGANInference({x: x_data}, discriminator)

Methods

init

__init__(
    *args,
    **kwargs
)

build_loss_and_gradients

build_loss_and_gradients(var_list)

finalize

finalize()

Function to call after convergence.

initialize

initialize(
    penalty=10.0,
    clip=None,
    *args,
    **kwargs
)

Initialize inference algorithm. It initializes hyperparameters and builds ops for the algorithm’s computation graph.

Args:

  • penalty: float. Scalar value to enforce gradient penalty that ensures the gradients have norm equal to 1 (Gulrajani, Ahmed, Arjovsky, Dumoulin, & Courville, 2017). Set to None (or 0.0) if using no penalty.
  • clip: float. Value to clip weights by. Default is no clipping.
print_progress(info_dict)

Print progress to output.

run

run(
    variables=None,
    use_coordinator=True,
    *args,
    **kwargs
)

A simple wrapper to run inference.

  1. Initialize algorithm via initialize.
  2. (Optional) Build a TensorFlow summary writer for TensorBoard.
  3. (Optional) Initialize TensorFlow variables.
  4. (Optional) Start queue runners.
  5. Run update for self.n_iter iterations.
  6. While running, print_progress.
  7. Finalize algorithm via finalize.
  8. (Optional) Stop queue runners.

To customize the way inference is run, run these steps individually.

Args:

  • variables: list. A list of TensorFlow variables to initialize during inference. Default is to initialize all variables (this includes reinitializing variables that were already initialized). To avoid initializing any variables, pass in an empty list.
  • use_coordinator: bool. Whether to start and stop queue runners during inference using a TensorFlow coordinator. For example, queue runners are necessary for batch training with file readers. *args, **kwargs: Passed into initialize.

update

update(
    feed_dict=None,
    variables=None
)

Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein GAN. In International conference on machine learning.

Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … Bengio, Y. (2014). Generative adversarial nets. In Neural information processing systems.

Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., & Courville, A. (2017). Improved Training of Wasserstein GANs. arXiv.org. Retrieved from http://arxiv.org/abs/1704.00028v1