observations.multi_mnist

multi_mnist(
    path,
    max_digits=2,
    canvas_size=50,
    seed=42
)

Load the multiple MNIST data set (Eslami et al., 2016). It modifies the original MNIST such that each image contains a number of non-overlapping random MNIST digits with equal probability.

Args:

  • path: str. Path to directory which either stores file or otherwise file will be downloaded and extracted there. Filename is 'multi_mnist_{}_{}_{}.npz'.format(max_digits, canvas_size, seed).
  • max_digits: int, optional. Maximum number of non-overlapping MNIST digits per image to generate if not cached.
  • canvas_size: list of two int, optional. Width x height pixel size of generated images if not cached.
  • seed: int, optional. Random seed to generate the data set from MNIST if not cached.

Returns:

Tuple of (np.ndarray of dtype uint8, list) (x_train, y_train), (x_test, y_test). Each element in the y’s is a np.ndarray of labels, one label for each digit in the image.

Eslami, S. A., Heess, N., Weber, T., Tassa, Y., Szepesvari, D., Hinton, G. E., & others. (2016). Attend, infer, repeat: Fast scene understanding with generative models. In Neural information processing systems.