Interpretations

Interpreter

exception PatternMissingError[source]

Bases: NotImplementedError

interpretation(new)[source]
pop_interpretation()[source]
push_interpretation(new)[source]
reinterpret(x)[source]

Overloaded reinterpretation of a deferred expression.

This handles a limited class of expressions, raising ValueError in unhandled cases.

Parameters

x (A funsor or data structure holding funsors.) – An input, typically involving deferred Funsor s.

Returns

A reinterpreted version of the input.

Raises

ValueError

Interpretations

class CallableInterpretation(interpret)[source]

Bases: Interpretation

A simple callable interpretation.

Example usage:

@CallableInterpretation
def my_interpretation(cls, *args):
    return ...
Parameters

interpret (callable) – A function implementing interpretation.

set_callable(interpret)[source]

Resets the callable .interpret attribute.

class DispatchedInterpretation(name='dispatched')[source]

Bases: Interpretation

An interpretation based on pattern matching.

Example usage:

my_interpretation = DispatchedInterpretation("my_interpretation")

# Register a funsor pattern and rule.
@my_interpretation.register(...)
def my_impl(cls, *args):
    ...

# Use the new interpretation.
with my_interpretation:
    ...
class Interpretation(name)[source]

Bases: ContextDecorator, ABC

Abstract base class for Funsor interpretations.

Instances may be used as context managers or decorators.

Parameters

name (str) – A name used for printing and debugging (required).

class Memoize(base_interpretation, cache=None)[source]

Bases: Interpretation

Exploits cons-hashing to do implicit common subexpression elimination.

Parameters
  • base_interpretation (Interpretation) – The interpretation to memoize.

  • cache (dict) – An optional temporary cache where results will be memoized.

class StatefulInterpretation(name='stateful')[source]

Bases: Interpretation

Base class for interpretations with instance-dependent state or parameters.

Example usage:

class MyInterpretation(StatefulInterpretation):

    def __init__(self, my_param):
        self.my_param = my_param

@MyInterpretation.register(...)
def my_impl(interpretation_state, cls, *args):
    my_param = interpretation_state.my_param
    ...

with MyInterpretation(my_param=0.1):
    ...
eager = eager/normalize/reflect

Eager exact naive interpretation wherever possible.

lazy = lazy/reflect

Performs substitutions eagerly, but construct lazy funsors for everything else.

memoize(cache=None)[source]

Context manager wrapping Memoize and yielding the cache dict.

moment_matching = moment_matching/eager/normalize/reflect

A moment matching interpretation of Reduce expressions. This falls back to eager in other cases.

normalize = normalize/reflect

Normalize modulo associativity and commutativity, but do not evaluate any numerical operations.

sequential = sequential/eager/normalize/reflect

Eagerly execute ops with known implementations; additonally execute vectorized ops sequentially if no known vectorized implementation exists.

Monte Carlo

class MonteCarlo(*, rng_key=None, **sample_inputs)[source]

Bases: StatefulInterpretation

A Monte Carlo interpretation of Integrate expressions. This falls back to the previous interpreter in other cases.

Parameters

rng_key

Preconditioning

class Precondition(aux_name='aux')[source]

Bases: StatefulInterpretation

Preconditioning interpretation for adjoint computations.

This interpretation is intended to be used once, followed by a call to combine_subs() as follows:

# Lazily build a factor graph.
with reflect:
    log_joint = Gaussian(...) + ... + Gaussian(...)
    log_Z = log_joint.reduce(ops.logaddexp)

# Run a backward sampling under the precondition interpretation.
with Precondition() as p:
    marginals = adjoint(
        ops.logaddexp, ops.add, log_Z, batch_vars=p.sample_vars
    )
combine_subs = p.combine_subs()

# Extract samples from Delta distributions.
samples = {
    k: v(**combine_subs)
    for name, delta in marginals.items()
    for k, v in funsor.montecarlo.extract_samples(delta).items()
}

See forward_filter_backward_precondition() for complete usage.

Parameters

aux_name (str) – Name of the auxiliary variable containing white noise.

combine_subs()[source]

Method to create a combining substitution after preconditioning is complete. The returned substitution replaces per-factor auxiliary variables with slices into a single combined auxiliary variable.

Returns

A substitution indexing each factor-wise auxiliary variable into a single global auxiliary variable.

Return type

dict

Approximations

argmax_approximate = argmax_approximate

Point-approximate at the argmax of the provided guide.

compute_argmax(model, approx_vars)[source]
compute_argmax(model: Tensor, approx_vars)
compute_argmax(model: Gaussian, approx_vars)
compute_argmax(model: Contraction[Union[LogaddexpOp, NullOp], AddOp, frozenset, Tuple[Union[Tensor, Number], Gaussian]], approx_vars)
compute_argmax(model: Contraction[NullOp, AddOp, frozenset, tuple], approx_vars)

Computes argmax of a funsor.

Parameters
  • model (Funsor) – A function of the approximated vars.

  • approx_vars (frozenset) – A frozenset of Variable s to maximize.

Returns

A dict mapping name (str) to point estimate (Funsor), for each variable name in approx_vars.

Return type

str

laplace_approximate = laplace_approximate

Gaussian approximate using the value and Hessian of the model, evaluated at the mode of the guide.

mean_approximate = mean_approximate

Point-approximate at the mean of the provided guide.

Evidence lower bound

class Elbo(guide, approx_vars)[source]

Bases: StatefulInterpretation

Given an approximating guide funsor, approximates:

model.reduce(ops.logaddexp, approx_vars)

by the lower bound:

Integrate(guide, model - guide, approx_vars)
Parameters
  • guide (Funsor) – A guide or proposal funsor.

  • approx_vars (frozenset) – The variables being integrated.