Recipes using Funsor¶
This module provides a number of high-level algorithms using Funsor.
-
forward_filter_backward_rsample
(factors: Dict[str, funsor.terms.Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], sample_inputs: Dict[str, type] = {}, rng_key=None)[source]¶ A forward-filter backward-batched-reparametrized-sample algorithm for use in variational inference. The motivating use case is performing Gaussian tensor variable elimination over structured variational posteriors.
Parameters: - factors (dict) – A dictionary mapping sample site name to a Funsor factor created at that sample site.
- frozenset – A set of names of latent variables to marginalize and plates to aggregate.
- plates – A set of names of plates to aggregate.
- sample_inputs (dict) – An optional dict of enclosing sample indices over which samples will be drawn in batch.
- rng_key – A random number key for the JAX backend.
Returns: A pair
samples:Dict[str, Tensor], log_prob: Tensor
of samples and log density evaluated at each of those samples. Ifsample_inputs
is nonempty, both outputs will be batched.Return type:
-
forward_filter_backward_precondition
(factors: Dict[str, funsor.terms.Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], aux_name: str = 'aux')[source]¶ A forward-filter backward-precondition algorithm for use in variational inference or preconditioning in Hamiltonian Monte Carlo. The motivating use case is performing Gaussian tensor variable elimination over structured variational posteriors, and optionally using the learned posterior to determine momentum in HMC.
Parameters: - factors (dict) – A dictionary mapping sample site name to a Funsor factor created at that sample site.
- frozenset – A set of names of latent variables to marginalize and plates to aggregate.
- plates – A set of names of plates to aggregate.
- aux_name (str) – Name of the auxiliary variable containing white noise.
Returns: A pair
samples:Dict[str, Tensor], log_prob: Tensor
of samples and log density evaluated at each of those samples. Both outputs depend on a vector named byaux_name
, e.g.aux: Reals[d]
whered
is the total number of elements in eliminated variables.Return type: