Recipes using Funsor

This module provides a number of high-level algorithms using Funsor.

forward_filter_backward_rsample(factors: Dict[str, Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], sample_inputs: Dict[str, type] = {}, rng_key=None)[source]

A forward-filter backward-batched-reparametrized-sample algorithm for use in variational inference. The motivating use case is performing Gaussian tensor variable elimination over structured variational posteriors.

Parameters
  • factors (dict) – A dictionary mapping sample site name to a Funsor factor created at that sample site.

  • frozenset – A set of names of latent variables to marginalize and plates to aggregate.

  • plates – A set of names of plates to aggregate.

  • sample_inputs (dict) – An optional dict of enclosing sample indices over which samples will be drawn in batch.

  • rng_key – A random number key for the JAX backend.

Returns

A pair samples:Dict[str, Tensor], log_prob: Tensor of samples and log density evaluated at each of those samples. If sample_inputs is nonempty, both outputs will be batched.

Return type

tuple

forward_filter_backward_precondition(factors: Dict[str, Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], aux_name: str = 'aux')[source]

A forward-filter backward-precondition algorithm for use in variational inference or preconditioning in Hamiltonian Monte Carlo. The motivating use case is performing Gaussian tensor variable elimination over structured variational posteriors, and optionally using the learned posterior to determine momentum in HMC.

Parameters
  • factors (dict) – A dictionary mapping sample site name to a Funsor factor created at that sample site.

  • frozenset – A set of names of latent variables to marginalize and plates to aggregate.

  • plates – A set of names of plates to aggregate.

  • aux_name (str) – Name of the auxiliary variable containing white noise.

Returns

A pair samples:Dict[str, Tensor], log_prob: Tensor of samples and log density evaluated at each of those samples. Both outputs depend on a vector named by aux_name, e.g. aux: Reals[d] where d is the total number of elements in eliminated variables.

Return type

tuple