# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import functools
import itertools
from collections import Counter, OrderedDict, defaultdict
from functools import reduce
from typing import Tuple, Union
import opt_einsum
import funsor
import funsor.ops as ops
from funsor.affine import affine_inputs
from funsor.delta import Delta
from funsor.domains import find_domain
from funsor.gaussian import Gaussian
from funsor.interpretations import eager, normalize, reflect
from funsor.interpreter import children
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp
from funsor.tensor import Tensor
from funsor.terms import (
_INFIX,
Align,
Binary,
Funsor,
Number,
Reduce,
Subs,
Unary,
Variable,
to_funsor,
)
from funsor.typing import Variadic
from funsor.util import broadcast_shape, get_backend, quote
[docs]class Contraction(Funsor):
"""
Declarative representation of a finitary sum-product operation.
After normalization via the :func:`~funsor.terms.normalize` interpretation
contractions will canonically order their terms by type::
Delta, Number, Tensor, Gaussian
"""
def __init__(self, red_op, bin_op, reduced_vars, terms):
terms = (terms,) if isinstance(terms, Funsor) else terms
assert isinstance(red_op, AssociativeOp)
assert isinstance(bin_op, AssociativeOp)
assert all(isinstance(v, Funsor) for v in terms)
assert isinstance(reduced_vars, frozenset)
assert all(isinstance(v, Variable) for v in reduced_vars)
assert isinstance(terms, tuple) and len(terms) > 0
assert not (isinstance(red_op, NullOp) and isinstance(bin_op, NullOp))
if isinstance(red_op, NullOp):
assert not reduced_vars
elif isinstance(bin_op, NullOp):
assert len(terms) == 1
else:
assert reduced_vars and len(terms) > 1
assert (red_op, bin_op) in DISTRIBUTIVE_OPS
fresh = frozenset()
bound = {v.name: v.output for v in reduced_vars}
inputs = OrderedDict()
for v in terms:
inputs.update((k, d) for k, d in v.inputs.items() if k not in bound)
if bin_op is ops.null:
output = terms[0].output
else:
output = reduce(
lambda lhs, rhs: find_domain(bin_op, lhs, rhs),
[v.output for v in reversed(terms)],
)
super(Contraction, self).__init__(inputs, output, fresh, bound)
self.red_op = red_op
self.bin_op = bin_op
self.terms = terms
self.reduced_vars = reduced_vars
def __repr__(self):
if self.bin_op in _INFIX:
bin_op = " " + _INFIX[self.bin_op] + " "
return "{}.reduce({}, {})".format(
bin_op.join(map(repr, self.terms)),
self.red_op,
str(set(self.reduced_vars)),
)
return super().__repr__()
def __str__(self):
if self.bin_op in _INFIX:
bin_op = " " + _INFIX[self.bin_op] + " "
return "({}).reduce({}, {})".format(
bin_op.join(map(str, self.terms)),
self.red_op,
str(set(map(str, self.reduced_vars))),
)
return super().__str__()
def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
for term in self.terms:
if isinstance(term, Delta):
sampled_vars -= term.fresh
if not sampled_vars:
return self
if self.red_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
import jax
rng_keys = jax.random.split(rng_key, len(self.terms))
else:
rng_keys = [None] * len(self.terms)
if self.bin_op in (ops.null, ops.logaddexp):
# Design choice: we sample over logaddexp reductions, but leave
# logaddexp binary choices symbolic.
terms = [
term._sample(
sampled_vars.intersection(term.inputs), sample_inputs, rng_key
)
for term, rng_key in zip(self.terms, rng_keys)
]
return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)
if self.bin_op is ops.add:
# Sample variables greedily in order of the terms in which they appear.
for term in self.terms:
greedy_vars = sampled_vars.intersection(term.inputs)
if greedy_vars:
break
assert greedy_vars
greedy_terms, terms = [], []
for term in self.terms:
if greedy_vars.isdisjoint(term.inputs):
terms.append(term)
elif isinstance(term, Delta) and greedy_vars.isdisjoint(term.fresh):
terms.append(term)
else:
greedy_terms.append(term)
if len(greedy_terms) == 1:
term = greedy_terms[0]
terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(
self.red_op, self.bin_op, self.reduced_vars, *terms
)
elif (
len(greedy_terms) == 2
and isinstance(greedy_terms[0], Tensor)
and isinstance(greedy_terms[1], Gaussian)
):
discrete, gaussian = greedy_terms
term = discrete + gaussian.log_normalizer
terms.append(gaussian)
terms.append(-gaussian.log_normalizer)
terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(
self.red_op, self.bin_op, self.reduced_vars, *terms
)
elif any(
isinstance(term, funsor.distribution.Distribution)
and not greedy_vars.isdisjoint(term.value.inputs)
for term in greedy_terms
):
sampled_terms = [
term._sample(
greedy_vars.intersection(term.value.inputs),
sample_inputs,
rng_key,
)
for term, rng_key in zip(greedy_terms, rng_keys)
if isinstance(term, funsor.distribution.Distribution)
and not greedy_vars.isdisjoint(term.value.inputs)
]
result = Contraction(
self.red_op,
self.bin_op,
self.reduced_vars,
*(terms + sampled_terms)
)
else:
raise NotImplementedError(
"Unhandled case: {}".format(
", ".join(str(type(t)) for t in greedy_terms)
)
)
return result._sample(
sampled_vars - greedy_vars, sample_inputs, rng_keys[1]
)
raise TypeError(
"Cannot sample through ops ({}, {})".format(self.red_op, self.bin_op)
)
[docs] def align(self, names):
assert isinstance(names, tuple)
assert all(name in self.inputs for name in names)
new_terms = tuple(
t.align(tuple(n for n in names if n in t.inputs)) for t in self.terms
)
result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *new_terms)
if not names == tuple(result.inputs):
return Align(
result, names
) # raise NotImplementedError("TODO align all terms")
return result
def _alpha_convert(self, alpha_subs):
reduced_vars = frozenset(
to_funsor(alpha_subs.get(var.name, var), var.output)
for var in self.reduced_vars
)
alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs)
return red_op, bin_op, reduced_vars, terms
GaussianMixture = Contraction[
Union[ops.LogaddexpOp, NullOp],
ops.AddOp,
frozenset,
Tuple[Union[Tensor, Number], Gaussian],
]
@quote.register(Contraction)
def _(arg, indent, out):
line = "{}({}, {},".format(type(arg).__name__, repr(arg.red_op), repr(arg.bin_op))
out.append((indent, line))
quote.inplace(arg.reduced_vars, indent + 1, out)
i, line = out[-1]
out[-1] = i, line + ","
quote.inplace(arg.terms, indent + 1, out)
i, line = out[-1]
out[-1] = i, line + ")"
@children.register(Contraction)
def children_contraction(x):
return (x.red_op, x.bin_op, x.reduced_vars) + x.terms
[docs]@children.register(Contraction)
def children_contraction(x):
return (x.red_op, x.bin_op, x.reduced_vars) + x.terms
[docs]@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor])
def eager_contraction_generic_to_tuple(red_op, bin_op, reduced_vars, *terms):
return eager.interpret(Contraction, red_op, bin_op, reduced_vars, terms)
[docs]@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
# Count the number of terms in which each variable is reduced.
counts = Counter()
for term in terms:
counts.update(reduced_vars & term.input_vars)
# push down leaf reductions
terms = list(terms)
leaf_reduced = False
reduced_once = frozenset(v for v, count in counts.items() if count == 1)
if reduced_once:
for i, term in enumerate(terms):
unique_vars = reduced_once & term.input_vars
if unique_vars:
result = term.reduce(red_op, unique_vars)
if result is not normalize.interpret(
Contraction, red_op, ops.null, unique_vars, (term,)
):
terms[i] = result
reduced_vars -= unique_vars
leaf_reduced = True
if leaf_reduced:
return Contraction(red_op, bin_op, reduced_vars, *terms)
# exploit associativity to recursively evaluate this contraction
# a bit expensive, but handles interpreter-imposed directionality constraints
terms = tuple(terms)
reduced_twice = frozenset(v for v, count in counts.items() if count == 2)
for i, lhs in enumerate(terms[0:-1]):
for j_, rhs in enumerate(terms[i + 1 :]):
j = i + j_ + 1
unique_vars = reduced_twice.intersection(lhs.input_vars, rhs.input_vars)
result = Contraction(red_op, bin_op, unique_vars, lhs, rhs)
if result is not normalize.interpret(
Contraction, red_op, bin_op, unique_vars, (lhs, rhs)
): # did we make progress?
# pick the first evaluable pair
reduced_vars -= unique_vars
new_terms = terms[:i] + (result,) + terms[i + 1 : j] + terms[j + 1 :]
return Contraction(red_op, bin_op, reduced_vars, *new_terms)
return None
[docs]@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor)
def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term):
args = red_op, term, reduced_vars
return eager.dispatch(Reduce, *args)(*args)
[docs]@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor)
def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs):
if not reduced_vars.issubset(lhs.input_vars & rhs.input_vars):
args = red_op, bin_op, reduced_vars, (lhs, rhs)
result = eager.dispatch(Contraction, *args)(*args)
if result is not None:
return result
args = bin_op, lhs, rhs
result = eager.dispatch(Binary, *args)(*args)
if result is not None and reduced_vars:
result = eager.interpret(Reduce, red_op, result, reduced_vars)
return result
@eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Tensor, Tensor)
def eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms):
if not all(term.dtype == "real" for term in terms):
raise NotImplementedError("TODO")
backend = BACKEND_TO_EINSUM_BACKEND[get_backend()]
return _eager_contract_tensors(reduced_vars, terms, backend=backend)
[docs]@eager.register(Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Tensor, Tensor)
def eager_contraction_tensor(red_op, bin_op, reduced_vars, *terms):
if not all(term.dtype == "real" for term in terms):
raise NotImplementedError("TODO")
backend = BACKEND_TO_LOGSUMEXP_BACKEND[get_backend()]
return _eager_contract_tensors(reduced_vars, terms, backend=backend)
# TODO Consider using this for more than binary contractions.
def _eager_contract_tensors(reduced_vars, terms, backend):
iter_symbols = map(opt_einsum.get_symbol, itertools.count())
symbols = defaultdict(functools.partial(next, iter_symbols))
inputs = OrderedDict()
einsum_inputs = []
operands = []
for term in terms:
inputs.update(term.inputs)
einsum_inputs.append(
"".join(symbols[k] for k in term.inputs)
+ "".join(
symbols[i - len(term.shape)]
for i, size in enumerate(term.shape)
if size != 1
)
)
# Squeeze absent event dims to be compatible with einsum.
data = term.data
batch_shape = data.shape[: len(data.shape) - len(term.shape)]
event_shape = tuple(size for size in term.shape if size != 1)
data = data.reshape(batch_shape + event_shape)
operands.append(data)
for var in reduced_vars:
inputs.pop(var.name, None)
batch_shape = tuple(v.size for v in inputs.values())
event_shape = broadcast_shape(*(term.shape for term in terms))
einsum_output = "".join(symbols[k] for k in inputs) + "".join(
symbols[dim] for dim in range(-len(event_shape), 0) if dim in symbols
)
equation = ",".join(einsum_inputs) + "->" + einsum_output
data = opt_einsum.contract(equation, *operands, backend=backend)
data = data.reshape(batch_shape + event_shape)
return Tensor(data, inputs)
# TODO(https://github.com/pyro-ppl/funsor/issues/238) Use a port of
# Pyro's gaussian_tensordot() here. Until then we must eagerly add the
# possibly-rank-deficient terms before reducing to avoid Cholesky errors.
[docs]@eager.register(
Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, GaussianMixture, GaussianMixture
)
def eager_contraction_gaussian(red_op, bin_op, reduced_vars, x, y):
return (x + y).reduce(red_op, reduced_vars)
@affine_inputs.register(Contraction)
def _(fn):
with reflect:
flat = reduce(fn.bin_op, fn.terms).reduce(fn.red_op, fn.reduced_vars)
return affine_inputs(flat)
##########################################
# Normalizing Contractions
##########################################
ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4, Unary[ops.NegOp, Gaussian]: 5}
GROUND_TERMS = tuple(ORDERING)
[docs]@normalize.register(
Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GROUND_TERMS
)
def normalize_contraction_commutative_canonical_order(
red_op, bin_op, reduced_vars, *terms
):
# when bin_op is commutative, put terms into a canonical order for pattern matching
new_terms = tuple(
v
for i, v in sorted(
enumerate(terms),
key=lambda t: (ORDERING.get(type(t[1]).__origin__, -1), t[0]),
)
)
if any(v is not vv for v, vv in zip(terms, new_terms)):
return Contraction(red_op, bin_op, reduced_vars, *new_terms)
return normalize.interpret(Contraction, red_op, bin_op, reduced_vars, new_terms)
@normalize.register(
Contraction, AssociativeOp, ops.AddOp, frozenset, GaussianMixture, GROUND_TERMS
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other):
return Contraction(
mixture.red_op if red_op is ops.null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
)
[docs]@normalize.register(
Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GaussianMixture
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture):
return Contraction(
mixture.red_op if red_op is ops.null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
)
[docs]@normalize.register(
Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor]
)
def normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms):
return normalize.interpret(Contraction, red_op, bin_op, reduced_vars, tuple(terms))
[docs]@normalize.register(Contraction, NullOp, NullOp, frozenset, Funsor)
def normalize_trivial(red_op, bin_op, reduced_vars, term):
assert not reduced_vars
return term
[docs]@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
if not reduced_vars and red_op is not ops.null:
return Contraction(ops.null, bin_op, reduced_vars, *terms)
if len(terms) == 1 and bin_op is not ops.null:
return Contraction(red_op, ops.null, reduced_vars, *terms)
if red_op is ops.null and bin_op is ops.null:
return terms[0]
if red_op is bin_op:
new_terms = tuple(v.reduce(red_op, reduced_vars) for v in terms)
return Contraction(red_op, bin_op, frozenset(), *new_terms)
if bin_op in ops.UNITS and any(
isinstance(t, Number) and t.data == ops.UNITS[bin_op] for t in terms
):
new_terms = tuple(
t
for t in terms
if not (isinstance(t, Number) and t.data == ops.UNITS[bin_op])
)
if not new_terms: # everything was a unit
new_terms = (terms[0],)
return Contraction(red_op, bin_op, reduced_vars, *new_terms)
for i, v in enumerate(terms):
if not isinstance(v, Contraction):
continue
# fuse operations without distributing
if (v.red_op is ops.null and bin_op is v.bin_op) or (
bin_op is ops.null and v.red_op in (red_op, ops.null)
):
red_op = v.red_op if red_op is ops.null else red_op
bin_op = v.bin_op if bin_op is ops.null else bin_op
new_terms = terms[:i] + v.terms + terms[i + 1 :]
return Contraction(
red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms
)
# nothing more to do, reflect
return None
#########################################
# Creating Contractions from other terms
#########################################
[docs]@normalize.register(Binary, AssociativeOp, Funsor, Funsor)
def binary_to_contract(op, lhs, rhs):
return Contraction(ops.null, op, frozenset(), lhs, rhs)
[docs]@normalize.register(Reduce, AssociativeOp, Funsor, frozenset)
def reduce_funsor(op, arg, reduced_vars):
return Contraction(op, ops.null, reduced_vars, arg)
[docs]@normalize.register(
Unary,
ops.NegOp,
(Variable, Contraction[ops.AssociativeOp, ops.MulOp, frozenset, tuple]),
)
def unary_neg_variable(op, arg):
return arg * -1
#######################################################################
# Distributing Unary transformations (Subs, log, exp, neg, reciprocal)
#######################################################################
[docs]@normalize.register(Subs, Funsor, tuple)
def do_fresh_subs(arg, subs):
if not subs:
return arg
if all(name in arg.fresh for name, sub in subs):
return arg.eager_subs(subs)
return None
[docs]@normalize.register(Subs, Contraction, tuple)
def distribute_subs_contraction(arg, subs):
new_terms = tuple(
Subs(v, tuple((name, sub) for name, sub in subs if name in v.inputs))
if any(name in v.inputs for name, sub in subs)
else v
for v in arg.terms
)
return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *new_terms)
[docs]@normalize.register(Subs, Subs, tuple)
def normalize_fuse_subs(arg, subs):
# a(b)(c) -> a(b(c), c)
arg_subs = (
tuple(arg.subs.items()) if isinstance(arg.subs, OrderedDict) else arg.subs
)
new_subs = subs + tuple((k, Subs(v, subs)) for k, v in arg_subs)
return Subs(arg.arg, new_subs)
[docs]@normalize.register(Binary, ops.SubOp, Funsor, Funsor)
def binary_subtract(op, lhs, rhs):
return lhs + -rhs
[docs]@normalize.register(Binary, ops.TruedivOp, Funsor, Funsor)
def binary_divide(op, lhs, rhs):
return lhs * Unary(ops.reciprocal, rhs)
[docs]@normalize.register(Unary, ops.ExpOp, Unary[ops.LogOp, Funsor])
@normalize.register(Unary, ops.LogOp, Unary[ops.ExpOp, Funsor])
@normalize.register(Unary, ops.NegOp, Unary[ops.NegOp, Funsor])
@normalize.register(Unary, ops.ReciprocalOp, Unary[ops.ReciprocalOp, Funsor])
def unary_log_exp(op, arg):
return arg.arg
[docs]@normalize.register(
Unary, ops.ReciprocalOp, Contraction[NullOp, ops.MulOp, frozenset, tuple]
)
@normalize.register(Unary, ops.NegOp, Contraction[NullOp, ops.AddOp, frozenset, tuple])
def unary_contract(op, arg):
return Contraction(
arg.red_op, arg.bin_op, arg.reduced_vars, *(op(t) for t in arg.terms)
)
BACKEND_TO_EINSUM_BACKEND = {
"numpy": "numpy",
"torch": "torch",
"jax": "jax.numpy",
}
# NB: numpy_log, numpy_map is backend-agnostic so they also work for torch backend;
# however, we might need to profile to make a switch
BACKEND_TO_LOGSUMEXP_BACKEND = {
"numpy": "funsor.einsum.numpy_log",
"torch": "pyro.ops.einsum.torch_log",
"jax": "funsor.einsum.numpy_log",
}
BACKEND_TO_MAP_BACKEND = {
"numpy": "funsor.einsum.numpy_map",
"torch": "pyro.ops.einsum.torch_map",
"jax": "funsor.einsum.numpy_map",
}