Funsor is a tensor-like library for functions and distributions¶
Operations¶
Operation classes¶
- class LogAbsDetJacobianOp(*args, **kwargs)¶
Bases:
BinaryOp
- static default(x, y, fn)¶
- dispatcher = <dispatched log_abs_det_jacobian>¶
- name = 'log_abs_det_jacobian'¶
- signature = <Signature (x, y, fn)>¶
- class Op(*args, **kwargs)[source]¶
Bases:
object
Abstract base class for all mathematical operations on ground terms.
Ops take
arity
-many leftmost positional args that may be funsors, followed by additional non-funsor args and kwargs. The additional args and kwargs must have default values.When wrapping new backend ops, keep in mind these restrictions, which may require you to wrap backend functions before making them into ops:
Create new ops only by decoraing a default implementation with
@UnaryOp.make
,@BinaryOp.make
, etc.Register backend-specific implementations via
@my_op.register(type1)
,@my_op.register(type1, type2)
etc for arity 1, 2, etc. Patterns may include only the firstarity
-many types.Only the first
arity
-many arguments may be funsors. Remaining args and kwargs must all be ground Python data.
- Variables
~.arity (int) – The number of funsor arguments this op takes. Must be defined by subclasses.
- Parameters
*args –
**kwargs – All extra arguments to this op, excluding the arguments up to
.arity
,
- arity = NotImplemented¶
- register(*pattern)¶
- class TransformOp(*args, **kwargs)[source]¶
Bases:
UnaryOp
- set_inv(fn)[source]¶
- Parameters
fn (callable) – A function that inputs an arg
y
and outputs a valuex
such thaty=self(x)
.
- class WrappedTransformOp(*args, **kwargs)¶
Bases:
TransformOp
Wrapper for a backend
Transform
object that provides.inv
and.log_abs_det_jacobian
. This additionally validates shapes on the first__call__()
.- static default(x, fn, *, validate_args=True)¶
Wrapper for a backend
Transform
object that provides.inv
and.log_abs_det_jacobian
. This additionally validates shapes on the first__call__()
.
- dispatcher = <dispatched wrapped_transform>¶
- property inv¶
- property log_abs_det_jacobian¶
- name = 'wrapped_transform'¶
- signature = <Signature (x, fn, *, validate_args=True)>¶
Builtin operations¶
- abs = ops.abs¶
Return the absolute value of the argument.
- add = ops.add¶
Same as a + b.
- and_ = ops.and_¶
Same as a & b.
- atanh = ops.atanh¶
Return the inverse hyperbolic tangent of x.
- eq = ops.eq¶
Same as a == b.
- exp = ops.exp¶
Return e raised to the power of x.
- floordiv = ops.floordiv¶
Same as a // b.
- ge = ops.ge¶
Same as a >= b.
- getitem = ops.getitem¶
- getslice = ops.getslice¶
- gt = ops.gt¶
Same as a > b.
- invert = ops.invert¶
Same as ~a.
- le = ops.le¶
Same as a <= b.
- lgamma = ops.lgamma¶
Natural logarithm of absolute value of Gamma function at x.
- log = ops.log¶
- log1p = ops.log1p¶
Return the natural logarithm of 1+x (base e).
The result is computed in a way which is accurate for x near zero.
- lshift = ops.lshift¶
Same as a << b.
- lt = ops.lt¶
Same as a < b.
- matmul = ops.matmul¶
Same as a @ b.
- max = ops.max¶
- min = ops.min¶
- mod = ops.mod¶
Same as a % b.
- mul = ops.mul¶
Same as a * b.
- ne = ops.ne¶
Same as a != b.
- neg = ops.neg¶
Same as -a.
- null = ops.null¶
Placeholder associative op that unifies with any other op
- or_ = ops.or_¶
Same as a | b.
- pos = ops.pos¶
Same as +a.
- pow = ops.pow¶
Same as a ** b.
- reciprocal = ops.reciprocal¶
- rshift = ops.rshift¶
Same as a >> b.
- safediv = ops.safediv¶
- safesub = ops.safesub¶
- sigmoid = ops.sigmoid¶
- sqrt = ops.sqrt¶
Return the square root of x.
- sub = ops.sub¶
Same as a - b.
- tanh = ops.tanh¶
Return the hyperbolic tangent of x.
- truediv = ops.truediv¶
Same as a / b.
- xor = ops.xor¶
Same as a ^ b.
Array operations¶
- all = ops.all¶
- amax = ops.amax¶
- amin = ops.amin¶
- any = ops.any¶
- argmax = ops.argmax¶
- argmin = ops.argmin¶
- astype = ops.astype¶
- cat = ops.cat¶
- cholesky = ops.cholesky¶
Like
numpy.linalg.cholesky()
but uses sqrt for scalar matrices.
- cholesky_inverse = ops.cholesky_inverse¶
Like
torch.cholesky_inverse()
but supports batching and gradients.
- cholesky_solve = ops.cholesky_solve¶
- clamp = ops.clamp¶
- detach = ops.detach¶
- diagonal = ops.diagonal¶
- einsum = ops.einsum¶
- expand = ops.expand¶
- finfo = ops.finfo¶
- flip = ops.flip¶
- full_like = ops.full_like¶
- isnan = ops.isnan¶
- logaddexp = ops.logaddexp¶
- logsumexp = ops.logsumexp¶
- mean = ops.mean¶
- new_arange = ops.new_arange¶
- new_eye = ops.new_eye¶
- new_full = ops.new_full¶
- new_zeros = ops.new_zeros¶
- permute = ops.permute¶
- prod = ops.prod¶
- qr = ops.qr¶
- randn = ops.randn¶
- sample = ops.sample¶
- scatter = ops.scatter¶
- scatter_add = ops.scatter_add¶
- stack = ops.stack¶
- std = ops.std¶
- sum = ops.sum¶
- transpose = ops.transpose¶
- triangular_inv = ops.triangular_inv¶
- triangular_solve = ops.triangular_solve¶
- unsqueeze = ops.unsqueeze¶
- var = ops.var¶
Domains¶
- class Bint[source]¶
Bases:
object
Factory for bounded integer types:
Bint[5] # integers ranging in {0,1,2,3,4} Bint[2, 3, 3] # 3x3 matrices with entries in {0,1}
- dtype = None¶
- shape = None¶
- class Dependent(fn)[source]¶
Bases:
object
Type hint for dependently type-decorated functions.
Examples:
Dependent[Real] # a constant known domain Dependent[lambda x: Array[x.dtype, x.shape[1:]] # args are Domains Dependent[lambda x, y: Bint[x.size + y.size]]
- Parameters
fn (callable) – A lambda taking named arguments (in any order) which will be filled in with the domain of the similarly named funsor argument to the decorated function. This lambda should compute a desired resulting domain given domains of arguments.
- class Reals[source]¶
Bases:
object
Type of a real-valued array with known shape:
Reals[()] = Real # scalar Reals[8] # vector of length 8 Reals[3, 3] # 3x3 matrix
- shape = None¶
- find_domain(op, *domains)[source]¶
- find_domain(op: UnaryOp, domain)
- find_domain(op: AstypeOp, domain)
- find_domain(op: ExpOp, domain)
- find_domain(op: LogOp, domain)
- find_domain(op: ReductionOp, domain)
- find_domain(op: ReshapeOp, domain)
- find_domain(op: GetitemOp, lhs_domain, rhs_domain)
- find_domain(op: GetsliceOp, domain)
- find_domain(op: BinaryOp, lhs, rhs)
- find_domain(op: ComparisonOp, lhs, rhs)
- find_domain(op: FloordivOp, lhs, rhs)
- find_domain(op: ModOp, lhs, rhs)
- find_domain(op: MatmulOp, lhs, rhs)
- find_domain(op: AssociativeOp, *domains)
- find_domain(op: WrappedTransformOp, domain)
- find_domain(op: LogAbsDetJacobianOp, domain, codomain)
- find_domain(op: StackOp, parts)
- find_domain(op: CatOp, parts)
- find_domain(op: EinsumOp, operands)
Finds the
Domain
resulting when applyingop
todomains
. :param callable op: An operation. :param Domain *domains: One or more input domains.
Interpretations¶
Interpreter¶
- exception PatternMissingError[source]¶
Bases:
NotImplementedError
- reinterpret(x)[source]¶
Overloaded reinterpretation of a deferred expression.
This handles a limited class of expressions, raising
ValueError
in unhandled cases.- Parameters
x (A funsor or data structure holding funsors.) – An input, typically involving deferred
Funsor
s.- Returns
A reinterpreted version of the input.
- Raises
ValueError
Interpretations¶
- class CallableInterpretation(interpret)[source]¶
Bases:
Interpretation
A simple callable interpretation.
Example usage:
@CallableInterpretation def my_interpretation(cls, *args): return ...
- Parameters
interpret (callable) – A function implementing interpretation.
- class DispatchedInterpretation(name='dispatched')[source]¶
Bases:
Interpretation
An interpretation based on pattern matching.
Example usage:
my_interpretation = DispatchedInterpretation("my_interpretation") # Register a funsor pattern and rule. @my_interpretation.register(...) def my_impl(cls, *args): ... # Use the new interpretation. with my_interpretation: ...
- class Interpretation(name)[source]¶
Bases:
ContextDecorator
,ABC
Abstract base class for Funsor interpretations.
Instances may be used as context managers or decorators.
- Parameters
name (str) – A name used for printing and debugging (required).
- class Memoize(base_interpretation, cache=None)[source]¶
Bases:
Interpretation
Exploits cons-hashing to do implicit common subexpression elimination.
- Parameters
base_interpretation (Interpretation) – The interpretation to memoize.
cache (dict) – An optional temporary cache where results will be memoized.
- class StatefulInterpretation(name='stateful')[source]¶
Bases:
Interpretation
Base class for interpretations with instance-dependent state or parameters.
Example usage:
class MyInterpretation(StatefulInterpretation): def __init__(self, my_param): self.my_param = my_param @MyInterpretation.register(...) def my_impl(interpretation_state, cls, *args): my_param = interpretation_state.my_param ... with MyInterpretation(my_param=0.1): ...
- eager = eager/normalize/reflect¶
Eager exact naive interpretation wherever possible.
- lazy = lazy/reflect¶
Performs substitutions eagerly, but construct lazy funsors for everything else.
- moment_matching = moment_matching/eager/normalize/reflect¶
A moment matching interpretation of
Reduce
expressions. This falls back toeager
in other cases.
- normalize = normalize/reflect¶
Normalize modulo associativity and commutativity, but do not evaluate any numerical operations.
- sequential = sequential/eager/normalize/reflect¶
Eagerly execute ops with known implementations; additonally execute vectorized ops sequentially if no known vectorized implementation exists.
Monte Carlo¶
- class MonteCarlo(*, rng_key=None, **sample_inputs)[source]¶
Bases:
StatefulInterpretation
A Monte Carlo interpretation of
Integrate
expressions. This falls back to the previous interpreter in other cases.- Parameters
rng_key –
Preconditioning¶
- class Precondition(aux_name='aux')[source]¶
Bases:
StatefulInterpretation
Preconditioning interpretation for adjoint computations.
This interpretation is intended to be used once, followed by a call to
combine_subs()
as follows:# Lazily build a factor graph. with reflect: log_joint = Gaussian(...) + ... + Gaussian(...) log_Z = log_joint.reduce(ops.logaddexp) # Run a backward sampling under the precondition interpretation. with Precondition() as p: marginals = adjoint( ops.logaddexp, ops.add, log_Z, batch_vars=p.sample_vars ) combine_subs = p.combine_subs() # Extract samples from Delta distributions. samples = { k: v(**combine_subs) for name, delta in marginals.items() for k, v in funsor.montecarlo.extract_samples(delta).items() }
See
forward_filter_backward_precondition()
for complete usage.- Parameters
aux_name (str) – Name of the auxiliary variable containing white noise.
- combine_subs()[source]¶
Method to create a combining substitution after preconditioning is complete. The returned substitution replaces per-factor auxiliary variables with slices into a single combined auxiliary variable.
- Returns
A substitution indexing each factor-wise auxiliary variable into a single global auxiliary variable.
- Return type
Approximations¶
- argmax_approximate = argmax_approximate¶
Point-approximate at the argmax of the provided guide.
- compute_argmax(model, approx_vars)[source]¶
- compute_argmax(model: Tensor, approx_vars)
- compute_argmax(model: Gaussian, approx_vars)
- compute_argmax(model: Contraction[Union[LogaddexpOp, NullOp], AddOp, frozenset, Tuple[Union[Tensor, Number], Gaussian]], approx_vars)
- compute_argmax(model: Contraction[NullOp, AddOp, frozenset, tuple], approx_vars)
Computes argmax of a funsor.
- laplace_approximate = laplace_approximate¶
Gaussian approximate using the value and Hessian of the model, evaluated at the mode of the guide.
- mean_approximate = mean_approximate¶
Point-approximate at the mean of the provided guide.
Evidence lower bound¶
- class Elbo(guide, approx_vars)[source]¶
Bases:
StatefulInterpretation
Given an approximating
guide
funsor, approximates:model.reduce(ops.logaddexp, approx_vars)
by the lower bound:
Integrate(guide, model - guide, approx_vars)
Funsors¶
Basic Funsors¶
- class Approximate(*args, **kwargs)[source]¶
Bases:
Funsor
Interpretation-specific approximation wrt a set of variables.
The default eager interpretation should be exact. The user-facing interface is the
Funsor.approximate()
method.
- class Cat(name, parts, part_name=None)[source]¶
Bases:
Funsor
Concatenate funsors along an existing input dimension.
- Parameters
- class Funsor(*args, **kwargs)[source]¶
Bases:
object
Abstract base class for immutable functional tensors.
Concrete derived classes must implement
__init__()
methods taking hashable*args
and no optional**kwargs
so as to support cons hashing.Derived classes with
.fresh
variables must implement aneager_subs()
method. Derived classes with.bound
variables must implement an_alpha_convert()
method.- Parameters
inputs (OrderedDict) – A mapping from input name to domain. This can be viewed as a typed context or a mapping from free variables to domains.
output (Domain) – An output domain.
- property dtype¶
- property shape¶
- property requires_grad¶
- sample(sampled_vars, sample_inputs=None, rng_key=None)[source]¶
Create a Monte Carlo approximation to this funsor by replacing functions of
sampled_vars
withDelta
s.The result is a
Funsor
with the same.inputs
and.output
as the original funsor (plussample_inputs
if provided), so that self can be replaced by the sample in expectation computations:y = x.sample(sampled_vars) assert y.inputs == x.inputs assert y.output == x.output exact = (x.exp() * integrand).reduce(ops.add) approx = (y.exp() * integrand).reduce(ops.add)
If
sample_inputs
is provided, this creates a batch of samples.- Parameters
sampled_vars (str, Variable, or set or frozenset thereof.) – A set of input variables to sample.
sample_inputs (OrderedDict) – An optional mapping from variable name to
Domain
over which samples will be batched.rng_key (None or JAX's random.PRNGKey) – a PRNG state to be used by JAX backend to generate random samples
- align(names)[source]¶
Align this funsor to match given
names
. This is mainly useful in preparation for extracting.data
of afunsor.tensor.Tensor
.
- eager_subs(subs)[source]¶
Internal substitution function. This relies on the user-facing
__call__()
method to coerce non-Funsors to Funsors. Once all inputs are Funsors,eager_subs()
implementations can recurse to callSubs
.
- class Independent(*args, **kwargs)[source]¶
Bases:
Funsor
Creates an independent diagonal distribution.
This is equivalent to substitution followed by reduction:
f = ... # a batched distribution assert f.inputs['x_i'] == Reals[4, 5] assert f.inputs['i'] == Bint[3] g = Independent(f, 'x', 'i', 'x_i') assert g.inputs['x'] == Reals[3, 4, 5] assert 'x_i' not in g.inputs assert 'i' not in g.inputs x = Variable('x', Reals[3, 4, 5]) g == f(x_i=x['i']).reduce(ops.add, 'i')
- Parameters
- class Lambda(*args, **kwargs)[source]¶
Bases:
Funsor
Lazy inverse to
ops.getitem
.This is useful to simulate higher-order functions of integers by representing those functions as arrays.
- Parameters
var (Variable) – A variable to bind.
expr (funsor) – A funsor.
- class Number(data, dtype=None)[source]¶
Bases:
Funsor
Funsor backed by a Python number.
- Parameters
data (numbers.Number) – A python number.
dtype – A nonnegative integer or the string “real”.
- class Reduce(*args, **kwargs)[source]¶
Bases:
Funsor
Lazy reduction over multiple variables.
The user-facing interface is the
Funsor.reduce()
method.- Parameters
op (AssociativeOp) – An associative operator.
arg (funsor) – An argument to be reduced.
reduced_vars (frozenset) – A set of variables over which to reduce.
- class Scatter(*args, **kwargs)[source]¶
Bases:
Funsor
Transpose of structurally linear
Subs
, followed byReduce
.For injective scatter operations this should satisfy the equation:
if destin = Scatter(op, subs, source, frozenset()) then source = Subs(destin, subs)
The
reduced_vars
is merely for computational efficiency, and could always be split out into a separate.reduce()
. For example in the following equation, the left hand side uses much less memory than the right hand side:Scatter(op, subs, source, reduced_vars) == Scatter(op, subs, source, frozenset()).reduce(op, reduced_vars)
Warning
This is currently implemented only for injective scatter operations. In particular, this does not allow accumulation behavior like scatter-add.
Note
Scatter(ops.add, ...)
is the funsor analog ofnumpy.add.at()
ortorch.index_put()
orjax.lax.scatter_add()
. For injective substitutions,Scatter(ops.add, ...)
is roughly equivalent to the tensor operation:result = zeros(...) # since zero is the additive unit result[subs] = source
- Parameters
- class Stack(*args, **kwargs)[source]¶
Bases:
Funsor
Stack of funsors along a new input dimension.
- Parameters
- class Slice(name, *args, **kwargs)[source]¶
Bases:
Funsor
Symbolic representation of a Python
slice
object.- Parameters
- class Subs(arg, subs)[source]¶
Bases:
Funsor
Lazy substitution of the form
x(u=y, v=z)
.- Parameters
arg (Funsor) – A funsor being substituted into.
subs (tuple) – A tuple of
(name, value)
pairs, wherename
is a string andvalue
can be coerced to aFunsor
viato_funsor()
.
- class Unary(*args, **kwargs)[source]¶
Bases:
Funsor
Lazy unary operation.
- Parameters
op (Op) – A unary operator.
arg (Funsor) – An argument.
- class Variable(*args, **kwargs)[source]¶
Bases:
Funsor
Funsor representing a single free variable.
- Parameters
name (str) – A variable name.
output (funsor.domains.Domain) – A domain.
- to_data(x, name_to_dim=None, **kwargs)[source]¶
- to_data(x: Funsor, name_to_dim=None)
- to_data(x: Number, name_to_dim=None)
- to_data(x: Tensor, name_to_dim=None)
- to_data(funsor_dist: Distribution, name_to_dim=None)
- to_data(funsor_dist: Independent[Union[Independent, Distribution], str, str, str], name_to_dim=None)
- to_data(funsor_dist: Gaussian, name_to_dim=None)
- to_data(funsor_dist: Contraction[Union[LogaddexpOp, NullOp], AddOp, frozenset, Tuple[Union[Tensor, Number], Gaussian]], name_to_dim=None)
- to_data(funsor_dist: Multinomial, name_to_dim=None)
- to_data(funsor_dist: Delta, name_to_dim=None)
- to_data(expr: Unary[TransformOp, Union[Unary, Variable]], name_to_dim=None)
- to_data(x: Constant, name_to_dim=None)
Extract a python object from a
Funsor
.Raises a
ValueError
if free variables remain or if the funsor is lazy.- Parameters
x – An object, possibly a
Funsor
.name_to_dim (OrderedDict) – An optional inputs hint.
- Returns
A non-funsor equivalent to
x
.- Raises
ValueError if any free variables remain.
- Raises
PatternMissingError if funsor is not fully evaluated.
- to_funsor(x, output=None, dim_to_name=None, **kwargs)[source]¶
- to_funsor(x: Funsor, output=None, dim_to_name=None)
- to_funsor(name: str, output=None)
- to_funsor(x: Number, output=None, dim_to_name=None)
- to_funsor(s: slice, output=None, dim_to_name=None)
- to_funsor(args: tuple, output=None, dim_to_name=None)
- to_funsor(x: generic, output=None, dim_to_name=None)
- to_funsor(x: ndarray, output=None, dim_to_name=None)
- to_funsor(backend_dist: Beta, output=None, dim_to_name=None)
- to_funsor(backend_dist: Cauchy, output=None, dim_to_name=None)
- to_funsor(backend_dist: Chi2, output=None, dim_to_name=None)
- to_funsor(backend_dist: _PyroWrapper_BernoulliProbs, output=None, dim_to_name=None)
- to_funsor(backend_dist: _PyroWrapper_BernoulliLogits, output=None, dim_to_name=None)
- to_funsor(backend_dist: Binomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: Categorical, output=None, dim_to_name=None)
- to_funsor(backend_dist: _PyroWrapper_CategoricalLogits, output=None, dim_to_name=None)
- to_funsor(pyro_dist: Delta, output=None, dim_to_name=None)
- to_funsor(backend_dist: Dirichlet, output=None, dim_to_name=None)
- to_funsor(backend_dist: DirichletMultinomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: Exponential, output=None, dim_to_name=None)
- to_funsor(backend_dist: Gamma, output=None, dim_to_name=None)
- to_funsor(backend_dist: GammaPoisson, output=None, dim_to_name=None)
- to_funsor(backend_dist: Geometric, output=None, dim_to_name=None)
- to_funsor(backend_dist: Gumbel, output=None, dim_to_name=None)
- to_funsor(backend_dist: HalfCauchy, output=None, dim_to_name=None)
- to_funsor(backend_dist: HalfNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Laplace, output=None, dim_to_name=None)
- to_funsor(backend_dist: Logistic, output=None, dim_to_name=None)
- to_funsor(backend_dist: LowRankMultivariateNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Multinomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: MultivariateNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedBeta, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedDirichlet, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedGamma, output=None, dim_to_name=None)
- to_funsor(backend_dist: NonreparameterizedNormal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Normal, output=None, dim_to_name=None)
- to_funsor(backend_dist: Pareto, output=None, dim_to_name=None)
- to_funsor(backend_dist: Poisson, output=None, dim_to_name=None)
- to_funsor(backend_dist: StudentT, output=None, dim_to_name=None)
- to_funsor(backend_dist: Uniform, output=None, dim_to_name=None)
- to_funsor(backend_dist: VonMises, output=None, dim_to_name=None)
- to_funsor(backend_dist: ContinuousBernoulli, output=None, dim_to_name=None)
- to_funsor(backend_dist: FisherSnedecor, output=None, dim_to_name=None)
- to_funsor(backend_dist: NegativeBinomial, output=None, dim_to_name=None)
- to_funsor(backend_dist: OneHotCategorical, output=None, dim_to_name=None)
- to_funsor(backend_dist: RelaxedBernoulli, output=None, dim_to_name=None)
- to_funsor(backend_dist: Weibull, output=None, dim_to_name=None)
- to_funsor(tfm: Transform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: ExpTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: TanhTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: SigmoidTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: _InverseTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(tfm: ComposeTransform, output=None, dim_to_name=None, real_inputs=None)
- to_funsor(backend_dist: ExpandedDistribution, output=None, dim_to_name=None)
- to_funsor(backend_dist: Independent, output=None, dim_to_name=None)
- to_funsor(backend_dist: MaskedDistribution, output=None, dim_to_name=None)
- to_funsor(backend_dist: TransformedDistribution, output=None, dim_to_name=None)
- to_funsor(pyro_dist: Bernoulli, output=None, dim_to_name=None)
- to_funsor(x: ProvenanceTensor, output=None, dim_to_name=None)
- to_funsor(x: Tensor, output=None, dim_to_name=None)
- to_funsor(pyro_dist: FunsorDistribution, output=None, dim_to_name=None)
Convert to a
Funsor
. OnlyFunsor
s and scalars are accepted.- Parameters
x – An object.
output (funsor.domains.Domain) – An optional output hint.
dim_to_name (OrderedDict) – An optional mapping from negative batch dimensions to name strings.
- Returns
A Funsor equivalent to
x
.- Return type
- Raises
ValueError
Delta¶
- class Delta(*args)[source]¶
Bases:
Funsor
Normalized delta distribution binding multiple variables.
There are three syntaxes supported for constructing Deltas:
Delta(((name1, (point1, log_density1)), (name2, (point2, log_density2)), (name3, (point3, log_density3))))
or for a single name:
Delta(name, point, log_density)
or for default
log_density == 0
:Delta(name, point)
- Parameters
terms (tuple) – A tuple of tuples of the form
(name, (point, log_density))
.
Tensor¶
- Einsum(equation, *operands)[source]¶
Wrapper around
torch.einsum()
ornp.einsum()
to operate on real-valued Funsors.Note this operates only on the
output
tensor. To perform sum-product contractions on named dimensions, instead use+
andReduce
.- Parameters
equation (str) – An
torch.einsum()
ornp.einsum()
equation.operands (tuple) – A tuple of input funsors.
- class Function(*args, **kwargs)[source]¶
Bases:
Funsor
Funsor wrapped by a native PyTorch or NumPy function.
Functions are assumed to support broadcasting and can be eagerly evaluated on funsors with free variables of int type (i.e. batch dimensions).
Function
s are usually created via thefunction()
decorator.
- class Tensor(data, inputs=None, dtype='real')[source]¶
Bases:
Funsor
Funsor backed by a PyTorch Tensor or a NumPy ndarray.
This follows the
torch.distributions
convention of arranging named “batch” dimensions on the left and remaining “event” dimensions on the right. The output shape is determined by all remaining dims. For example:data = torch.zeros(5,4,3,2) x = Tensor(data, {"i": Bint[5], "j": Bint[4]}) assert x.output == Reals[3, 2]
Operators like
matmul
and.sum()
operate only on the output shape, and will not change the named inputs.- Parameters
- property requires_grad¶
- new_arange(name, *args, **kwargs)[source]¶
Helper to create a named
torch.arange()
ornp.arange()
funsor. In some cases this can be replaced by a symbolicSlice
.
- align_tensor(new_inputs, x, expand=False)[source]¶
Permute and add dims to a tensor to match desired
new_inputs
.- Parameters
new_inputs (OrderedDict) – A target set of inputs.
x (funsor.terms.Funsor) – A
Tensor
orNumber
.expand (bool) – If False (default), set result size to 1 for any input of
x
not innew_inputs
; if True expand tonew_inputs
size.
- Returns
a number or
torch.Tensor
ornp.ndarray
that can be broadcast to other tensors with inputsnew_inputs
.- Return type
int or float or torch.Tensor or np.ndarray
- align_tensors(*args, **kwargs)[source]¶
Permute multiple tensors before applying a broadcasted op.
This is mainly useful for implementing eager funsor operations.
- Parameters
*args (funsor.terms.Funsor) – Multiple
Tensor
s andNumber
s.expand (bool) – Whether to expand input tensors. Defaults to False.
- Returns
a pair
(inputs, tensors)
where tensors are alltorch.Tensor
s ornp.ndarray
s that can be broadcast together to a single data with giveninputs
.- Return type
- function(*signature)[source]¶
Decorator to wrap a PyTorch/NumPy function, using either type hints or explicit type annotations.
Example:
# Using type hints: @funsor.tensor.function def matmul(x: Reals[3, 4], y: Reals[4, 5]) -> Reals[3, 5]: return torch.matmul(x, y) # Using explicit type annotations: @funsor.tensor.function(Reals[3, 4], Reals[4, 5], Reals[3, 5]) def matmul(x, y): return torch.matmul(x, y) @funsor.tensor.function(Reals[10], Reals[10, 10], Reals[10], Real) def mvn_log_prob(loc, scale_tril, x): d = torch.distributions.MultivariateNormal(loc, scale_tril) return d.log_prob(x)
To support functions that output nested tuples of tensors, specify a nested
Tuple
of output types, for example:@funsor.tensor.function def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: return torch.max(x, dim=-1)
- Parameters
*signature – A sequence if input domains followed by a final output domain or nested tuple of output domains.
- tensordot(x, y, dims)[source]¶
Wrapper around
torch.tensordot()
ornp.tensordot()
to operate on real-valued Funsors.Note this operates only on the
output
tensor. To perform sum-product contractions on named dimensions, instead use+
andReduce
.Arguments should satisfy:
len(x.shape) >= dims len(y.shape) >= dims dims == 0 or x.shape[-dims:] == y.shape[:dims]
Gaussian¶
- class BlockMatrix(shape)[source]¶
Bases:
object
Jit-compatible helper to build blockwise matrices. Syntax is similar to
torch.zeros()
x = BlockMatrix((100, 20, 20)) x[..., 0:4, 0:4] = x11 x[..., 0:4, 6:10] = x12 x[..., 6:10, 0:4] = x12.transpose(-1, -2) x[..., 6:10, 6:10] = x22 x = x.as_tensor() assert x.shape == (100, 20, 20)
- class BlockVector(shape)[source]¶
Bases:
object
Jit-compatible helper to build blockwise vectors. Syntax is similar to
torch.zeros()
x = BlockVector((100, 20)) x[..., 0:4] = x1 x[..., 6:10] = x2 x = x.as_tensor() assert x.shape == (100, 20)
- class Gaussian(white_vec=None, prec_sqrt=None, inputs=None, *, mean=None, info_vec=None, precision=None, scale_tril=None, covariance=None)[source]¶
Bases:
Funsor
Funsor representing a batched Gaussian log-density function.
Gaussians are the internal representation for joint and conditional multivariate normal distributions and multivariate normal likelihoods. Mathematically, a Gaussian represents the quadratic log density function:
f(x) = -0.5 * || x @ prec_sqrt - white_vec ||^2 = -0.5 * < x @ prec_sqrt - white_vec | x @ prec_sqrt - white_vec > = -0.5 * < x | prec_sqrt @ prec_sqrt.T | x> + < x | prec_sqrt | white_vec > - 0.5 ||white_vec||^2
Internally Gaussians use a square root information filter (SRIF) representation consisting of a square root of the precision matrix
prec_sqrt
and a vector in the whitened spacewhite_vec
. This representation allows space-efficient construction of Gaussians with incomplete information, i.e. with zero eigenvalues in the precision matrix. These incomplete log densities arise when making low-dimensional observations of higher-dimensional hidden state. Sampling and marginalization are supported only for full-rank Gaussians or full-rank subsets of Gaussians. See therank()
andis_full_rank()
properties.Note
Gaussian
s are not normalized probability distributions, rather they are canonicalized to evaluate to zero log density at their maximum:f(prec_sqrt \ white_vec) = 0
. Not only are Gaussians non-normalized, but they may be rank deficient and non-normalizable, in which case sampling and marginalization are supported only un full-rank subsets of variables.- Parameters
white_vec (torch.Tensor) – An batched white noise vector, where
white_vec = prec_sqrt.T @ mean
. Alternatively you can specify one of the kwargsmean
orinfo_vec
, which will be converted towhite_vec
.prec_sqrt (torch.Tensor) – A batched square root of the positive semidefinite precision matrix. This need not be square, and typically has shape
prec_sqrt.shape == white_vec.shape[:-1] + (dim, rank)
, wheredim
is the total flattened size of real inputs andrank = white_vec.shape[-1]
. Alternatively you can specify one of the kwargsprecision
,covariance
, orscale_tril
, which will be converted toprec_sqrt
.inputs (OrderedDict) – Mapping from name to
Domain
.
- compression_threshold = 2¶
- classmethod set_compression_threshold(threshold: float)[source]¶
Context manager to set rank compression threshold.
To save space Gaussians compress wide
prec_sqrt
matrices down to square. However compression uses a QR decomposition which can be expensive and which has unstable gradients when the resulting precision matrix is rank deficient. To balance space and time costs and numerical stability, compression is trigger only onprec_sqrt
matrices whose width to height ratio is greater thanthreshold
.- Parameters
threshold (float) – Defaults to 2. To optimize for space and deterministic computations, set
threshold = 1
. To optimize for fewest QR decompositions and numerical stability, setthreshold = math.inf
.
- property rank¶
- property is_full_rank¶
Joint¶
Contraction¶
- class Contraction(*args, **kwargs)[source]¶
Bases:
Funsor
Declarative representation of a finitary sum-product operation.
After normalization via the
normalize()
interpretation contractions will canonically order their terms by type:Delta, Number, Tensor, Gaussian
- GaussianMixture¶
alias of
Contraction
Integrate¶
Constant¶
- class ConstantMeta(name, bases, dct)[source]¶
Bases:
FunsorMeta
Wrapper to convert
const_inputs
to a tuple.
- class Constant(const_inputs, arg)[source]¶
Bases:
Funsor
Funsor that is constant wrt
const_inputs
.Constant
can be used for provenance tracking.Examples:
a = Constant(OrderedDict(x=Real, y=Bint[3]), Number(0)) a(y=1) # returns Constant(OrderedDict(x=Real), Number(0)) a(x=2, y=1) # returns Number(0) d = Tensor(torch.tensor([1, 2, 3]))["y"] a + d # returns Constant(OrderedDict(x=Real), d) c = Constant(OrderedDict(x=Bint[3]), Number(1)) c.reduce(ops.add, "x") # returns Number(3)
- Parameters
const_inputs (dict) – A mapping from input name (str) to datatype (
funsor.domain.Domain
).arg (funsor) – A funsor that is constant wrt to const_inputs.
Optimizer¶
Adjoint Algorithms¶
- class AdjointTape[source]¶
Bases:
Interpretation
- adjoint_contract_unary(adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, arg)[source]¶
- adjoint_contract_generic(adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms)[source]¶
Sum-Product Algorithms¶
- partial_unroll(factors, eliminate=frozenset({}), plate_to_step={})[source]¶
Performs partial unrolling of plated factor graphs to standard factor graphs. Only plates with history={0, 1} are supported.
For plates (history=0) unrolling operation appends
_{i}
suffix to variable names for indexi
in the plate (e.g., “x”->”x_0” for i=0). For markov dimensions (history=1) unrolling operation renames the suffixesvar_prev
tovar_{i}
andvar_curr
tovar_{i+1}
for indexi
(e.g., “x_prev”->”x_0” and “x_curr”->”x_1” for i=0). Markov vars are assumed to have names that followvar_suffix
formatting and specificallyvar_0
for the initial factor (e.g.,("x_0", "x_prev", "x_curr")
for history=1).- Parameters
eliminate (frozenset) – A set of free variables to unroll, including both sum variables and product variable.
plate_to_step (dict) – A dict mapping markov dimensions to
step
collections that contain ordered sequences of Markov variable names (e.g.,{"time": frozenset({("x_0", "x_prev", "x_curr")})}
). Plates are passed with an emptystep
.
- Returns
a list of partially unrolled Funsors, a frozenset of partially unrolled variable names, and a frozenset of remaining plates.
- partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset({}), plates=frozenset({}), pedantic=False)[source]¶
Performs partial sum-product contraction of a collection of factors.
- Returns
a list of partially contracted Funsors.
- Return type
- dynamic_partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset({}), plate_to_step={})[source]¶
Generalization of the tensor variable elimination algorithm of
funsor.sum_product.partial_sum_product()
to handle higer-order markov dimensions in addition to plate dimensions. Markov dimensions in transition factors are eliminated efficiently using the parallel-scan algorithm infunsor.sum_product.sarkka_bilmes_product()
. The resulting factors are then combined with the initial factors and final states are eliminated. Therefore, when Markov dimension is eliminatedfactors
has to contain initial factors and transition factors.- Parameters
sum_op (AssociativeOp) – A semiring sum operation.
prod_op (AssociativeOp) – A semiring product operation.
eliminate (frozenset) – A set of free variables to eliminate, including both sum variables and product variable.
plate_to_step (dict) – A dict mapping markov dimensions to
step
collections that contain ordered sequences of Markov variable names (e.g.,{"time": frozenset({("x_0", "x_prev", "x_curr")})}
). Plates are passed with an emptystep
.
- Returns
a list of partially contracted Funsors.
- Return type
- modified_partial_sum_product(sum_op, prod_op, factors, eliminate=frozenset({}), plate_to_step={})[source]¶
Generalization of the tensor variable elimination algorithm of
funsor.sum_product.partial_sum_product()
to handle markov dimensions in addition to plate dimensions. Markov dimensions in transition factors are eliminated efficiently using the parallel-scan algorithm infunsor.sum_product.sequential_sum_product()
. The resulting factors are then combined with the initial factors and final states are eliminated. Therefore, when Markov dimension is eliminatedfactors
has to contain a pairs of initial factors and transition factors.- Parameters
sum_op (AssociativeOp) – A semiring sum operation.
prod_op (AssociativeOp) – A semiring product operation.
eliminate (frozenset) – A set of free variables to eliminate, including both sum variables and product variable.
plate_to_step (dict) – A dict mapping markov dimensions to
step
collections that contain ordered sequences of Markov variable names (e.g.,{"time": frozenset({("x_0", "x_prev", "x_curr")})}
). Plates are passed with an emptystep
.
- Returns
a list of partially contracted Funsors.
- Return type
- sum_product(sum_op, prod_op, factors, eliminate=frozenset({}), plates=frozenset({}), pedantic=False)[source]¶
Performs sum-product contraction of a collection of factors.
- Returns
a single contracted Funsor.
- Return type
- sequential_sum_product(sum_op, prod_op, trans, time, step)[source]¶
For a funsor
trans
with dimensionstime
,prev
andcurr
, computes a recursion equivalent to:tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) .reduce(sum_op, "drop")
but does so efficiently in parallel in O(log(time)).
- Parameters
sum_op (AssociativeOp) – A semiring sum operation.
prod_op (AssociativeOp) – A semiring product operation.
trans (Funsor) – A transition funsor.
time (Variable) – The time input dimension.
step (dict) – A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names.
- mixed_sequential_sum_product(sum_op, prod_op, trans, time, step, num_segments=None)[source]¶
For a funsor
trans
with dimensionstime
,prev
andcurr
, computes a recursion equivalent to:tail_time = 1 + arange("time", trans.inputs["time"].size - 1) tail = sequential_sum_product(sum_op, prod_op, trans(time=tail_time), time, {"prev": "curr"}) return prod_op(trans(time=0)(curr="drop"), tail(prev="drop")) .reduce(sum_op, "drop")
by mixing parallel and serial scan algorithms over
num_segments
segments.- Parameters
sum_op (AssociativeOp) – A semiring sum operation.
prod_op (AssociativeOp) – A semiring product operation.
trans (Funsor) – A transition funsor.
time (Variable) – The time input dimension.
step (dict) – A dict mapping previous variables to current variables. This can contain multiple pairs of prev->curr variable names.
num_segments (int) – number of segments for the first stage
- sarkka_bilmes_product(sum_op, prod_op, trans, time_var, global_vars=frozenset({}), num_periods=1)[source]¶
- class MarkovProductMeta(name, bases, dct)[source]¶
Bases:
FunsorMeta
Wrapper to convert
step
to a tuple and fill in defaultstep_names
.
- class MarkovProduct(sum_op, prod_op, trans, time, step, step_names=None)[source]¶
Bases:
Funsor
Lazy representation of
sequential_sum_product()
.- Parameters
sum_op (AssociativeOp) – A marginalization op.
prod_op (AssociativeOp) – A Bayesian fusion op.
trans (Funsor) – A sequence of transition factors, usually varying along the
time
input.step (dict) – A str-to-str mapping of “previous” inputs of
trans
to “current” inputs oftrans
.step_names (dict) – Optional, for internal use by alpha conversion.
Affine Pattern Matching¶
- affine_inputs(fn)[source]¶
Returns a [sound sub]set of real inputs of
fn
wrt whichfn
is known to be affine.
- extract_affine(fn)[source]¶
Extracts an affine representation of a funsor, satisfying:
x = ... const, coeffs = extract_affine(x) y = sum(Einsum(eqn, coeff, Variable(var, coeff.output)) for var, (coeff, eqn) in coeffs.items()) assert_close(y, x) assert frozenset(coeffs) == affine_inputs(x)
The
coeffs
will have one key per input wrt whichfn
is known to be affine (viaaffine_inputs()
), andconst
andcoeffs.values
will all be constant wrt these inputs.The affine approximation is computed by ev evaluating
fn
at zero and each basis vector. To improve performance, users may want to run under theMemoize()
interpretation.
Funsor Factory¶
- class Fresh(fn)[source]¶
Bases:
object
Type hint for
make_funsor()
decorated functions. This provides hints for fresh variables (names) and the return type.Examples:
Fresh[Real] # a constant known domain Fresh[lambda x: Array[x.dtype, x.shape[1:]] # args are Domains Fresh[lambda x, y: Bint[x.size + y.size]]
- Parameters
fn (callable) – A lambda taking named arguments (in any order) which will be filled in with the domain of the similarly named funsor argument to the decorated function. This lambda should compute a desired resulting domain given domains of arguments.
- class Bound[source]¶
Bases:
object
Type hint for
make_funsor()
decorated functions. This provides hints for bound variables (names).
- class Has(bound)[source]¶
Bases:
object
Type hint for
make_funsor()
decorated functions.This hint asserts that a set of
Bound
variables always appear in the.inputs
of the annotated argument.For example, we could write a named
matmul
function that asserts that both arguments always contain the reduced input, and cannot be constant with respect to that input:@make_funsor def MatMul( x: Has[{"i"}], y: Has[{"i"}], i: Bound, ) -> Fresh[lambda x: x]: return (x * y).reduce(ops.add, i)
Here the string
"i"
in the annotations forx
andy
refer to the argumenti
of ourMatMul
function, which is known to beBound
(i.e it does not appear in the.inputs
of evaluatingMatmul(x, y, "i")
.Warning
This annotation is experimental and may be removed in the future.
Note that because Funsor is inherently extensional, violating a Has constraint only raises a
SyntaxWarning
rather than a fullTypeError
and even then only under thereflect()
interpretation.As such,
Has
annotations should be used sparingly, reserved for cases where the programmer has complete control over the inputs to a function and knows that an argument will always depend on a bound variable, e.g. when writing one-off Funsor terms to describe custom layers in a neural network.- Parameters
bound (set) – A
set
of strings of names ofBound
arguments of amake_funsor()
-decorated function.
- make_funsor(fn)[source]¶
Decorator to dynamically create a subclass of
Funsor
, together with a single default eager pattern.This infers inputs, outputs, fresh, and bound variables from type hints follow the following convention:
Funsor inputs are typed
Funsor
.Bound variable inputs (names) are typed
Bound
.Fresh variable inputs (names) are typed
Fresh
together with lambda to compute the dependent domain.Ground value inputs (e.g. Python ints) are typed
Value
together with their actual data type, e.g.Value[int]
.The return value is typed
Fresh
together with a lambda to compute the dependent return domain.
For example to unflatten a single coordinate into a pair of coordinates we could define:
@make_funsor def Unflatten( x: Funsor, i: Bound, i_over_2: Fresh[lambda i: Bint[i.size // 2]], i_mod_2: Fresh[lambda: Bint[2]], ) -> Fresh[lambda x: x]: assert i.output.size % 2 == 0 return x(**{i.name: i_over_2 * Number(2, 3) + i_mod_2})
- Parameters
fn (callable) – A type annotated function of Funsors.
- Return type
subclas of
Funsor
Testing Utiltites¶
- class ActualExpected(actual, expected)[source]¶
Bases:
LazyComparison
Lazy string formatter for test assertions.
- random_tensor(inputs, output=Real)[source]¶
Creates a random
funsor.tensor.Tensor
with given inputs and output.
- random_gaussian(inputs)[source]¶
Creates a random
funsor.gaussian.Gaussian
with given inputs.
Typing Utiltites¶
- class GenericTypeMeta(name, bases, dct)[source]¶
Bases:
type
Metaclass to support subtyping with parameters for pattern matching, e.g.
Number[int, int]
.
- deep_isinstance(obj, cls)[source]¶
Enhanced version of
isinstance()
that can handle basic structuredtyping
types, including Funsor terms and otherGenericTypeMeta
instances,Union
,Tuple
, andFrozenSet
.Does not support
TypeVar
, arbitraryGeneric
, forward references, or mutable generic collection types likeList
. Will attempt to fall back toisinstance()
when it encounters an unsupported type inobj
orcls
.Usage:
x = (1, ("a", "b")) assert deep_isinstance(x, typing.Tuple[int, tuple]) assert deep_isinstance(x, typing.Tuple[typing.Any, typing.Tuple[str, ...]])
- Parameters
obj – An object that may be an instance of
cls
.cls – A class that may be a parent class of
obj
.
- deep_issubclass(subcls, cls)[source]¶
Enhanced version of
issubclass()
that can handle structured types, including Funsor terms,Tuple
, andFrozenSet
.Does not support more advanced
typing
features such asTypeVar
, arbitraryGeneric
subtypes, forward references, or mutable collection types likeList
. Will attempt to fall back toissubclass()
when it encounters a type insubcls
orcls
that it does not understand.Usage:
class A: pass class B(A): pass assert deep_issubclass(typing.Tuple[int, B], typing.Tuple[int, A]) assert not deep_issubclass(typing.Tuple[int, A], typing.Tuple[int, B]) assert deep_issubclass(typing.Tuple[A, A], typing.Tuple[A, ...]) assert not deep_issubclass(typing.Tuple[B], typing.Tuple[A, ...])
- Parameters
subcls – A class that may be a subclass of
cls
.cls – A class that may be a parent class of
subcls
.
- deep_type(obj)[source]¶
- deep_type(obj: tuple)
- deep_type(obj: frozenset)
An enhanced version of
type()
that reconstructs structuredtyping`
types for a limited set of immutable data structures, notablytuple
andfrozenset
. Mostly intended for internal use in Funsor interpretation pattern-matching.Example:
assert deep_type((1, ("a",))) is typing.Tuple[int, typing.Tuple[str]] assert deep_type(frozenset(["a"])) is typing.FrozenSet[str]
- get_type_hints(obj, globalns=None, localns=None)[source]¶
Return type hints for an object.
This is often the same as obj.__annotations__, but it handles forward references encoded as string literals, and if necessary adds Optional[t] if a default value equal to None is set.
The argument may be a module, class, method, or function. The annotations are returned as a dictionary. For classes, annotations include also inherited members.
TypeError is raised if the argument is not of a type that can contain annotations, and an empty dictionary is returned if no annotations are present.
BEWARE – the behavior of globalns and localns is counterintuitive (unless you are familiar with how eval() and exec() work). The search order is locals first, then globals.
If no dict arguments are passed, an attempt is made to use the globals from obj (or the respective module’s globals for classes), and these are also used as the locals. If the object does not appear to have globals, an empty dictionary is used.
If one dict argument is passed, it is used for both globals and locals.
If two dict arguments are passed, they specify globals and locals, respectively.
- register_subclasscheck(cls)[source]¶
Decorator for registering a custom
__subclasscheck__
method forcls
which is only ever invoked indeep_issubclass()
.This is primarily intended for working with the
typing
library at runtime. Prefer overriding__subclasscheck__
in the usual way with a metaclass where possible.
Recipes using Funsor¶
This module provides a number of high-level algorithms using Funsor.
- forward_filter_backward_rsample(factors: Dict[str, Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], sample_inputs: Dict[str, type] = {}, rng_key=None)[source]¶
A forward-filter backward-batched-reparametrized-sample algorithm for use in variational inference. The motivating use case is performing Gaussian tensor variable elimination over structured variational posteriors.
- Parameters
factors (dict) – A dictionary mapping sample site name to a Funsor factor created at that sample site.
frozenset – A set of names of latent variables to marginalize and plates to aggregate.
plates – A set of names of plates to aggregate.
sample_inputs (dict) – An optional dict of enclosing sample indices over which samples will be drawn in batch.
rng_key – A random number key for the JAX backend.
- Returns
A pair
samples:Dict[str, Tensor], log_prob: Tensor
of samples and log density evaluated at each of those samples. Ifsample_inputs
is nonempty, both outputs will be batched.- Return type
- forward_filter_backward_precondition(factors: Dict[str, Funsor], eliminate: FrozenSet[str], plates: FrozenSet[str], aux_name: str = 'aux')[source]¶
A forward-filter backward-precondition algorithm for use in variational inference or preconditioning in Hamiltonian Monte Carlo. The motivating use case is performing Gaussian tensor variable elimination over structured variational posteriors, and optionally using the learned posterior to determine momentum in HMC.
- Parameters
factors (dict) – A dictionary mapping sample site name to a Funsor factor created at that sample site.
frozenset – A set of names of latent variables to marginalize and plates to aggregate.
plates – A set of names of plates to aggregate.
aux_name (str) – Name of the auxiliary variable containing white noise.
- Returns
A pair
samples:Dict[str, Tensor], log_prob: Tensor
of samples and log density evaluated at each of those samples. Both outputs depend on a vector named byaux_name
, e.g.aux: Reals[d]
whered
is the total number of elements in eliminated variables.- Return type
Pyro-Compatible Distributions¶
This interface provides a number of PyTorch-style distributions that use
funsors internally to perform inference. These high-level objects are based on
a wrapping class: FunsorDistribution
which
wraps a funsor in a PyTorch-distributions-compatible interface.
FunsorDistribution
objects can be used
directly in Pyro models (using the standard Pyro backend).
FunsorDistribution Base Class¶
- class FunsorDistribution(funsor_dist, batch_shape=torch.Size([]), event_shape=torch.Size([]), dtype='real', validate_args=None)[source]¶
Bases:
TorchDistribution
Distribution
wrapper around aFunsor
for use in Pyro code. This is typically used as a base class for specific funsor inference algorithms wrapped in a distribution interface.- Parameters
funsor_dist (funsor.terms.Funsor) – A funsor with an input named “value” that is treated as a random variable. The distribution should be normalized over “value”.
batch_shape (torch.Size) – The distribution’s batch shape. This must be in the same order as the input of the
funsor_dist
, but may contain extra dims of size 1.event_shape – The distribution’s event shape.
- arg_constraints = {}¶
- property support¶
Hidden Markov Models¶
- class DiscreteHMM(initial_logits, transition_logits, observation_dist, validate_args=None)[source]¶
Bases:
FunsorDistribution
Hidden Markov Model with discrete latent state and arbitrary observation distribution. This uses [1] to parallelize over time, achieving O(log(time)) parallel complexity.
The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_logits
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:# homogeneous + homogeneous case: event_shape = (1,) + observation_dist.event_shape
This class should be interchangeable with
pyro.distributions.hmm.DiscreteHMM
.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
- Parameters
initial_logits (Tensor) – A logits tensor for an initial categorical distribution over latent states. Should have rightmost size
state_dim
and be broadcastable tobatch_shape + (state_dim,)
.transition_logits (Tensor) – A logits tensor for transition conditional distributions between latent states. Should have rightmost shape
(state_dim, state_dim)
(old, new), and be broadcastable tobatch_shape + (num_steps, state_dim, state_dim)
.observation_dist (Distribution) – A conditional distribution of observed data conditioned on latent state. The
.batch_shape
should have rightmost sizestate_dim
and be broadcastable tobatch_shape + (num_steps, state_dim)
. The.event_shape
may be arbitrary.
- property has_rsample¶
- class GaussianHMM(initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None)[source]¶
Bases:
FunsorDistribution
Hidden Markov Model with Gaussians for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure
log_prob()
is differentiable.This corresponds to the generative model:
z = initial_distribution.sample() x = [] for t in range(num_steps): z = z @ transition_matrix + transition_dist.sample() x.append(z @ observation_matrix + observation_dist.sample())
The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_dist
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:event_shape = (1, obs_dim) # homogeneous + homogeneous case
This class should be compatible with
pyro.distributions.hmm.GaussianHMM
, but additionally supports funsoradjoint
algorithms.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
- Variables
- Parameters
initial_dist (MultivariateNormal) – A distribution over initial states. This should have batch_shape broadcastable to
self.batch_shape
. This should have event_shape(hidden_dim,)
.transition_matrix (Tensor) – A linear transformation of hidden state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim)
where the rightmost dims are ordered(old, new)
.transition_dist (MultivariateNormal) – A process noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim,)
.transition_matrix – A linear transformation from hidden to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim)
.observation_dist (MultivariateNormal or Independent of Normal) – An observation noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(obs_dim,)
.
- has_rsample = True¶
- arg_constraints = {}¶
- class GaussianMRF(initial_dist, transition_dist, observation_dist, validate_args=None)[source]¶
Bases:
FunsorDistribution
Temporal Markov Random Field with Gaussian factors for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure
log_prob()
is differentiable.The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_dist
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:event_shape = (1, obs_dim) # homogeneous + homogeneous case
This class should be compatible with
pyro.distributions.hmm.GaussianMRF
, but additionally supports funsoradjoint
algorithms.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
- Variables
- Parameters
initial_dist (MultivariateNormal) – A distribution over initial states. This should have batch_shape broadcastable to
self.batch_shape
. This should have event_shape(hidden_dim,)
.transition_dist (MultivariateNormal) – A joint distribution factor over a pair of successive time steps. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim + hidden_dim,)
(old+new).observation_dist (MultivariateNormal) – A joint distribution factor over a hidden and an observed state. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim + obs_dim,)
.
- has_rsample = True¶
- class SwitchingLinearHMM(initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None)[source]¶
Bases:
FunsorDistribution
Switching Linear Dynamical System represented as a Hidden Markov Model.
This corresponds to the generative model:
z = Categorical(logits=initial_logits).sample() y = initial_mvn[z].sample() x = [] for t in range(num_steps): z = Categorical(logits=transition_logits[t, z]).sample() y = y @ transition_matrix[t, z] + transition_mvn[t, z].sample() x.append(y @ observation_matrix[t, z] + observation_mvn[t, z].sample())
Viewed as a dynamic Bayesian network:
z[t-1] ----> z[t] ---> z[t+1] Discrete latent class | \ | \ | \ | y[t-1] ----> y[t] ----> y[t+1] Gaussian latent state | / | / | / V / V / V / x[t-1] x[t] x[t+1] Gaussian observation
Let
class
be the latent class,state
be the latent multivariate normal state, andvalue
be the observed multivariate normal value.- Parameters
initial_logits (Tensor) – Represents
p(class[0])
.initial_mvn (MultivariateNormal) – Represents
p(state[0] | class[0])
.transition_logits (Tensor) – Represents
p(class[t+1] | class[t])
.transition_matrix (Tensor) –
transition_mvn (MultivariateNormal) – Together with
transition_matrix
, this representsp(state[t], state[t+1] | class[t])
.observation_matrix (Tensor) –
observation_mvn (MultivariateNormal) – Together with
observation_matrix
, this representsp(value[t+1], state[t+1] | class[t+1])
.exact (bool) – If True, perform exact inference at cost exponential in
num_steps
. If False, use amoment_matching()
approximation and use parallel scan algorithm to reduce parallel complexity to logarithmic innum_steps
. Defaults to False.
- has_rsample = True¶
- arg_constraints = {}¶
- filter(value)[source]¶
Compute posterior over final state given a sequence of observations.
- Parameters
value (Tensor) – A sequence of observations.
- Returns
A posterior distribution over latent states at the final time step, represented as a pair
(cat, mvn)
, whereCategorical
distribution over mixture components andmvn
is aMultivariateNormal
with rightmost batch dimension ranging over mixture components. This can then be used to initialize a sequential Pyro model for prediction.- Return type
Conversion Utilities¶
This module follows a convention for converting between funsors and PyTorch distribution objects. This convention is compatible with NumPy/PyTorch-style broadcasting. Following PyTorch distributions (and Tensorflow distributions), we consider “event shapes” to be on the right and broadcast-compatible “batch shapes” to be on the left.
This module also aims to be forgiving in inputs and pedantic in outputs:
methods accept either the superclass torch.distributions.Distribution
objects or the subclass pyro.distributions.TorchDistribution
objects.
Methods return only the narrower subclass
pyro.distributions.TorchDistribution
objects.
- tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype='real')[source]¶
Convert a
torch.Tensor
to afunsor.tensor.Tensor
.Note this should not touch data, but may trigger a
torch.Tensor.reshape()
op.- Parameters
tensor (torch.Tensor) – A PyTorch tensor.
event_inputs (tuple) – A tuple of names for rightmost tensor dimensions. If
tensor
has these names, they will be converted toresult.inputs
.event_output (int) – The number of tensor dimensions assigned to
result.output
. These must be on the right of anyevent_input
dimensions.
- Returns
A funsor.
- Return type
- funsor_to_tensor(funsor_, ndims, event_inputs=())[source]¶
Convert a
funsor.tensor.Tensor
to atorch.Tensor
.Note this should not touch data, but may trigger a
torch.Tensor.reshape()
op.- Parameters
funsor (funsor.tensor.Tensor) – A funsor.
ndims (int) – The number of result dims,
== result.dim()
.event_inputs (tuple) – Names assigned to rightmost dimensions.
- Returns
A PyTorch tensor.
- Return type
- dist_to_funsor(pyro_dist, event_inputs=())[source]¶
Convert a PyTorch distribution to a Funsor.
- Parameters
torch.distribution.Distribution – A PyTorch distribution.
- Returns
A funsor.
- Return type
- mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs={})[source]¶
Convert a joint
torch.distributions.MultivariateNormal
distribution into aFunsor
with multiple real inputs.This should satisfy:
sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0]
- Parameters
pyro_dist (torch.distributions.MultivariateNormal) – A multivariate normal distribution over one or more variables of real or vector or tensor type.
event_inputs (tuple) – A tuple of names for rightmost dimensions. These will be assigned to
result.inputs
of typeBint
.real_inputs (OrderedDict) – A dict mapping real variable name to appropriately sized
Real
. The sum of all.numel()
of all real inputs should be equal to thepyro_dist
dimension.
- Returns
A funsor with given
real_inputs
and possibly additional Bint inputs.- Return type
- funsor_to_mvn(gaussian, ndims, event_inputs=())[source]¶
Convert a
Funsor
to apyro.distributions.MultivariateNormal
, dropping the normalization constant.- Parameters
gaussian (funsor.gaussian.Gaussian or funsor.joint.Joint) – A Gaussian funsor.
ndims (int) – The number of batch dimensions in the result.
event_inputs (tuple) – A tuple of names to assign to rightmost dimensions.
- Returns
a multivariate normal distribution.
- Return type
- funsor_to_cat_and_mvn(funsor_, ndims, event_inputs)[source]¶
Converts a labeled gaussian mixture model to a pair of distributions.
- Parameters
funsor (funsor.joint.Joint) – A Gaussian mixture funsor.
ndims (int) – The number of batch dimensions in the result.
- Returns
A pair
(cat, mvn)
, wherecat
is aCategorical
distribution over mixture components andmvn
is aMultivariateNormal
with rightmost batch dimension ranging over mixture components.
- matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name='value_x', y_name='value_y')[source]¶
Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:
y = x @ matrix + mvn.sample()
The result is a non-normalized Gaussian funsor with two real inputs,
x_name
andy_name
, corresponding to a conditional distribution of real vectory` given real vector ``x
.- Parameters
matrix (torch.Tensor) – A matrix with rightmost shape
(x_size, y_size)
.mvn (torch.distributions.MultivariateNormal or torch.distributions.Independent of torch.distributions.Normal) – A multivariate normal distribution with
event_shape == (y_size,)
.event_dims (tuple) – A tuple of names for rightmost dimensions. These will be assigned to
result.inputs
of typeBint
.x_name (str) – The name of the
x
random variable.y_name (str) – The name of the
y
random variable.
- Returns
A funsor with given
real_inputs
and possibly additional Bint inputs.- Return type
Distribution Funsors¶
This interface provides a number of standard normalized probability distributions implemented as funsors.
- class Distribution(*args, **kwargs)[source]¶
Bases:
Funsor
Funsor backed by a PyTorch/JAX distribution object.
- Parameters
*args – Distribution-dependent parameters. These can be either funsors or objects that can be coerced to funsors via
to_funsor()
. See derived classes for details.
- dist_class = 'defined by derived classes'¶
- property has_enumerate_support¶
- class Beta(*args, **kwargs)¶
Bases:
Distribution
- class Cauchy(*args, **kwargs)¶
Bases:
Distribution
- class Chi2(*args, **kwargs)¶
Bases:
Distribution
- class BernoulliProbs(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
_PyroWrapper_BernoulliProbs
- class BernoulliLogits(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
_PyroWrapper_BernoulliLogits
- class Binomial(*args, **kwargs)¶
Bases:
Distribution
- class Categorical(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
Categorical
- class CategoricalLogits(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
_PyroWrapper_CategoricalLogits
- class Delta(*args, **kwargs)¶
Bases:
Distribution
- class Dirichlet(*args, **kwargs)¶
Bases:
Distribution
- class DirichletMultinomial(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
DirichletMultinomial
- class Exponential(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
Exponential
- class Gamma(*args, **kwargs)¶
Bases:
Distribution
- class GammaPoisson(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
GammaPoisson
- class Geometric(*args, **kwargs)¶
Bases:
Distribution
- class Gumbel(*args, **kwargs)¶
Bases:
Distribution
- class HalfCauchy(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
HalfCauchy
- class HalfNormal(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
HalfNormal
- class Laplace(*args, **kwargs)¶
Bases:
Distribution
- class Logistic(*args, **kwargs)¶
Bases:
Distribution
- class LowRankMultivariateNormal(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
LowRankMultivariateNormal
- class Multinomial(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
Multinomial
- class MultivariateNormal(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
MultivariateNormal
- class NonreparameterizedBeta(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
NonreparameterizedBeta
- class NonreparameterizedDirichlet(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
NonreparameterizedDirichlet
- class NonreparameterizedGamma(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
NonreparameterizedGamma
- class NonreparameterizedNormal(*args, **kwargs)¶
Bases:
Distribution
- dist_class¶
alias of
NonreparameterizedNormal
- class Normal(*args, **kwargs)¶
Bases:
Distribution
- class Pareto(*args, **kwargs)¶
Bases:
Distribution
- class Poisson(*args, **kwargs)¶
Bases:
Distribution
- class StudentT(*args, **kwargs)¶
Bases:
Distribution
- class Uniform(*args, **kwargs)¶
Bases:
Distribution
- class VonMises(*args, **kwargs)¶
Bases:
Distribution
Mini-Pyro Interface¶
This interface provides a backend for the Pyro probabilistic programming
language. This interface is intended to be used indirectly by writing standard
Pyro code and setting pyro_backend("funsor")
. See examples/minipyro.py for
example usage.
Mini Pyro¶
This file contains a minimal implementation of the Pyro Probabilistic
Programming Language. The API (method signatures, etc.) match that of
the full implementation as closely as possible. This file is independent
of the rest of Pyro, with the exception of the pyro.distributions
module.
An accompanying example that makes use of this implementation can be found at examples/minipyro.py.
- class CondIndepStackFrame(name, size, dim)¶
Bases:
tuple
- property dim¶
Alias for field number 2
- property name¶
Alias for field number 0
- property size¶
Alias for field number 1
- class ClippedAdam(optim_args)[source]¶
Bases:
PyroOptim
- TorchOptimizer¶
alias of
ClippedAdam
Einsum Interface¶
This interface implements tensor variable elimination among tensors. In particular it does not implement continuous variable elimination.
- naive_plated_einsum(eqn, *terms, **kwargs)[source]¶
Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019])
- [Obermeyer et al 2019] Obermeyer, F., Bingham, E., Jankowiak, M., Chiu, J.,
Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for Plated Factor Graphs, 2019
- einsum(eqn, *terms, **kwargs)[source]¶
Top-level interface for optimized tensor variable elimination.
- Parameters
equation (str) – An einsum equation.
*terms (funsor.terms.Funsor) – One or more operands.
plates (set) – Optional keyword argument denoting which funsor dimensions are plate dimensions. Among all input dimensions (from terms): dimensions in plates but not in outputs are product-reduced; dimensions in neither plates nor outputs are sum-reduced.
Compiler & Tracer¶
- lower(expr: Funsor) Funsor [source]¶
Lower a funsor expression: - eliminate bound variables - convert Contraction to Binary
- trace_function(fn, kwargs: dict, *, allow_constants=False)[source]¶
Traces function to an
OpProgram
that runs on backend values.Example:
# Create a function involving ops. def fn(a, b, x): return ops.add(ops.matmul(a, x), b) # Evaluate via Funsor substitution. data = dict(a=randn(3, 3), b=randn(3), x=randn(3)) expected = fn(**data) # Alternatively evaluate via a program. program = trace_function(expr, data) actual = program(**data) assert (acutal == expected).all()
- class OpProgram(constants, inputs, operations)[source]¶
Bases:
object
Backend program for evaluating a symbolic funsor expression.
Programs depend on the funsor library only via
funsor.ops
and op registrations; program evaluation does not involve funsor interpretation or rewriting. Programs can be pickled and unpickled.- Parameters
expr (iterable) – A list of built-in constants (leaves).
inputs (iterable) – A list of string names of program inputs (leaves).
operations (iterable) – A list of program operations defining non-leaf nodes in the program dag. Each operations is a tuple
(op, arg_ids)
where op is a funsor op andarg_ids
is a tuple of positions of values, starting from zero and counting: constants, inputs, and operation outputs.
Named tensor notation with funsors (Part 1)¶
Introduction¶
Mathematical notation with named axes introduced in Named Tensor Notation (Chiang, Rush, Barak 2021) improves the readability of mathematical formulas involving multidimensional arrays. This includes tensor operations such as elementwise operations, reductions, contractions, renaming, indexing, and broadcasting. In this tutorial we translate examples from Named Tensor Notation into funsors to demonstrate the implementation of these operations in funsor library and familiarize readers with funsor syntax. Part 1 covers examples from 2 Informal Overview, 3.4.2 Advanced Indexing, and 5 Formal Definitions.
First, let’s import some dependencies.
[ ]:
!pip install funsor[torch]@git+https://github.com/pyro-ppl/funsor
[1]:
from torch import tensor
import funsor
import funsor.ops as ops
from funsor import Number, Tensor, Variable
from funsor.domains import Bint
funsor.set_backend("torch")
Named Tensors¶
Each tensor axis is given a name:
[2]:
A = Tensor(tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]]))["height", "width"]
Access elements of \(A\) using named indices:
[3]:
# A(height=0, width=2) =
A(width=2, height=0)
[3]:
Tensor(tensor(4))
Partial indexing:
[4]:
A(height=0)
[4]:
Tensor(tensor([3, 1, 4]), {'width': Bint[3]})
[5]:
A(width=2)
[5]:
Tensor(tensor([4, 9, 5]), {'height': Bint[3]})
Named tensor operations¶
Elementwise operations and broadcasting¶
Elementwise operations:
[6]:
# A.sigmoid() =
# ops.sigmoid(A) =
# 1 / (1 + ops.exp(-A)) =
1 / (1 + (-A).exp())
[6]:
Tensor(tensor([[0.9526, 0.7311, 0.9820],
[0.7311, 0.9933, 0.9999],
[0.8808, 0.9975, 0.9933]]), {'height': Bint[3], 'width': Bint[3]})
Tensors with different shapes are automatically broadcasted against each other before an operation is applied. Let
[7]:
x = Tensor(tensor([2, 7, 1]))["height"]
y = Tensor(tensor([1, 4, 1]))["width"]
Binary addition operation:
[8]:
# ops.add(A, x) =
A + x
[8]:
Tensor(tensor([[ 5, 3, 6],
[ 8, 12, 16],
[ 3, 7, 6]]), {'height': Bint[3], 'width': Bint[3]})
[9]:
# ops.add(A, y) =
A + y
[9]:
Tensor(tensor([[ 4, 5, 5],
[ 2, 9, 10],
[ 3, 10, 6]]), {'height': Bint[3], 'width': Bint[3]})
Binary multiplication operation:
[10]:
# ops.mul(A, x) =
A * x
[10]:
Tensor(tensor([[ 6, 2, 8],
[ 7, 35, 63],
[ 2, 6, 5]]), {'height': Bint[3], 'width': Bint[3]})
Binary maximum operation:
[11]:
ops.max(A, y)
[11]:
Tensor(tensor([[3, 4, 4],
[1, 5, 9],
[2, 6, 5]]), {'height': Bint[3], 'width': Bint[3]})
Reductions¶
Named axes can be reduced over by calling the .reduce
method and specifying the reduction operator and names of reduced axes. Note that reduction is defined only for operators that are associative and commutative.
[12]:
A.reduce(ops.add, "height")
[12]:
Tensor(tensor([ 6, 12, 18]), {'width': Bint[3]})
[13]:
A.reduce(ops.add, "width")
[13]:
Tensor(tensor([ 8, 15, 13]), {'height': Bint[3]})
Reduction over multiple axes:
[14]:
A.reduce(ops.add, {"height", "width"})
[14]:
Tensor(tensor(36))
Multiplication reduction:
[15]:
A.reduce(ops.mul, "height")
[15]:
Tensor(tensor([ 6, 30, 180]), {'width': Bint[3]})
Max reduction:
[16]:
A.reduce(ops.max, "height")
[16]:
Tensor(tensor([3, 6, 9]), {'width': Bint[3]})
Contraction¶
Contraction operation can be written as elementwise multiplication followed by summation over an axis:
[17]:
(A * y).reduce(ops.add, "width")
[17]:
Tensor(tensor([11, 30, 31]), {'height': Bint[3]})
Some other operations from linear algebra:
[18]:
(x * x).reduce(ops.add, "height")
[18]:
Tensor(tensor(54))
[19]:
x * y
[19]:
Tensor(tensor([[ 2, 8, 2],
[ 7, 28, 7],
[ 1, 4, 1]]), {'height': Bint[3], 'width': Bint[3]})
[20]:
(A * y).reduce(ops.add, "width")
[20]:
Tensor(tensor([11, 30, 31]), {'height': Bint[3]})
[21]:
(x * A).reduce(ops.add, "height")
[21]:
Tensor(tensor([15, 43, 76]), {'width': Bint[3]})
[22]:
B = Tensor(
tensor([[3, 2, 5], [5, 4, 0], [8, 3, 6]]),
)["width", "width2"]
(A * B).reduce(ops.add, "width")
[22]:
Tensor(tensor([[ 46, 22, 39],
[100, 49, 59],
[ 76, 43, 40]]), {'height': Bint[3], 'width2': Bint[3]})
Contraction can be generalized to other binary and reduction operations:
[23]:
(A + y).reduce(ops.max, "width")
[23]:
Tensor(tensor([ 5, 10, 10]), {'height': Bint[3]})
Renaming and reshaping¶
Renaming funsors is simple:
[24]:
# A(height=Variable("height2", Bint[3]))
A(height="height2")
[24]:
Tensor(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]), {'height2': Bint[3], 'width': Bint[3]})
[25]:
layer = Variable("layer", Bint[9])
A_layer = A(height=layer // Number(3, 4), width=layer % Number(3, 4))
A_layer
[25]:
Tensor(tensor([3, 1, 4, 1, 5, 9, 2, 6, 5]), {'layer': Bint[9]})
[26]:
height = Variable("height", Bint[3])
width = Variable("width", Bint[3])
A_layer(layer=height * Number(3, 4) + width % Number(3, 4))
[26]:
Tensor(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]), {'height': Bint[3], 'width': Bint[3]})
Advanced indexing¶
All of advanced indexing can be achieved through name substitutions in funsors.
Partial indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,i)\):
[27]:
E = Tensor(
tensor([[2, 1, 5], [3, 4, 2], [1, 3, 7], [1, 4, 3], [5, 9, 2]]),
)["vocab", "emb"]
E(vocab=2)
[27]:
Tensor(tensor([1, 3, 7]), {'emb': Bint[3]})
Integer array indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,I)\):
[28]:
I = Tensor(tensor([3, 2, 4, 0]), dtype=5)["seq"]
E(vocab=I)
[28]:
Tensor(tensor([[1, 4, 3],
[1, 3, 7],
[5, 9, 2],
[2, 1, 5]]), {'seq': Bint[4], 'emb': Bint[3]})
Gather operation \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(P,I)\):
[29]:
P = Tensor(
tensor([[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]),
)["vocab", "seq"]
P(vocab=I)
[29]:
Tensor(tensor([1, 5, 2, 2]), {'seq': Bint[4]})
Indexing with two integer arrays:
[30]:
I1 = Tensor(tensor([1, 2, 0]), dtype=4)["subseq"]
I2 = Tensor(tensor([3, 0, 4]), dtype=5)["subseq"]
P(seq=I1, vocab=I2)
[30]:
Tensor(tensor([3, 4, 5]), {'subseq': Bint[3]})
Note
Click here to download the full example code
Example: Adam optimizer¶
import argparse
import torch
import funsor
import funsor.ops as ops
from funsor.adam import Adam
from funsor.domains import Real, Reals
from funsor.tensor import Tensor
from funsor.terms import Variable
def main(args):
funsor.set_backend("torch")
# Problem definition.
N = 100
P = 10
data = Tensor(torch.randn(N, P))["n"]
true_weight = Tensor(torch.randn(P))
true_bias = Tensor(torch.randn(()))
truth = true_bias + true_weight @ data
# Model.
weight = Variable("weight", Reals[P])
bias = Variable("bias", Real)
pred = bias + weight @ data
loss = (pred - truth).abs().reduce(ops.add, "n")
# Inference.
with Adam(args.num_steps, lr=args.learning_rate, log_every=args.log_every) as optim:
loss.reduce(ops.min, {"weight", "bias"})
print(f"True bias\n{true_bias}")
print("Learned bias\n{}".format(optim.param("bias")))
print(f"True weight\n{true_weight}")
print("Learned weight\n{}".format(optim.param("weight")))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Linear regression example using Adam interpretation"
)
parser.add_argument("-P", "--num-features", type=int, default=10)
parser.add_argument("-N", "--num-data", type=int, default=100)
parser.add_argument("-n", "--num-steps", type=int, default=201)
parser.add_argument("-lr", "--learning-rate", type=float, default=0.05)
parser.add_argument("--log-every", type=int, default=20)
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Discrete HMM¶
import argparse
from collections import OrderedDict
import torch
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import reinterpret
from funsor.optimizer import apply_optimizer
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_probs = torch.tensor([[0.2, 0.8], [0.7, 0.3]], requires_grad=True)
emit_probs = torch.tensor([[0.4, 0.6], [0.1, 0.9]], requires_grad=True)
params = [trans_probs, emit_probs]
# A discrete HMM model.
def model(data):
log_prob = funsor.to_funsor(0.0)
trans = dist.Categorical(
probs=funsor.Tensor(
trans_probs,
inputs=OrderedDict([("prev", funsor.Bint[args.hidden_dim])]),
)
)
emit = dist.Categorical(
probs=funsor.Tensor(
emit_probs,
inputs=OrderedDict([("latent", funsor.Bint[args.hidden_dim])]),
)
)
x_curr = funsor.Number(0, args.hidden_dim)
for t, y in enumerate(data):
x_prev = x_curr
# A delayed sample statement.
x_curr = funsor.Variable("x_{}".format(t), funsor.Bint[args.hidden_dim])
log_prob += trans(prev=x_prev, value=x_curr)
if not args.lazy and isinstance(x_prev, funsor.Variable):
log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)
log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
data = torch.ones(args.time_steps, dtype=torch.long)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
if args.lazy:
with funsor.interpretations.lazy:
log_prob = apply_optimizer(model(data))
log_prob = reinterpret(log_prob)
else:
log_prob = model(data)
assert not log_prob.inputs, "free variables remain"
loss = -log_prob.data
loss.backward()
optim.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Kalman filter example")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.05, type=float)
parser.add_argument("-d", "--hidden-dim", default=2, type=int)
parser.add_argument("--lazy", action="store_true")
parser.add_argument("--filter", action="store_true")
parser.add_argument("--xfail-if-not-implemented", action="store_true")
args = parser.parse_args()
if args.xfail_if_not_implemented:
try:
main(args)
except NotImplementedError:
print("XFAIL")
else:
main(args)
Note
Click here to download the full example code
Example: Switching Linear Dynamical System EEG¶
We use a switching linear dynamical system [1] to model a EEG time series dataset. For inference we use a moment-matching approximation enabled by funsor.interpretations.moment_matching.
References
[1] Anderson, B., and J. Moore. “Optimal filtering. Prentice-Hall, Englewood Cliffs.” New Jersey (1979).
import argparse
import time
from collections import OrderedDict
from os.path import exists
from urllib.request import urlopen
import numpy as np
import pyro
import torch
import torch.nn as nn
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.pyro.convert import (
funsor_to_cat_and_mvn,
funsor_to_mvn,
matrix_and_mvn_to_funsor,
mvn_to_funsor,
)
# download dataset from UCI archive
def download_data():
if not exists("eeg.dat"):
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/00264/EEG%20Eye%20State.arff"
with open("eeg.dat", "wb") as f:
f.write(urlopen(url).read())
class SLDS(nn.Module):
def __init__(
self,
num_components, # the number of switching states K
hidden_dim, # the dimension of the continuous latent space
obs_dim, # the dimension of the continuous outputs
fine_transition_matrix=True, # controls whether the transition matrix depends on s_t
fine_transition_noise=False, # controls whether the transition noise depends on s_t
fine_observation_matrix=False, # controls whether the observation matrix depends on s_t
fine_observation_noise=False, # controls whether the observation noise depends on s_t
moment_matching_lag=1,
): # controls the expense of the moment matching approximation
self.num_components = num_components
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self.moment_matching_lag = moment_matching_lag
self.fine_transition_noise = fine_transition_noise
self.fine_observation_matrix = fine_observation_matrix
self.fine_observation_noise = fine_observation_noise
self.fine_transition_matrix = fine_transition_matrix
assert moment_matching_lag > 0
assert (
fine_transition_noise
or fine_observation_matrix
or fine_observation_noise
or fine_transition_matrix
), (
"The continuous dynamics need to be coupled to the discrete dynamics in at least one way [use at "
+ "least one of the arguments --ftn --ftm --fon --fom]"
)
super(SLDS, self).__init__()
# initialize the various parameters of the model
self.transition_logits = nn.Parameter(
0.1 * torch.randn(num_components, num_components)
)
if fine_transition_matrix:
transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(
num_components, hidden_dim, hidden_dim
)
else:
transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(
hidden_dim, hidden_dim
)
self.transition_matrix = nn.Parameter(transition_matrix)
if fine_transition_noise:
self.log_transition_noise = nn.Parameter(
0.1 * torch.randn(num_components, hidden_dim)
)
else:
self.log_transition_noise = nn.Parameter(0.1 * torch.randn(hidden_dim))
if fine_observation_matrix:
self.observation_matrix = nn.Parameter(
0.3 * torch.randn(num_components, hidden_dim, obs_dim)
)
else:
self.observation_matrix = nn.Parameter(
0.3 * torch.randn(hidden_dim, obs_dim)
)
if fine_observation_noise:
self.log_obs_noise = nn.Parameter(
0.1 * torch.randn(num_components, obs_dim)
)
else:
self.log_obs_noise = nn.Parameter(0.1 * torch.randn(obs_dim))
# define the prior distribution p(x_0) over the continuous latent at the initial time step t=0
x_init_mvn = pyro.distributions.MultivariateNormal(
torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim)
)
self.x_init_mvn = mvn_to_funsor(
x_init_mvn,
real_inputs=OrderedDict([("x_0", funsor.Reals[self.hidden_dim])]),
)
# we construct the various funsors used to compute the marginal log probability and other model quantities.
# these funsors depend on the various model parameters.
def get_tensors_and_dists(self):
# normalize the transition probabilities
trans_logits = self.transition_logits - self.transition_logits.logsumexp(
dim=-1, keepdim=True
)
trans_probs = funsor.Tensor(
trans_logits, OrderedDict([("s", funsor.Bint[self.num_components])])
)
trans_mvn = pyro.distributions.MultivariateNormal(
torch.zeros(self.hidden_dim), self.log_transition_noise.exp().diag_embed()
)
obs_mvn = pyro.distributions.MultivariateNormal(
torch.zeros(self.obs_dim), self.log_obs_noise.exp().diag_embed()
)
event_dims = (
("s",) if self.fine_transition_matrix or self.fine_transition_noise else ()
)
x_trans_dist = matrix_and_mvn_to_funsor(
self.transition_matrix, trans_mvn, event_dims, "x", "y"
)
event_dims = (
("s",)
if self.fine_observation_matrix or self.fine_observation_noise
else ()
)
y_dist = matrix_and_mvn_to_funsor(
self.observation_matrix, obs_mvn, event_dims, "x", "y"
)
return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist
# compute the marginal log probability of the observed data using a moment-matching approximation
@funsor.interpretations.moment_matching
def log_prob(self, data):
(
trans_logits,
trans_probs,
trans_mvn,
obs_mvn,
x_trans_dist,
y_dist,
) = self.get_tensors_and_dists()
log_prob = funsor.Number(0.0)
s_vars = {-1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)}
x_vars = {}
for t, y in enumerate(data):
# construct free variables for s_t and x_t
s_vars[t] = funsor.Variable(f"s_{t}", funsor.Bint[self.num_components])
x_vars[t] = funsor.Variable(f"x_{t}", funsor.Reals[self.hidden_dim])
# incorporate the discrete switching dynamics
log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t])
# incorporate the prior term p(x_t | x_{t-1})
if t == 0:
log_prob += self.x_init_mvn(value=x_vars[t])
else:
log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t])
# do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many
# pairs of free variables.
if t > self.moment_matching_lag - 1:
log_prob = log_prob.reduce(
ops.logaddexp,
{
s_vars[t - self.moment_matching_lag],
x_vars[t - self.moment_matching_lag],
},
)
# incorporate the observation p(y_t | x_t, s_t)
log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)
T = data.shape[0]
# reduce any remaining free variables
for t in range(self.moment_matching_lag):
log_prob = log_prob.reduce(
ops.logaddexp,
{
s_vars[T - self.moment_matching_lag + t],
x_vars[T - self.moment_matching_lag + t],
},
)
# assert that we've reduced all the free variables in log_prob
assert not log_prob.inputs, "unexpected free variables remain"
# return the PyTorch tensor behind log_prob (which we can directly differentiate)
return log_prob.data
# do filtering, prediction, and smoothing using a moment-matching approximation.
# here we implicitly use a moment matching lag of L = 1. the general logic follows
# the logic in the log_prob method.
@torch.no_grad()
@funsor.interpretations.moment_matching
def filter_and_predict(self, data, smoothing=False):
(
trans_logits,
trans_probs,
trans_mvn,
obs_mvn,
x_trans_dist,
y_dist,
) = self.get_tensors_and_dists()
log_prob = funsor.Number(0.0)
s_vars = {-1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)}
x_vars = {-1: None}
predictive_x_dists, predictive_y_dists, filtering_dists = [], [], []
test_LLs = []
for t, y in enumerate(data):
s_vars[t] = funsor.Variable(f"s_{t}", funsor.Bint[self.num_components])
x_vars[t] = funsor.Variable(f"x_{t}", funsor.Reals[self.hidden_dim])
log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t])
if t == 0:
log_prob += self.x_init_mvn(value=x_vars[t])
else:
log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t])
if t > 0:
log_prob = log_prob.reduce(
ops.logaddexp, {s_vars[t - 1], x_vars[t - 1]}
)
# do 1-step prediction and compute test LL
if t > 0:
predictive_x_dists.append(log_prob)
_log_prob = log_prob - log_prob.reduce(ops.logaddexp)
predictive_y_dist = y_dist(s=s_vars[t], x=x_vars[t]) + _log_prob
test_LLs.append(
predictive_y_dist(y=y).reduce(ops.logaddexp).data.item()
)
predictive_y_dist = predictive_y_dist.reduce(
ops.logaddexp, {f"x_{t}", f"s_{t}"}
)
predictive_y_dists.append(funsor_to_mvn(predictive_y_dist, 0, ()))
log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)
# save filtering dists for forward-backward smoothing
if smoothing:
filtering_dists.append(log_prob)
# do the backward recursion using previously computed ingredients
if smoothing:
# seed the backward recursion with the filtering distribution at t=T
smoothing_dists = [filtering_dists[-1]]
T = data.size(0)
s_vars = {
t: funsor.Variable(f"s_{t}", funsor.Bint[self.num_components])
for t in range(T)
}
x_vars = {
t: funsor.Variable(f"x_{t}", funsor.Reals[self.hidden_dim])
for t in range(T)
}
# do the backward recursion.
# let p[t|t-1] be the predictive distribution at time step t.
# let p[t|t] be the filtering distribution at time step t.
# let f[t] denote the prior (transition) density at time step t.
# then the smoothing distribution p[t|T] at time step t is
# given by the following recursion.
# p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]>
# where <...> denotes integration of the latent variables at time step t.
for t in reversed(range(T - 1)):
integral = smoothing_dists[-1] - predictive_x_dists[t]
integral += dist.Categorical(
trans_probs(s=s_vars[t]), value=s_vars[t + 1]
)
integral += x_trans_dist(s=s_vars[t], x=x_vars[t], y=x_vars[t + 1])
integral = integral.reduce(
ops.logaddexp, {s_vars[t + 1], x_vars[t + 1]}
)
smoothing_dists.append(filtering_dists[t] + integral)
# compute predictive test MSE and predictive variances
predictive_means = torch.stack([d.mean for d in predictive_y_dists]) # T-1 ydim
predictive_vars = torch.stack(
[d.covariance_matrix.diagonal(dim1=-1, dim2=-2) for d in predictive_y_dists]
)
predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1)
if smoothing:
# compute smoothed mean function
smoothing_dists = [
funsor_to_cat_and_mvn(d, 0, (f"s_{t}",))
for t, d in enumerate(reversed(smoothing_dists))
]
means = torch.stack([d[1].mean for d in smoothing_dists]) # T 2 xdim
means = torch.matmul(means.unsqueeze(-2), self.observation_matrix).squeeze(
-2
) # T 2 ydim
probs = torch.stack([d[0].logits for d in smoothing_dists]).exp()
probs = probs / probs.sum(-1, keepdim=True) # T 2
smoothing_means = (probs.unsqueeze(-1) * means).sum(-2) # T ydim
smoothing_probs = probs[:, 1]
return (
predictive_mse,
torch.tensor(np.array(test_LLs)),
predictive_means,
predictive_vars,
smoothing_means,
smoothing_probs,
)
else:
return predictive_mse, torch.tensor(np.array(test_LLs))
def main(args):
funsor.set_backend("torch")
# download and pre-process EEG data if not in test mode
if not args.test:
download_data()
N_val, N_test = 149, 200
data = np.loadtxt("eeg.dat", delimiter=",", skiprows=19)
print(f"[raw data shape] {data.shape}")
data = data[::20, :]
print(f"[data shape after thinning] {data.shape}")
eye_state = [int(d) for d in data[:, -1].tolist()]
data = torch.tensor(data[:, :-1]).float()
# in test mode (for continuous integration on github) so create fake data
else:
data = torch.randn(10, 3)
N_val, N_test = 2, 2
T, obs_dim = data.shape
N_train = T - N_test - N_val
np.random.seed(0)
rand_perm = np.random.permutation(N_val + N_test)
val_indices = rand_perm[0:N_val]
test_indices = rand_perm[N_val:]
data_mean = data[0:N_train, :].mean(0)
data -= data_mean
data_std = data[0:N_train, :].std(0)
data /= data_std
print(f"Length of time series T: {T} Observation dimension: {obs_dim}")
print(f"N_train: {N_train} N_val: {N_val} N_test: {N_test}")
torch.manual_seed(args.seed)
# set up model
slds = SLDS(
num_components=args.num_components,
hidden_dim=args.hidden_dim,
obs_dim=obs_dim,
fine_observation_noise=args.fon,
fine_transition_noise=args.ftn,
fine_observation_matrix=args.fom,
fine_transition_matrix=args.ftm,
moment_matching_lag=args.moment_matching_lag,
)
# set up optimizer
adam = torch.optim.Adam(
slds.parameters(),
lr=args.learning_rate,
betas=(args.beta1, 0.999),
amsgrad=True,
)
scheduler = torch.optim.lr_scheduler.ExponentialLR(adam, gamma=args.gamma)
ts = [time.time()]
report_frequency = 1
# training loop
for step in range(args.num_steps):
nll = -slds.log_prob(data[0:N_train, :]) / N_train
nll.backward()
if step == 5:
scheduler.base_lrs[0] *= 0.20
adam.step()
scheduler.step()
adam.zero_grad()
if step % report_frequency == 0 or step == args.num_steps - 1:
step_dt = ts[-1] - ts[-2] if step > 0 else 0.0
pred_mse, pred_LLs = slds.filter_and_predict(
data[0 : N_train + N_val + N_test, :]
)
val_mse = pred_mse[val_indices].mean().item()
test_mse = pred_mse[test_indices].mean().item()
val_ll = pred_LLs[val_indices].mean().item()
test_ll = pred_LLs[test_indices].mean().item()
stats = "[step %03d] train_nll: %.5f val_mse: %.5f val_ll: %.5f test_mse: %.5f test_ll: %.5f\t(dt: %.2f)"
print(
stats % (step, nll.item(), val_mse, val_ll, test_mse, test_ll, step_dt)
)
ts.append(time.time())
# plot predictions and smoothed means
if args.plot:
assert not args.test
(
predicted_mse,
LLs,
pred_means,
pred_vars,
smooth_means,
smooth_probs,
) = slds.filter_and_predict(data, smoothing=True)
pred_means = pred_means.data.numpy()
pred_stds = pred_vars.sqrt().data.numpy()
smooth_means = smooth_means.data.numpy()
smooth_probs = smooth_probs.data.numpy()
import matplotlib
matplotlib.use("Agg") # noqa: E402
import matplotlib.pyplot as plt
f, axes = plt.subplots(4, 1, figsize=(12, 8), sharex=True)
T = data.size(0)
N_valtest = N_val + N_test
to_seconds = 117.0 / T
for k, ax in enumerate(axes[:-1]):
which = [0, 4, 10][k]
ax.plot(to_seconds * np.arange(T), data[:, which], "ko", markersize=2)
ax.plot(
to_seconds * np.arange(N_train),
smooth_means[:N_train, which],
ls="solid",
color="r",
)
ax.plot(
to_seconds * (N_train + np.arange(N_valtest)),
pred_means[-N_valtest:, which],
ls="solid",
color="b",
)
ax.fill_between(
to_seconds * (N_train + np.arange(N_valtest)),
pred_means[-N_valtest:, which] - 1.645 * pred_stds[-N_valtest:, which],
pred_means[-N_valtest:, which] + 1.645 * pred_stds[-N_valtest:, which],
color="lightblue",
)
ax.set_ylabel(f"$y_{which + 1}$", fontsize=20)
ax.tick_params(axis="both", which="major", labelsize=14)
axes[-1].plot(to_seconds * np.arange(T), eye_state, "k", ls="solid")
axes[-1].plot(to_seconds * np.arange(T), smooth_probs, "r", ls="solid")
axes[-1].set_xlabel("Time (s)", fontsize=20)
axes[-1].set_ylabel("Eye state", fontsize=20)
axes[-1].tick_params(axis="both", which="major", labelsize=14)
plt.tight_layout(pad=0.7)
plt.savefig("eeg.pdf")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Switching linear dynamical system")
parser.add_argument("-n", "--num-steps", default=3, type=int)
parser.add_argument("-s", "--seed", default=15, type=int)
parser.add_argument("-hd", "--hidden-dim", default=5, type=int)
parser.add_argument("-k", "--num-components", default=2, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.5, type=float)
parser.add_argument("-b1", "--beta1", default=0.75, type=float)
parser.add_argument("-g", "--gamma", default=0.99, type=float)
parser.add_argument("-mml", "--moment-matching-lag", default=1, type=int)
parser.add_argument("--plot", action="store_true")
parser.add_argument("--fon", action="store_true")
parser.add_argument("--ftm", action="store_true")
parser.add_argument("--fom", action="store_true")
parser.add_argument("--ftn", action="store_true")
parser.add_argument("--test", action="store_true")
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Forward-Backward algorithm¶
import argparse
from collections import OrderedDict
from typing import Dict, List, Tuple
import funsor.ops as ops
from funsor import Funsor, Tensor
from funsor.adjoint import AdjointTape
from funsor.domains import Bint
from funsor.testing import assert_close, random_tensor
def forward_algorithm(
factors: List[Funsor],
step: Dict[str, str],
) -> Tuple[Funsor, List[Funsor]]:
"""
Calculate log marginal probability using the forward algorithm:
Z = p(y[0:T])
Transition and emission probabilities are given by init and trans factors:
init = p(y[0], x[0])
trans[t] = p(y[t], x[t] | x[t-1])
Forward probabilities are computed inductively:
alpha[t] = p(y[0:t], x[t])
alpha[0] = init
alpha[t+1] = alpha[t] @ trans[t+1]
"""
step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
reduce_vars = frozenset(drop)
# base case
alpha = factors[0]
alphas = [alpha]
# inductive steps
for trans in factors[1:]:
alpha = (alpha(**curr_to_drop) + trans(**prev_to_drop)).reduce(
ops.logaddexp, reduce_vars
)
alphas.append(alpha)
else:
Z = alpha(**curr_to_drop).reduce(ops.logaddexp, reduce_vars)
return Z
def forward_backward_algorithm(
factors: List[Funsor],
step: Dict[str, str],
) -> List[Tensor]:
"""
Calculate marginal probabilities:
p(x[t], x[t-1] | Y)
"""
step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
reduce_vars = frozenset(drop)
# Base cases
alpha = factors[0] # alpha[0] = p(y[0], x[0])
beta = Tensor(
ops.full_like(alpha.data, 0.0), alpha(x_curr="x_prev").inputs
) # beta[T] = 1
# Backward algorithm
# beta[t] = p(y[t+1:T] | x[t])
# beta[t] = trans[t+1] @ beta[t+1]
betas = [beta]
for trans in factors[:0:-1]:
beta = (trans(**curr_to_drop) + beta(**prev_to_drop)).reduce(
ops.logaddexp, reduce_vars
)
betas.append(beta)
else:
init = factors[0]
Z = (init(**curr_to_drop) + beta(**prev_to_drop)).reduce(
ops.logaddexp, reduce_vars
)
betas.reverse()
# Forward algorithm & Marginal computations
marginal = alpha + betas[0](**{"x_prev": "x_curr"}) - Z # p(x[0] | Y)
marginals = [marginal]
# inductive steps
for trans, beta in zip(factors[1:], betas[1:]):
# alpha[t-1] * trans[t] = p(y[0:t], x[t-1], x[t])
alpha_trans = alpha(**{"x_curr": "x_prev"}) + trans
# alpha[t] = p(y[0:t], x[t])
alpha = alpha_trans.reduce(ops.logaddexp, "x_prev")
# alpha[t-1] * trans[t] * beta[t] / Z = p(x[t-1], x[t] | Y)
marginal = alpha_trans + beta(**{"x_prev": "x_curr"}) - Z
marginals.append(marginal)
return marginals
def main(args):
"""
Compute marginal probabilities p(x[t], x[t-1] | Y) for an HMM:
x[0] -> ... -> x[t-1] -> x[t] -> ... -> x[T]
| | | |
v v v v
y[0] y[t-1] y[t] y[T]
Z = p(y[0:T])
alpha[t] = p(y[0:t], x[t])
beta[t] = p(y[t+1:T] | x[t])
trans[t] = p(y[t], x[t] | x[t-1])
p(x[t], x[t-1] | Y) = alpha[t-1] * beta[t] * trans[t] / Z
d Z / d trans[t] = alpha[t-1] * beta[t]
**References:**
[1] Jason Eisner (2016)
"Inside-Outside and Forward-Backward Algorithms Are Just Backprop
(Tutorial Paper)"
https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf
[2] Zhifei Li and Jason Eisner (2009)
"First- and Second-Order Expectation Semirings
with Applications to Minimum-Risk Training on Translation Forests"
http://www.cs.jhu.edu/~zfli/pubs/semiring_translation_zhifei_emnlp09.pdf
"""
# transition and emission probabilities
init = random_tensor(OrderedDict([("x_curr", Bint[args.hidden_dim])]))
factors = [init]
for t in range(args.time_steps - 1):
factors.append(
random_tensor(
OrderedDict(x_prev=Bint[args.hidden_dim], x_curr=Bint[args.hidden_dim])
)
)
# Compute marginal probabilities using the forward-backward algorithm
marginals = forward_backward_algorithm(factors, {"x_prev": "x_curr"})
# Compute marginal probabilities using backpropagation
with AdjointTape() as tape:
Z = forward_algorithm(factors, {"x_prev": "x_curr"})
result = tape.adjoint(ops.logaddexp, ops.add, Z, factors)
adjoint_terms = list(result.values())
t = 0
for expected, adj, trans in zip(marginals, adjoint_terms, factors):
# adjoint returns dZ / dtrans = alpha[t-1] * beta[t]
# marginal = alpha[t-1] * beta[t] * trans / Z
actual = adj + trans - Z
assert_close(expected, actual.align(tuple(expected.inputs)), rtol=1e-4)
print("")
print(f"Marginal term: p(x[{t}], x[{t-1}] | Y)")
print("Forward-backward algorithm:\n", expected.data)
print("Differentiating forward algorithm:\n", actual.data)
t += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Forward-Backward Algorithm Is Just Backprop"
)
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-d", "--hidden-dim", default=3, type=int)
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Kalman Filter¶
import argparse
import torch
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import reinterpret
from funsor.optimizer import apply_optimizer
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_noise = torch.tensor(0.1, requires_grad=True)
emit_noise = torch.tensor(0.5, requires_grad=True)
params = [trans_noise, emit_noise]
# A Gaussian HMM model.
def model(data):
log_prob = funsor.to_funsor(0.0)
x_curr = funsor.Tensor(torch.tensor(0.0))
for t, y in enumerate(data):
x_prev = x_curr
# A delayed sample statement.
x_curr = funsor.Variable("x_{}".format(t), funsor.Real)
log_prob += dist.Normal(1 + x_prev / 2.0, trans_noise, value=x_curr)
# Optionally marginalize out the previous state.
if t > 0 and not args.lazy:
log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)
# An observe statement.
log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)
# Marginalize out all remaining delayed variables.
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
torch.manual_seed(0)
data = torch.randn(args.time_steps)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
if args.lazy:
with funsor.interpretations.lazy:
log_prob = apply_optimizer(model(data))
log_prob = reinterpret(log_prob)
else:
log_prob = model(data)
assert not log_prob.inputs, "free variables remain"
loss = -log_prob.data
loss.backward()
optim.step()
if args.verbose and step % 10 == 0:
print("step {} loss = {}".format(step, loss.item()))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Kalman filter example")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.05, type=float)
parser.add_argument("--lazy", action="store_true")
parser.add_argument("--filter", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Mini Pyro¶
import argparse
import torch
from pyroapi import distributions as dist
from pyroapi import infer, optim, pyro, pyro_backend
from torch.distributions import constraints
import funsor
from funsor.montecarlo import MonteCarlo
def main(args):
funsor.set_backend("torch")
# Define a basic model with a single Normal latent random variable `loc`
# and a batch of Normally distributed observations.
def model(data):
loc = pyro.sample("loc", dist.Normal(0.0, 1.0))
with pyro.plate("data", len(data), dim=-1):
pyro.sample("obs", dist.Normal(loc, 1.0), obs=data)
# Define a guide (i.e. variational distribution) with a Normal
# distribution over the latent random variable `loc`.
def guide(data):
guide_loc = pyro.param("guide_loc", torch.tensor(0.0))
guide_scale = pyro.param(
"guide_scale", torch.tensor(1.0), constraint=constraints.positive
)
pyro.sample("loc", dist.Normal(guide_loc, guide_scale))
# Generate some data.
torch.manual_seed(0)
data = torch.randn(100) + 3.0
# Because the API in minipyro matches that of Pyro proper,
# training code works with generic Pyro implementations.
with pyro_backend(args.backend), MonteCarlo():
# Construct an SVI object so we can do variational inference on our
# model/guide pair.
Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
elbo = Elbo()
adam = optim.Adam({"lr": args.learning_rate})
svi = infer.SVI(model, guide, adam, elbo)
# Basic training loop
pyro.get_param_store().clear()
for step in range(args.num_steps):
loss = svi.step(data)
if args.verbose and step % 100 == 0:
print(f"step {step} loss = {loss}")
# Report the final values of the variational parameters
# in the guide after training.
if args.verbose:
for name in pyro.get_param_store():
value = pyro.param(name).data
print(f"{name} = {value.detach().cpu().numpy()}")
# For this simple (conjugate) model we know the exact posterior. In
# particular we know that the variational distribution should be
# centered near 3.0. So let's check this explicitly.
assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Minipyro demo")
parser.add_argument("-b", "--backend", default="funsor")
parser.add_argument("-n", "--num-steps", default=1001, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.02, type=float)
parser.add_argument("--jit", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: PCFG¶
import argparse
import math
from collections import OrderedDict
import torch
import funsor
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import Bint
from funsor.tensor import Tensor
from funsor.terms import Number, Stack, Variable
def Uniform(components):
components = tuple(components)
size = len(components)
if size == 1:
return components[0]
var = Variable("v", Bint[size])
return Stack(var.name, components).reduce(ops.logaddexp, var.name) - math.log(size)
# @of_shape(*([Bint[2]] * size))
def model(size, position=0):
if size == 1:
name = str(position)
return Uniform((Delta(name, Number(0, 2)), Delta(name, Number(1, 2))))
return Uniform(
model(t, position) + model(size - t, t + position) for t in range(1, size)
)
def main(args):
funsor.set_backend("torch")
torch.manual_seed(args.seed)
print_ = print if args.verbose else lambda msg: None
print_("Data:")
data = torch.distributions.Categorical(torch.ones(2)).sample((args.size,))
assert data.shape == (args.size,)
data = Tensor(data, OrderedDict(i=Bint[args.size]), dtype=2)
print_(data)
print_("Model:")
m = model(args.size)
print_(m.pretty())
print_("Eager log_prob:")
obs = {str(i): data(i) for i in range(args.size)}
log_prob = m(**obs)
print_(log_prob)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PCFG example")
parser.add_argument("-s", "--size", default=3, type=int)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Biased Kalman Filter¶
import argparse
import itertools
import math
import os
import pyro.distributions as dist
import torch
import torch.nn as nn
from torch.optim import Adam
import funsor
import funsor.ops as ops
import funsor.torch.distributions as f_dist
from funsor.domains import Reals
from funsor.pyro.convert import dist_to_funsor, funsor_to_mvn
from funsor.tensor import Tensor, Variable
# We use a 2D continuous-time NCV dynamics model throughout.
# See http://webee.technion.ac.il/people/shimkin/Estimation09/ch8_target.pdf
TIME_STEP = 1.0
NCV_PROCESS_NOISE = torch.tensor(
[
[1 / 3, 0.0, 1 / 2, 0.0],
[0.0, 1 / 3, 0.0, 1 / 2],
[1 / 2, 0.0, 1.0, 0.0],
[0.0, 1 / 2, 0.0, 1.0],
]
)
NCV_TRANSITION_MATRIX = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
]
)
@torch.no_grad()
def generate_data(num_frames, num_sensors):
"""
Generate data from a damped NCV dynamics model
"""
dt = TIME_STEP
bias_scale = 4.0
obs_noise = 1.0
trans_noise = 0.3
# define dynamics
z = torch.cat([10.0 * torch.randn(2), torch.rand(2)]) # position # velocity
damp = 0.1 # damp the velocities
f = torch.tensor(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[dt * math.exp(-damp * dt), 0, math.exp(-damp * dt), 0],
[0, dt * math.exp(-damp * dt), 0, math.exp(-damp * dt)],
]
)
trans_dist = dist.MultivariateNormal(
torch.zeros(4),
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
)
# define biased sensors
sensor_bias = bias_scale * torch.randn(2 * num_sensors)
h = torch.eye(4, 2).unsqueeze(-1).expand(-1, -1, num_sensors).reshape(4, -1)
obs_dist = dist.MultivariateNormal(
sensor_bias, scale_tril=obs_noise * torch.eye(2 * num_sensors)
)
states = []
observations = []
for t in range(num_frames):
z = z @ f + trans_dist.sample()
states.append(z)
x = z @ h + obs_dist.sample()
observations.append(x)
states = torch.stack(states)
observations = torch.stack(observations)
assert observations.shape == (num_frames, num_sensors * 2)
return observations, states, sensor_bias
class Model(nn.Module):
def __init__(self, num_sensors):
super(Model, self).__init__()
self.num_sensors = num_sensors
# learnable params
self.log_bias_scale = nn.Parameter(torch.tensor(0.0))
self.log_obs_noise = nn.Parameter(torch.tensor(0.0))
self.log_trans_noise = nn.Parameter(torch.tensor(0.0))
def forward(self, observations, add_bias=True):
obs_dim = 2 * self.num_sensors
bias_scale = self.log_bias_scale.exp()
obs_noise = self.log_obs_noise.exp()
trans_noise = self.log_trans_noise.exp()
# bias distribution
bias = Variable("bias", Reals[obs_dim])
assert not torch.isnan(bias_scale), "bias scales was nan"
bias_dist = dist_to_funsor(
dist.MultivariateNormal(
torch.zeros(obs_dim),
scale_tril=bias_scale * torch.eye(2 * self.num_sensors),
)
)(value=bias)
init_dist = dist.MultivariateNormal(
torch.zeros(4), scale_tril=100.0 * torch.eye(4)
)
self.init = dist_to_funsor(init_dist)(value="state")
# hidden states
prev = Variable("prev", Reals[4])
curr = Variable("curr", Reals[4])
self.trans_dist = f_dist.MultivariateNormal(
loc=prev @ NCV_TRANSITION_MATRIX,
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
value=curr,
)
state = Variable("state", Reals[4])
obs = Variable("obs", Reals[obs_dim])
observation_matrix = Tensor(
torch.eye(4, 2)
.unsqueeze(-1)
.expand(-1, -1, self.num_sensors)
.reshape(4, -1)
)
assert observation_matrix.output.shape == (
4,
obs_dim,
), observation_matrix.output.shape
obs_loc = state @ observation_matrix
if add_bias:
obs_loc += bias
self.observation_dist = f_dist.MultivariateNormal(
loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs
)
logp = bias_dist
curr = "state_init"
logp += self.init(state=curr)
for t, x in enumerate(observations):
prev, curr = curr, "state_{}".format(t)
logp += self.trans_dist(prev=prev, curr=curr)
logp += self.observation_dist(state=curr, obs=x)
# marginalize out previous state
logp = logp.reduce(ops.logaddexp, prev)
# marginalize out bias variable
logp = logp.reduce(ops.logaddexp, "bias")
# save posterior over the final state
assert set(logp.inputs) == {"state_{}".format(len(observations) - 1)}
posterior = funsor_to_mvn(logp, ndims=0)
# marginalize out remaining variables
logp = logp.reduce(ops.logaddexp)
assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty()
return logp.data, posterior
def track(args):
results = {} # keyed on (seed, bias, num_frames)
for seed in args.seed:
torch.manual_seed(seed)
observations, states, sensor_bias = generate_data(
max(args.num_frames), args.num_sensors
)
for bias, num_frames in itertools.product(args.bias, args.num_frames):
print(
"tracking with seed={}, bias={}, num_frames={}".format(
seed, bias, num_frames
)
)
model = Model(args.num_sensors)
optim = Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.8))
losses = []
for i in range(args.num_epochs):
optim.zero_grad()
log_prob, posterior = model(observations[:num_frames], add_bias=bias)
loss = -log_prob
loss.backward()
losses.append(loss.item())
if i % 10 == 0:
print(loss.item())
optim.step()
# Collect evaluation metrics.
final_state_true = states[num_frames - 1]
assert final_state_true.shape == (4,)
final_pos_true = final_state_true[:2]
final_vel_true = final_state_true[2:]
final_state_est = posterior.loc
assert final_state_est.shape == (4,)
final_pos_est = final_state_est[:2]
final_vel_est = final_state_est[2:]
final_pos_error = float(torch.norm(final_pos_true - final_pos_est))
final_vel_error = float(torch.norm(final_vel_true - final_vel_est))
print("final_pos_error = {}".format(final_pos_error))
results[seed, bias, num_frames] = {
"args": args,
"observations": observations[:num_frames],
"states": states[:num_frames],
"sensor_bias": sensor_bias,
"losses": losses,
"bias_scale": float(model.log_bias_scale.exp()),
"obs_noise": float(model.log_obs_noise.exp()),
"trans_noise": float(model.log_trans_noise.exp()),
"final_state_estimate": posterior,
"final_pos_error": final_pos_error,
"final_vel_error": final_vel_error,
}
if args.metrics_filename:
print("saving output to: {}".format(args.metrics_filename))
torch.save(results, args.metrics_filename)
return results
def main(args):
funsor.set_backend("torch")
if (
args.force
or not args.metrics_filename
or not os.path.exists(args.metrics_filename)
):
# Increase compression threshold for numerical stability.
with funsor.gaussian.Gaussian.set_compression_threshold(3):
results = track(args)
else:
results = torch.load(args.metrics_filename)
if args.plot_filename:
import matplotlib
matplotlib.use("Agg")
import numpy as np
from matplotlib import pyplot
seeds = set(seed for seed, _, _ in results)
X = args.num_frames
pyplot.figure(figsize=(5, 1.4), dpi=300)
pos_error = np.array(
[
[results[s, 0, f]["final_pos_error"] for s in seeds]
for f in args.num_frames
]
)
mse = (pos_error**2).mean(axis=1)
std = (pos_error**2).std(axis=1) / len(seeds) ** 0.5
pyplot.plot(X, mse**0.5, "k--")
pyplot.fill_between(
X, (mse - std) ** 0.5, (mse + std) ** 0.5, color="black", alpha=0.15, lw=0
)
pos_error = np.array(
[
[results[s, 1, f]["final_pos_error"] for s in seeds]
for f in args.num_frames
]
)
mse = (pos_error**2).mean(axis=1)
std = (pos_error**2).std(axis=1) / len(seeds) ** 0.5
pyplot.plot(X, mse**0.5, "r-")
pyplot.fill_between(
X, (mse - std) ** 0.5, (mse + std) ** 0.5, color="red", alpha=0.15, lw=0
)
pyplot.ylabel("Position RMSE")
pyplot.xlabel("Track Length")
pyplot.xticks((5, 10, 15, 20, 25, 30))
pyplot.xlim(5, 30)
pyplot.tight_layout(0)
pyplot.savefig(args.plot_filename)
def int_list(args):
result = []
for arg in args.split(","):
if "-" in arg:
beg, end = map(int, arg.split("-"))
result.extend(range(beg, 1 + end))
else:
result.append(int(arg))
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Biased Kalman filter")
parser.add_argument(
"--seed",
default="0",
type=int_list,
help="random seed, comma delimited for multiple runs",
)
parser.add_argument(
"--bias",
default="0,1",
type=int_list,
help="whether to model bias, comma deliminted for multiple runs",
)
parser.add_argument(
"-f",
"--num-frames",
default="5,10,15,20,25,30",
type=int_list,
help="number of sensor frames, comma delimited for multiple runs",
)
parser.add_argument("--num-sensors", default=5, type=int)
parser.add_argument("-n", "--num-epochs", default=50, type=int)
parser.add_argument("--lr", default=0.1, type=float)
parser.add_argument("--metrics-filename", default="", type=str)
parser.add_argument("--plot-filename", default="", type=str)
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Switching Linear Dynamical System¶
import argparse
import torch
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
def main(args):
funsor.set_backend("torch")
# Declare parameters.
trans_probs = funsor.Tensor(
torch.tensor([[0.9, 0.1], [0.1, 0.9]], requires_grad=True)
)
trans_noise = funsor.Tensor(
torch.tensor(
[0.1, 1.0], # low noise component # high noisy component
requires_grad=True,
)
)
emit_noise = funsor.Tensor(torch.tensor(0.5, requires_grad=True))
params = [trans_probs.data, trans_noise.data, emit_noise.data]
# A Gaussian HMM model.
@funsor.interpretations.moment_matching
def model(data):
log_prob = funsor.Number(0.0)
# s is the discrete latent state,
# x is the continuous latent state,
# y is the observed state.
s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
x_curr = funsor.Tensor(torch.tensor(0.0))
for t, y in enumerate(data):
s_prev = s_curr
x_prev = x_curr
# A delayed sample statement.
s_curr = funsor.Variable(f"s_{t}", funsor.Bint[2])
log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)
# A delayed sample statement.
x_curr = funsor.Variable(f"x_{t}", funsor.Real)
log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)
# Marginalize out previous delayed sample statements.
if t > 0:
log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name})
# An observe statement.
log_prob += dist.Normal(x_curr, emit_noise, value=y)
log_prob = log_prob.reduce(ops.logaddexp)
return log_prob
# Train model parameters.
torch.manual_seed(0)
data = torch.randn(args.time_steps)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
log_prob = model(data)
assert not log_prob.inputs, "free variables remain"
loss = -log_prob.data
loss.backward()
optim.step()
if args.verbose and step % 10 == 0:
print(f"step {step} loss = {loss.item()}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Switching linear dynamical system")
parser.add_argument("-t", "--time-steps", default=10, type=int)
parser.add_argument("-n", "--train-steps", default=101, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("--filter", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: Talbot’s method for numerical inversion of the Laplace transform¶
import argparse
import math
import torch
import funsor
import funsor.ops as ops
from funsor.adam import Adam
from funsor.domains import Real
from funsor.factory import Bound, Fresh, Has, make_funsor
from funsor.interpretations import StatefulInterpretation
from funsor.tensor import Tensor
from funsor.terms import Funsor, Variable
from funsor.util import get_backend
@make_funsor
def InverseLaplace(
F: Has[{"s"}], t: Funsor, s: Bound # noqa: F821
) -> Fresh[lambda F: F]:
"""
Inverse Laplace transform of function F(s).
There is no closed-form solution for arbitrary F(s). However, we can
resort to numerical approximations which we store in new interpretations.
For example, see Talbot's method below.
:param F: function of s.
:param t: times at which to evaluate the inverse Laplace transformation of F.
:param s: s Variable.
"""
return None
class Talbot(StatefulInterpretation):
"""
Talbot's method for numerical inversion of the Laplace transform.
Reference
Abate, Joseph, and Ward Whitt. "A Unified Framework for Numerically
Inverting Laplace Transforms." INFORMS Journal of Computing, vol. 18.4
(2006): 408-421. Print. (http://www.columbia.edu/~ww2040/allpapers.html)
Implementation here is adapted from the MATLAB implementation of the algorithm by
Tucker McClure (2021). Numerical Inverse Laplace Transform
(https://www.mathworks.com/matlabcentral/fileexchange/39035-numerical-inverse-laplace-transform),
MATLAB Central File Exchange. Retrieved April 4, 2021.
:param num_steps: number of terms to sum for each t.
"""
def __init__(self, num_steps):
super().__init__("talbot")
self.num_steps = num_steps
@Talbot.register(InverseLaplace, Funsor, Funsor, Variable)
def talbot(self, F, t, s):
if get_backend() == "torch":
import torch
k = torch.arange(1, self.num_steps)
delta = torch.zeros(self.num_steps, dtype=torch.complex64)
delta[0] = 2 * self.num_steps / 5
delta[1:] = (
2 * math.pi / 5 * k * (1 / (math.pi / self.num_steps * k).tan() + 1j)
)
gamma = torch.zeros(self.num_steps, dtype=torch.complex64)
gamma[0] = 0.5 * delta[0].exp()
gamma[1:] = (
1
+ 1j
* math.pi
/ self.num_steps
* k
* (1 + 1 / (math.pi / self.num_steps * k).tan() ** 2)
- 1j / (math.pi / self.num_steps * k).tan()
) * delta[1:].exp()
delta = Tensor(delta)["num_steps"]
gamma = Tensor(gamma)["num_steps"]
ilt = 0.4 / t * (gamma * F(**{s.name: delta / t})).reduce(ops.add, "num_steps")
return Tensor(ilt.data.real, ilt.inputs)
else:
raise NotImplementedError(f"Unsupported backend {get_backend()}")
def main(args):
"""
Reference for the n-step sequential model used here:
Aaron L. Lucius et al (2003).
"General Methods for Analysis of Sequential ‘‘n-step’’ Kinetic Mechanisms:
Application to Single Turnover Kinetics of Helicase-Catalyzed DNA Unwinding"
https://www.sciencedirect.com/science/article/pii/S0006349503746487
"""
funsor.set_backend("torch")
# Problem definition.
true_rate = 20
true_nsteps = 4
rate = Variable("rate", Real)
nsteps = Variable("nsteps", Real)
s = Variable("s", Real)
time = Tensor(torch.arange(0.04, 1.04, 0.04))["timepoint"]
noise = Tensor(torch.randn(time.inputs["timepoint"].size))["timepoint"] / 500
data = (
Tensor(1 - torch.igammac(torch.tensor(true_nsteps), true_rate * time.data))[
"timepoint"
]
+ noise
)
F = rate**nsteps / (s * (rate + s) ** nsteps)
# Inverse Laplace.
pred = InverseLaplace(F, time, "s")
# Loss function.
loss = (pred - data).abs().reduce(ops.add, "timepoint")
init_params = {
"rate": Tensor(torch.tensor(5.0, requires_grad=True)),
"nsteps": Tensor(torch.tensor(2.0, requires_grad=True)),
}
with Talbot(num_steps=args.talbot_num_steps):
# Fit the data
with Adam(
args.num_steps,
lr=args.learning_rate,
log_every=args.log_every,
params=init_params,
) as optim:
loss.reduce(ops.min, {"rate", "nsteps"})
# Fitted curve.
fitted_curve = pred(rate=optim.param("rate"), nsteps=optim.param("nsteps"))
print(f"Data\n{data}")
print(f"Fit curve\n{fitted_curve}")
print(f"True rate\n{true_rate}")
print("Learned rate\n{}".format(optim.param("rate").item()))
print(f"True number of steps\n{true_nsteps}")
print("Learned number of steps\n{}".format(optim.param("nsteps").item()))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Numerical inversion of the Laplace transform using Talbot's method"
)
parser.add_argument("-N", "--talbot-num-steps", type=int, default=32)
parser.add_argument("-n", "--num-steps", type=int, default=501)
parser.add_argument("-lr", "--learning-rate", type=float, default=0.1)
parser.add_argument("--log-every", type=int, default=20)
args = parser.parse_args()
main(args)
Note
Click here to download the full example code
Example: VAE MNIST¶
import argparse
import os
import typing
from collections import OrderedDict
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
import funsor
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.domains import Bint, Reals
REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_PATH = os.path.join(REPO_PATH, "data")
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
def forward(self, image: Reals[28, 28]) -> typing.Tuple[Reals[20], Reals[20]]:
image = image.reshape(image.shape[:-2] + (-1,))
h1 = F.relu(self.fc1(image))
loc = self.fc21(h1)
scale = self.fc22(h1).exp()
return loc, scale
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def forward(self, z: Reals[20]) -> Reals[28, 28]:
h3 = F.relu(self.fc3(z))
out = torch.sigmoid(self.fc4(h3))
return out.reshape(out.shape[:-1] + (28, 28))
def main(args):
funsor.set_backend("torch")
# XXX Temporary fix after https://github.com/pyro-ppl/pyro/pull/2701
import pyro
pyro.enable_validation(False)
encoder = Encoder()
decoder = Decoder()
# These rely on type hints on the .forward() methods.
encode = funsor.function(encoder)
decode = funsor.function(decoder)
@funsor.montecarlo.MonteCarlo()
def loss_function(data, subsample_scale):
# Lazily sample from the guide.
loc, scale = encode(data)
q = funsor.Independent(
dist.Normal(loc["i"], scale["i"], value="z_i"), "z", "i", "z_i"
)
# Evaluate the model likelihood at the lazy value z.
probs = decode("z")
p = dist.Bernoulli(probs["x", "y"], value=data["x", "y"])
p = p.reduce(ops.add, {"x", "y"})
# Construct an elbo. This is where sampling happens.
elbo = funsor.Integrate(q, p - q, "z")
elbo = elbo.reduce(ops.add, "batch") * subsample_scale
loss = -elbo
return loss
train_loader = torch.utils.data.DataLoader(
MNIST(DATA_PATH, train=True, download=True, transform=transforms.ToTensor()),
batch_size=args.batch_size,
shuffle=True,
)
encoder.train()
decoder.train()
optimizer = optim.Adam(
list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3
)
for epoch in range(args.num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
subsample_scale = float(len(train_loader.dataset) / len(data))
data = data[:, 0, :, :]
data = funsor.Tensor(data, OrderedDict(batch=Bint[len(data)]))
optimizer.zero_grad()
loss = loss_function(data, subsample_scale)
assert isinstance(loss, funsor.Tensor), loss.pretty()
loss.data.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 50 == 0:
print(f" loss = {loss.item()}")
if batch_idx and args.smoke_test:
return
print(f"epoch {epoch} train_loss = {train_loss}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="VAE MNIST Example")
parser.add_argument("-n", "--num-epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--smoke-test", action="store_true")
args = parser.parse_args()
main(args)