Source code for funsor.integrate

from collections import OrderedDict
from typing import Union

import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian, cholesky_inverse
from funsor.terms import (
    Funsor,
    FunsorMeta,
    Number,
    Subs,
    Unary,
    Variable,
    _convert_reduced_vars,
    eager,
    normalize,
    substitute,
    to_funsor
)
from funsor.torch import Tensor


class IntegrateMeta(FunsorMeta):
    """
    Wrapper to convert reduced_vars arg to a frozenset of str.
    """
    def __call__(cls, log_measure, integrand, reduced_vars):
        reduced_vars = _convert_reduced_vars(reduced_vars)
        return super().__call__(log_measure, integrand, reduced_vars)


[docs]class Integrate(Funsor, metaclass=IntegrateMeta): """ Funsor representing an integral wrt a log density funsor. :param Funsor log_measure: A log density funsor treated as a measure. :param Funsor integrand: An integrand funsor. :param reduced_vars: An input name or set of names to reduce. :type reduced_vars: str, Variable, or set or frozenset thereof. """ def __init__(self, log_measure, integrand, reduced_vars): assert isinstance(log_measure, Funsor) assert isinstance(integrand, Funsor) assert isinstance(reduced_vars, frozenset) assert all(isinstance(v, str) for v in reduced_vars) inputs = OrderedDict((k, d) for term in (log_measure, integrand) for (k, d) in term.inputs.items() if k not in reduced_vars) output = integrand.output fresh = frozenset() bound = reduced_vars super(Integrate, self).__init__(inputs, output, fresh, bound) self.log_measure = log_measure self.integrand = integrand self.reduced_vars = reduced_vars def _alpha_convert(self, alpha_subs): assert self.bound.issuperset(alpha_subs) reduced_vars = frozenset(alpha_subs.get(k, k) for k in self.reduced_vars) alpha_subs = {k: to_funsor(v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k))) for k, v in alpha_subs.items()} log_measure = substitute(self.log_measure, alpha_subs) integrand = substitute(self.integrand, alpha_subs) return log_measure, integrand, reduced_vars
@normalize.register(Integrate, Funsor, Funsor, frozenset) def normalize_integrate(log_measure, integrand, reduced_vars): return Contraction(ops.add, ops.mul, reduced_vars, log_measure.exp(), integrand) @normalize.register(Integrate, Contraction[Union[ops.NullOp, ops.LogAddExpOp], ops.AddOp, frozenset, tuple], Funsor, frozenset) def normalize_integrate_contraction(log_measure, integrand, reduced_vars): delta_terms = [t for t in log_measure.terms if isinstance(t, Delta) and t.fresh.intersection(reduced_vars, integrand.inputs)] for delta in delta_terms: integrand = integrand(**{name: point for name, (point, log_density) in delta.terms if name in reduced_vars.intersection(integrand.inputs)}) return normalize_integrate(log_measure, integrand, reduced_vars) @eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]], (Variable, Delta, Gaussian, Number, Tensor, GaussianMixture)) def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs): if reduced_vars - reduced_vars.intersection(lhs.inputs, rhs.inputs): args = red_op, bin_op, reduced_vars, (lhs, rhs) result = eager.dispatch(Contraction, *args)(*args) if result is not None: return result args = lhs.log(), rhs, reduced_vars result = eager.dispatch(Integrate, *args)(*args) if result is not None: return result return None @eager.register(Integrate, GaussianMixture, Funsor, frozenset) def eager_integrate_gaussianmixture(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if reduced_vars <= real_vars: discrete, gaussian = log_measure.terms return discrete.exp() * Integrate(gaussian, integrand, reduced_vars) return None ######################################## # Delta patterns ######################################## @eager.register(Integrate, Delta, Funsor, frozenset) def eager_integrate(delta, integrand, reduced_vars): if not reduced_vars & delta.fresh: return None subs = tuple((name, point) for name, (point, log_density) in delta.terms if name in reduced_vars) new_integrand = Subs(integrand, subs) new_log_measure = Subs(delta, subs) result = Integrate(new_log_measure, new_integrand, reduced_vars - delta.fresh) return result ######################################## # Gaussian patterns ######################################## @eager.register(Integrate, Gaussian, Variable, frozenset) def eager_integrate(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if real_vars == frozenset([integrand.name]): loc = log_measure.info_vec.unsqueeze(-1).cholesky_solve(log_measure._precision_chol).squeeze(-1) data = loc * log_measure.log_normalizer.data.exp().unsqueeze(-1) data = data.reshape(loc.shape[:-1] + integrand.output.shape) inputs = OrderedDict((k, d) for k, d in log_measure.inputs.items() if d.dtype != 'real') result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) return None # defer to default implementation @eager.register(Integrate, Gaussian, Gaussian, frozenset) def eager_integrate(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if real_vars: lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') if lhs_reals == real_vars and rhs_reals <= real_vars: inputs = OrderedDict((k, d) for t in (log_measure, integrand) for k, d in t.inputs.items()) lhs_info_vec, lhs_precision = align_gaussian(inputs, log_measure) rhs_info_vec, rhs_precision = align_gaussian(inputs, integrand) lhs = Gaussian(lhs_info_vec, lhs_precision, inputs) # Compute the expectation of a non-normalized quadratic form. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf norm = lhs.log_normalizer.data.exp() lhs_cov = cholesky_inverse(lhs._precision_chol) lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve(lhs._precision_chol).squeeze(-1) vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc)) data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov)) inputs = OrderedDict((k, d) for k, d in inputs.items() if k not in reduced_vars) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) raise NotImplementedError('TODO implement partial integration') return None # defer to default implementation __all__ = [ 'Integrate', ]