# Source code for funsor.integrate

```# Copyright Contributors to the Pyro project.

from collections import OrderedDict
from typing import Union

import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.gaussian import Gaussian, align_gaussian, _mv, _trace_mm, _vv
from funsor.tensor import Tensor
from funsor.terms import (
Funsor,
FunsorMeta,
Number,
Subs,
Unary,
Variable,
_convert_reduced_vars,
eager,
normalize,
substitute,
to_funsor
)

class IntegrateMeta(FunsorMeta):
"""
Wrapper to convert reduced_vars arg to a frozenset of str.
"""
def __call__(cls, log_measure, integrand, reduced_vars):
reduced_vars = _convert_reduced_vars(reduced_vars)
return super().__call__(log_measure, integrand, reduced_vars)

[docs]class Integrate(Funsor, metaclass=IntegrateMeta):
"""
Funsor representing an integral wrt a log density funsor.

:param Funsor log_measure: A log density funsor treated as a measure.
:param Funsor integrand: An integrand funsor.
:param reduced_vars: An input name or set of names to reduce.
:type reduced_vars: str, Variable, or set or frozenset thereof.
"""
def __init__(self, log_measure, integrand, reduced_vars):
assert isinstance(log_measure, Funsor)
assert isinstance(integrand, Funsor)
assert isinstance(reduced_vars, frozenset)
assert all(isinstance(v, str) for v in reduced_vars)
inputs = OrderedDict((k, d) for term in (log_measure, integrand)
for (k, d) in term.inputs.items()
if k not in reduced_vars)
output = integrand.output
fresh = frozenset()
bound = reduced_vars
super(Integrate, self).__init__(inputs, output, fresh, bound)
self.log_measure = log_measure
self.integrand = integrand
self.reduced_vars = reduced_vars

def _alpha_convert(self, alpha_subs):
assert self.bound.issuperset(alpha_subs)
reduced_vars = frozenset(alpha_subs.get(k, k) for k in self.reduced_vars)
alpha_subs = {k: to_funsor(v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k)))
for k, v in alpha_subs.items()}
log_measure = substitute(self.log_measure, alpha_subs)
integrand = substitute(self.integrand, alpha_subs)
return log_measure, integrand, reduced_vars

@normalize.register(Integrate, Funsor, Funsor, frozenset)
def normalize_integrate(log_measure, integrand, reduced_vars):
return Contraction(ops.add, ops.mul, reduced_vars, log_measure.exp(), integrand)

@normalize.register(Integrate,
Funsor, frozenset)
def normalize_integrate_contraction(log_measure, integrand, reduced_vars):
delta_terms = [t for t in log_measure.terms if isinstance(t, Delta)
and t.fresh.intersection(reduced_vars, integrand.inputs)]
for delta in delta_terms:
integrand = integrand(**{name: point for name, (point, log_density) in delta.terms
if name in reduced_vars.intersection(integrand.inputs)})
return normalize_integrate(log_measure, integrand, reduced_vars)

Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]],
(Variable, Delta, Gaussian, Number, Tensor, GaussianMixture))
def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs):

if reduced_vars - reduced_vars.intersection(lhs.inputs, rhs.inputs):
args = red_op, bin_op, reduced_vars, (lhs, rhs)
result = eager.dispatch(Contraction, *args)(*args)
if result is not None:
return result

args = lhs.log(), rhs, reduced_vars
result = eager.dispatch(Integrate, *args)(*args)
if result is not None:
return result

return None

@eager.register(Integrate, GaussianMixture, Funsor, frozenset)
def eager_integrate_gaussianmixture(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if reduced_vars <= real_vars:
discrete, gaussian = log_measure.terms
return discrete.exp() * Integrate(gaussian, integrand, reduced_vars)
return None

########################################
# Delta patterns
########################################

@eager.register(Integrate, Delta, Funsor, frozenset)
def eager_integrate(delta, integrand, reduced_vars):
if not reduced_vars & delta.fresh:
return None
subs = tuple((name, point) for name, (point, log_density) in delta.terms
if name in reduced_vars)
new_integrand = Subs(integrand, subs)
new_log_measure = Subs(delta, subs)
result = Integrate(new_log_measure, new_integrand, reduced_vars - delta.fresh)
return result

########################################
# Gaussian patterns
########################################

@eager.register(Integrate, Gaussian, Variable, frozenset)
def eager_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars == frozenset([integrand.name]):
loc = ops.cholesky_solve(ops.unsqueeze(log_measure.info_vec, -1), log_measure._precision_chol).squeeze(-1)
data = loc * ops.unsqueeze(ops.exp(log_measure.log_normalizer.data), -1)
data = data.reshape(loc.shape[:-1] + integrand.output.shape)
inputs = OrderedDict((k, d) for k, d in log_measure.inputs.items() if d.dtype != 'real')
result = Tensor(data, inputs)
return None  # defer to default implementation

@eager.register(Integrate, Gaussian, Gaussian, frozenset)
def eager_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars:

lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real')
rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real')
if lhs_reals == real_vars and rhs_reals <= real_vars:
inputs = OrderedDict((k, d) for t in (log_measure, integrand)
for k, d in t.inputs.items())
lhs_info_vec, lhs_precision = align_gaussian(inputs, log_measure)
rhs_info_vec, rhs_precision = align_gaussian(inputs, integrand)
lhs = Gaussian(lhs_info_vec, lhs_precision, inputs)

# Compute the expectation of a non-normalized quadratic form.
# See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380.
# http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
norm = ops.exp(lhs.log_normalizer.data)
lhs_cov = ops.cholesky_inverse(lhs._precision_chol)
lhs_loc = ops.cholesky_solve(ops.unsqueeze(lhs.info_vec, -1), lhs._precision_chol).squeeze(-1)
vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc))
data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov))
inputs = OrderedDict((k, d) for k, d in inputs.items() if k not in reduced_vars)
result = Tensor(data, inputs)