Funsors

Basic Funsors

class Funsor(inputs, output, fresh=None, bound=None)[source]

Bases: object

Abstract base class for immutable functional tensors.

Concrete derived classes must implement __init__() methods taking hashable *args and no optional **kwargs so as to support cons hashing.

Derived classes with .fresh variables must implement an eager_subs() method. Derived classes with .bound variables must implement an _alpha_convert() method.

Parameters:
  • inputs (OrderedDict) – A mapping from input name to domain. This can be viewed as a typed context or a mapping from free variables to domains.
  • output (Domain) – An output domain.
dtype
shape
input_vars[source]
quote()[source]
pretty(*args, **kwargs)[source]
item()[source]
requires_grad
reduce(op, reduced_vars=None)[source]

Reduce along all or a subset of inputs.

Parameters:
  • op (AssociativeOp or ReductionOp) – A reduction operation.
  • reduced_vars (str, Variable, or set or frozenset thereof.) – An optional input name or set of names to reduce. If unspecified, all inputs will be reduced.
approximate(op, guide, approx_vars=None)[source]

Approximate wrt and all or a subset of inputs.

Parameters:
  • op (AssociativeOp) – A reduction operation.
  • guide (Funsor) – A guide funsor (e.g. a proposal distribution).
  • approx_vars (str, Variable, or set or frozenset thereof.) – An optional input name or set of names to reduce. If unspecified, all inputs will be reduced.
sample(sampled_vars, sample_inputs=None, rng_key=None)[source]

Create a Monte Carlo approximation to this funsor by replacing functions of sampled_vars with Delta s.

The result is a Funsor with the same .inputs and .output as the original funsor (plus sample_inputs if provided), so that self can be replaced by the sample in expectation computations:

y = x.sample(sampled_vars)
assert y.inputs == x.inputs
assert y.output == x.output
exact = (x.exp() * integrand).reduce(ops.add)
approx = (y.exp() * integrand).reduce(ops.add)

If sample_inputs is provided, this creates a batch of samples.

Parameters:
  • sampled_vars (str, Variable, or set or frozenset thereof.) – A set of input variables to sample.
  • sample_inputs (OrderedDict) – An optional mapping from variable name to Domain over which samples will be batched.
  • rng_key (None or JAX's random.PRNGKey) – a PRNG state to be used by JAX backend to generate random samples
align(names)[source]

Align this funsor to match given names. This is mainly useful in preparation for extracting .data of a funsor.tensor.Tensor.

Parameters:names (tuple) – A tuple of strings representing all names but in a new order.
Returns:A permuted funsor equivalent to self.
Return type:Funsor
eager_subs(subs)[source]

Internal substitution function. This relies on the user-facing __call__() method to coerce non-Funsors to Funsors. Once all inputs are Funsors, eager_subs() implementations can recurse to call Subs.

eager_unary(op)[source]
eager_reduce(op, reduced_vars)[source]
sequential_reduce(op, reduced_vars)[source]
moment_matching_reduce(op, reduced_vars)[source]
abs()[source]
atanh()[source]
sqrt()[source]
exp()[source]
log()[source]
log1p()[source]
sigmoid()[source]
tanh()[source]
reshape(shape)[source]
all(axis=None, keepdims=False)[source]
any(axis=None, keepdims=False)[source]
argmax(axis=None, keepdims=False)[source]
argmin(axis=None, keepdims=False)[source]
max(axis=None, keepdims=False)[source]
min(axis=None, keepdims=False)[source]
sum(axis=None, keepdims=False)[source]
prod(axis=None, keepdims=False)[source]
logsumexp(axis=None, keepdims=False)[source]
mean(axis=None, keepdims=False)[source]
std(axis=None, ddof=0, keepdims=False)[source]
var(axis=None, ddof=0, keepdims=False)[source]
to_funsor(x, output=None, dim_to_name=None, **kwargs)[source]

Convert to a Funsor . Only Funsor s and scalars are accepted.

Parameters:
  • x – An object.
  • output (funsor.domains.Domain) – An optional output hint.
  • dim_to_name (OrderedDict) – An optional mapping from negative batch dimensions to name strings.
Returns:

A Funsor equivalent to x.

Return type:

Funsor

Raises:

ValueError

to_data(x, name_to_dim=None, **kwargs)[source]

Extract a python object from a Funsor.

Raises a ValueError if free variables remain or if the funsor is lazy.

Parameters:
  • x – An object, possibly a Funsor.
  • name_to_dim (OrderedDict) – An optional inputs hint.
Returns:

A non-funsor equivalent to x.

Raises:

ValueError if any free variables remain.

Raises:

PatternMissingError if funsor is not fully evaluated.

class Variable(name, output)[source]

Bases: funsor.terms.Funsor

Funsor representing a single free variable.

Parameters:
  • name (str) – A variable name.
  • output (funsor.domains.Domain) – A domain.
eager_subs(subs)[source]
class Subs(arg, subs)[source]

Bases: funsor.terms.Funsor

Lazy substitution of the form x(u=y, v=z).

Parameters:
  • arg (Funsor) – A funsor being substituted into.
  • subs (tuple) – A tuple of (name, value) pairs, where name is a string and value can be coerced to a Funsor via to_funsor().
class Unary(op, arg)[source]

Bases: funsor.terms.Funsor

Lazy unary operation.

Parameters:
  • op (Op) – A unary operator.
  • arg (Funsor) – An argument.
class Binary(op, lhs, rhs)[source]

Bases: funsor.terms.Funsor

Lazy binary operation.

Parameters:
  • op (Op) – A binary operator.
  • lhs (Funsor) – A left hand side argument.
  • rhs (Funsor) – A right hand side argument.
class Reduce(op, arg, reduced_vars)[source]

Bases: funsor.terms.Funsor

Lazy reduction over multiple variables.

The user-facing interface is the Funsor.reduce() method.

Parameters:
  • op (AssociativeOp) – An associative operator.
  • arg (funsor) – An argument to be reduced.
  • reduced_vars (frozenset) – A set of variables over which to reduce.
class Scatter(op, subs, source, reduced_vars)[source]

Bases: funsor.terms.Funsor

Transpose of structurally linear Subs, followed by Reduce.

For injective scatter operations this should satisfy the equation:

if destin = Scatter(op, subs, source, frozenset())
then source = Subs(destin, subs)

The reduced_vars is merely for computational efficiency, and could always be split out into a separate .reduce(). For example in the following equation, the left hand side uses much less memory than the right hand side:

Scatter(op, subs, source, reduced_vars) ==
  Scatter(op, subs, source, frozenset()).reduce(op, reduced_vars)

Warning

This is currently implemented only for injective scatter operations. In particular, this does not allow accumulation behavior like scatter-add.

Note

Scatter(ops.add, ...) is the funsor analog of numpy.add.at() or torch.index_put() or jax.lax.scatter_add(). For injective substitutions, Scatter(ops.add, ...) is roughly equivalent to the tensor operation:

result = zeros(...)  # since zero is the additive unit
result[subs] = source
Parameters:
  • op (AssociativeOp) – An op. The unit of this op will be used as default value.
  • subs (tuple) – A substitution.
  • source (Funsor) – A source for data to be scattered from.
  • reduced_vars (frozenset) – A set of variables over which to reduce.
eager_subs(subs)[source]
class Approximate(op, model, guide, approx_vars)[source]

Bases: funsor.terms.Funsor

Interpretation-specific approximation wrt a set of variables.

The default eager interpretation should be exact. The user-facing interface is the Funsor.approximate() method.

Parameters:
  • op (AssociativeOp) – An associative operator.
  • model (Funsor) – An exact funsor depending on approx_vars.
  • guide (Funsor) – A proposal funsor guiding optional approximation.
  • approx_vars (frozenset) – A set of variables over which to approximate.
class Number(data, dtype=None)[source]

Bases: funsor.terms.Funsor

Funsor backed by a Python number.

Parameters:
  • data (numbers.Number) – A python number.
  • dtype – A nonnegative integer or the string “real”.
item()[source]
eager_unary(op)[source]
class Slice(name, start, stop, step, dtype)[source]

Bases: funsor.terms.Funsor

Symbolic representation of a Python slice object.

Parameters:
  • name (str) – A name for the new slice dimension.
  • start (int) –
  • stop (int) –
  • step (int) – Three args following slice semantics.
  • dtype (int) – An optional bounded integer type of this slice.
eager_subs(subs)[source]
class Stack(name, parts)[source]

Bases: funsor.terms.Funsor

Stack of funsors along a new input dimension.

Parameters:
  • name (str) – The name of the new input variable along which to stack.
  • parts (tuple) – A tuple of Funsors of homogenous output domain.
eager_subs(subs)[source]
eager_reduce(op, reduced_vars)[source]
class Cat(name, parts, part_name=None)[source]

Bases: funsor.terms.Funsor

Concatenate funsors along an existing input dimension.

Parameters:
  • name (str) – The name of the input variable along which to concatenate.
  • parts (tuple) – A tuple of Funsors of homogenous output domain.
eager_subs(subs)[source]
class Lambda(var, expr)[source]

Bases: funsor.terms.Funsor

Lazy inverse to ops.getitem.

This is useful to simulate higher-order functions of integers by representing those functions as arrays.

Parameters:
  • var (Variable) – A variable to bind.
  • expr (funsor) – A funsor.
class Independent(fn, reals_var, bint_var, diag_var)[source]

Bases: funsor.terms.Funsor

Creates an independent diagonal distribution.

This is equivalent to substitution followed by reduction:

f = ...  # a batched distribution
assert f.inputs['x_i'] == Reals[4, 5]
assert f.inputs['i'] == Bint[3]

g = Independent(f, 'x', 'i', 'x_i')
assert g.inputs['x'] == Reals[3, 4, 5]
assert 'x_i' not in g.inputs
assert 'i' not in g.inputs

x = Variable('x', Reals[3, 4, 5])
g == f(x_i=x['i']).reduce(ops.add, 'i')
Parameters:
  • fn (Funsor) – A funsor.
  • reals_var (str) – The name of a real-tensor input.
  • bint_var (str) – The name of a new batch input of fn.
  • diag_var – The name of a smaller-shape real input of fn.
eager_subs(subs)[source]
mean()[source]
variance()[source]
entropy()[source]
of_shape(*shape)[source]

Delta

solve(expr, value)[source]

Tries to solve for free inputs of an expr such that expr == value, and computes the log-abs-det-Jacobian of the resulting substitution.

Parameters:
  • expr (Funsor) – An expression with a free variable.
  • value (Funsor) – A target value.
Returns:

A tuple (name, point, log_abs_det_jacobian)

Return type:

tuple

Raises:

ValueError

class Delta(terms)[source]

Bases: funsor.terms.Funsor

Normalized delta distribution binding multiple variables.

There are three syntaxes supported for constructing Deltas:

Delta(((name1, (point1, log_density1)),
       (name2, (point2, log_density2)),
       (name3, (point3, log_density3))))

or for a single name:

Delta(name, point, log_density)

or for default log_density == 0:

Delta(name, point)
Parameters:terms (tuple) – A tuple of tuples of the form (name, (point, log_density)).
align(names)[source]
eager_subs(subs)[source]
eager_reduce(op, reduced_vars)[source]

Tensor

ignore_jit_warnings()[source]
class Tensor(data, inputs=None, dtype='real')[source]

Bases: funsor.terms.Funsor

Funsor backed by a PyTorch Tensor or a NumPy ndarray.

This follows the torch.distributions convention of arranging named “batch” dimensions on the left and remaining “event” dimensions on the right. The output shape is determined by all remaining dims. For example:

data = torch.zeros(5,4,3,2)
x = Tensor(data, {"i": Bint[5], "j": Bint[4]})
assert x.output == Reals[3, 2]

Operators like matmul and .sum() operate only on the output shape, and will not change the named inputs.

Parameters:
  • data (numeric_array) – A PyTorch tensor or NumPy ndarray.
  • inputs (dict) – An optional mapping from input name (str) to datatype (funsor.domains.Domain). Defaults to empty.
  • dtype (int or the string "real".) – optional output datatype. Defaults to “real”.
item()[source]
clamp_finite()[source]
requires_grad
align(names)[source]
eager_subs(subs)[source]
eager_unary(op)[source]
eager_reduce(op, reduced_vars)[source]
new_arange(name, *args, **kwargs)[source]

Helper to create a named torch.arange() or np.arange() funsor. In some cases this can be replaced by a symbolic Slice .

Parameters:
  • name (str) – A variable name.
  • start (int) –
  • stop (int) –
  • step (int) – Three args following slice semantics.
  • dtype (int) – An optional bounded integer type of this slice.
Return type:

Tensor

materialize(x)[source]

Attempt to convert a Funsor to a Number or Tensor by substituting arange() s into its free variables.

Parameters:x (Funsor) – A funsor.
Return type:Funsor
align_tensor(new_inputs, x, expand=False)[source]

Permute and add dims to a tensor to match desired new_inputs.

Parameters:
  • new_inputs (OrderedDict) – A target set of inputs.
  • x (funsor.terms.Funsor) – A Tensor or Number .
  • expand (bool) – If False (default), set result size to 1 for any input of x not in new_inputs; if True expand to new_inputs size.
Returns:

a number or torch.Tensor or np.ndarray that can be broadcast to other tensors with inputs new_inputs.

Return type:

int or float or torch.Tensor or np.ndarray

align_tensors(*args, **kwargs)[source]

Permute multiple tensors before applying a broadcasted op.

This is mainly useful for implementing eager funsor operations.

Parameters:
Returns:

a pair (inputs, tensors) where tensors are all torch.Tensor s or np.ndarray s that can be broadcast together to a single data with given inputs.

Return type:

tuple

class Function(fn, output, args)[source]

Bases: funsor.terms.Funsor

Funsor wrapped by a native PyTorch or NumPy function.

Functions are assumed to support broadcasting and can be eagerly evaluated on funsors with free variables of int type (i.e. batch dimensions).

Function s are usually created via the function() decorator.

Parameters:
  • fn (callable) – A native PyTorch or NumPy function to wrap.
  • output (type) – An output domain.
  • args (Funsor) – Funsor arguments.
function(*signature)[source]

Decorator to wrap a PyTorch/NumPy function, using either type hints or explicit type annotations.

Example:

# Using type hints:
@funsor.tensor.function
def matmul(x: Reals[3, 4], y: Reals[4, 5]) -> Reals[3, 5]:
    return torch.matmul(x, y)

# Using explicit type annotations:
@funsor.tensor.function(Reals[3, 4], Reals[4, 5], Reals[3, 5])
def matmul(x, y):
    return torch.matmul(x, y)

@funsor.tensor.function(Reals[10], Reals[10, 10], Reals[10], Real)
def mvn_log_prob(loc, scale_tril, x):
    d = torch.distributions.MultivariateNormal(loc, scale_tril)
    return d.log_prob(x)

To support functions that output nested tuples of tensors, specify a nested Tuple of output types, for example:

@funsor.tensor.function
def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]:
    return torch.max(x, dim=-1)
Parameters:*signature – A sequence if input domains followed by a final output domain or nested tuple of output domains.
Einsum(equation, *operands)[source]

Wrapper around torch.einsum() or np.einsum() to operate on real-valued Funsors.

Note this operates only on the output tensor. To perform sum-product contractions on named dimensions, instead use + and Reduce.

Parameters:
  • equation (str) – An torch.einsum() or np.einsum() equation.
  • operands (tuple) – A tuple of input funsors.
tensordot(x, y, dims)[source]

Wrapper around torch.tensordot() or np.tensordot() to operate on real-valued Funsors.

Note this operates only on the output tensor. To perform sum-product contractions on named dimensions, instead use + and Reduce.

Arguments should satisfy:

len(x.shape) >= dims
len(y.shape) >= dims
dims == 0 or x.shape[-dims:] == y.shape[:dims]
Parameters:
  • x (Funsor) – A left hand argument.
  • y (Funsor) – A y hand argument.
  • dims (int) – The number of dimension of overlap of output shape.
Return type:

Funsor

Gaussian

class BlockVector(shape)[source]

Bases: object

Jit-compatible helper to build blockwise vectors. Syntax is similar to torch.zeros()

x = BlockVector((100, 20))
x[..., 0:4] = x1
x[..., 6:10] = x2
x = x.as_tensor()
assert x.shape == (100, 20)
as_tensor()[source]
class BlockMatrix(shape)[source]

Bases: object

Jit-compatible helper to build blockwise matrices. Syntax is similar to torch.zeros()

x = BlockMatrix((100, 20, 20))
x[..., 0:4, 0:4] = x11
x[..., 0:4, 6:10] = x12
x[..., 6:10, 0:4] = x12.transpose(-1, -2)
x[..., 6:10, 6:10] = x22
x = x.as_tensor()
assert x.shape == (100, 20, 20)
as_tensor()[source]
align_gaussian(new_inputs, old, expand=False)[source]

Align data of a Gaussian distribution to a new inputs shape.

class Gaussian(white_vec, prec_sqrt, inputs)[source]

Bases: funsor.terms.Funsor

Funsor representing a batched Gaussian log-density function.

Gaussians are the internal representation for joint and conditional multivariate normal distributions and multivariate normal likelihoods. Mathematically, a Gaussian represents the quadratic log density function:

f(x) = -0.5 * || x @ prec_sqrt - white_vec ||^2
     = -0.5 * < x @ prec_sqrt - white_vec | x @ prec_sqrt - white_vec >
     = -0.5 * < x | prec_sqrt @ prec_sqrt.T | x>
       + < x | prec_sqrt | white_vec > - 0.5 ||white_vec||^2

Internally Gaussians use a square root information filter (SRIF) representation consisting of a square root of the precision matrix prec_sqrt and a vector in the whitened space white_vec. This representation allows space-efficient construction of Gaussians with incomplete information, i.e. with zero eigenvalues in the precision matrix. These incomplete log densities arise when making low-dimensional observations of higher-dimensional hidden state. Sampling and marginalization are supported only for full-rank Gaussians or full-rank subsets of Gaussians. See the rank() and is_full_rank() properties.

Note

Gaussian s are not normalized probability distributions, rather they are canonicalized to evaluate to zero log density at their maximum: f(prec_sqrt \ white_vec) = 0. Not only are Gaussians non-normalized, but they may be rank deficient and non-normalizable, in which case sampling and marginalization are supported only un full-rank subsets of variables.

Parameters:
  • white_vec (torch.Tensor) – An batched white noise vector, where white_vec = prec_sqrt.T @ mean. Alternatively you can specify one of the kwargs mean or info_vec, which will be converted to white_vec.
  • prec_sqrt (torch.Tensor) – A batched square root of the positive semidefinite precision matrix. This need not be square, and typically has shape prec_sqrt.shape == white_vec.shape[:-1] + (dim, rank), where dim is the total flattened size of real inputs and rank = white_vec.shape[-1]. Alternatively you can specify one of the kwargs precision, covariance, or scale_tril, which will be converted to prec_sqrt.
  • inputs (OrderedDict) – Mapping from name to Domain .
compression_threshold = 2
classmethod set_compression_threshold(threshold: float)[source]

Context manager to set rank compression threshold.

To save space Gaussians compress wide prec_sqrt matrices down to square. However compression uses a QR decomposition which can be expensive and which has unstable gradients when the resulting precision matrix is rank deficient. To balance space and time costs and numerical stability, compression is trigger only on prec_sqrt matrices whose width to height ratio is greater than threshold.

Parameters:threshold (float) – Defaults to 2. To optimize for space and deterministic computations, set threshold = 1. To optimize for fewest QR decompositions and numerical stability, set threshold = math.inf.
rank
is_full_rank
log_normalizer[source]
align(names)[source]
eager_subs(subs)[source]
eager_reduce(op, reduced_vars)[source]

Joint

moment_matching_contract_default(*args)[source]
moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian)[source]
eager_reduce_exp(op, arg, reduced_vars)[source]
eager_independent_joint(joint, reals_var, bint_var, diag_var)[source]

Contraction

class Contraction(red_op, bin_op, reduced_vars, terms)[source]

Bases: funsor.terms.Funsor

Declarative representation of a finitary sum-product operation.

After normalization via the normalize() interpretation contractions will canonically order their terms by type:

Delta, Number, Tensor, Gaussian
align(names)[source]
GaussianMixture

alias of funsor.cnf.Contraction

children_contraction(x)[source]
eager_contraction_generic_to_tuple(red_op, bin_op, reduced_vars, *terms)[source]
eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms)[source]
eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term)[source]
eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs)[source]
eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms)[source]
eager_contraction_gaussian(red_op, bin_op, reduced_vars, x, y)[source]
normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, *terms)[source]
normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture)[source]
normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms)[source]
normalize_trivial(red_op, bin_op, reduced_vars, term)[source]
normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms)[source]
binary_to_contract(op, lhs, rhs)[source]
reduce_funsor(op, arg, reduced_vars)[source]
unary_neg_variable(op, arg)[source]
do_fresh_subs(arg, subs)[source]
distribute_subs_contraction(arg, subs)[source]
normalize_fuse_subs(arg, subs)[source]
binary_subtract(op, lhs, rhs)[source]
binary_divide(op, lhs, rhs)[source]
unary_log_exp(op, arg)[source]
unary_contract(op, arg)[source]

Integrate

class Integrate(log_measure, integrand, reduced_vars)[source]

Bases: funsor.terms.Funsor

Funsor representing an integral wrt a log density funsor.

Parameters:
  • log_measure (Funsor) – A log density funsor treated as a measure.
  • integrand (Funsor) – An integrand funsor.
  • reduced_vars (str, Variable, or set or frozenset thereof.) – An input name or set of names to reduce.

Constant

class ConstantMeta(name, bases, dct)[source]

Bases: funsor.terms.FunsorMeta

Wrapper to convert const_inputs to a tuple.

class Constant(const_inputs, arg)[source]

Bases: funsor.terms.Funsor

Funsor that is constant wrt const_inputs.

Constant can be used for provenance tracking.

Examples:

a = Constant(OrderedDict(x=Real, y=Bint[3]), Number(0))
a(y=1)  # returns Constant(OrderedDict(x=Real), Number(0))
a(x=2, y=1)  # returns Number(0)

d = Tensor(torch.tensor([1, 2, 3]))["y"]
a + d  # returns Constant(OrderedDict(x=Real), d)

c = Constant(OrderedDict(x=Bint[3]), Number(1))
c.reduce(ops.add, "x")  # returns Number(3)
Parameters:
  • const_inputs (dict) – A mapping from input name (str) to datatype (funsor.domain.Domain).
  • arg (funsor) – A funsor that is constant wrt to const_inputs.
eager_subs(subs)[source]
eager_reduce(op, reduced_vars)[source]
align(names)[source]
materialize(x)[source]

Attempt to convert a Funsor to a Number or Tensor by substituting arange() s into its free variables.

Parameters:x (Funsor) – A funsor.
Return type:Funsor
eager_reduce_add(op, arg, reduced_vars)[source]
eager_binary_constant_constant(op, lhs, rhs)[source]
eager_binary_constant_tensor(op, lhs, rhs)[source]
eager_binary_tensor_constant(op, lhs, rhs)[source]
eager_unary(op, arg)[source]