Source code for funsor.integrate

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from functools import reduce
from typing import Tuple, Union

import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.constant import Constant
from funsor.delta import Delta
from funsor.gaussian import Gaussian, _norm2, _vm, align_gaussian
from funsor.interpretations import eager, normalize
from funsor.tensor import Tensor
from funsor.terms import (
    Funsor,
    FunsorMeta,
    Number,
    Subs,
    Unary,
    Variable,
    _convert_reduced_vars,
    substitute,
    to_funsor,
)


class IntegrateMeta(FunsorMeta):
    """
    Wrapper to convert reduced_vars arg to a frozenset of str.
    """

    def __call__(cls, log_measure, integrand, reduced_vars):
        inputs = log_measure.inputs.copy()
        inputs.update(integrand.inputs)
        reduced_vars = _convert_reduced_vars(reduced_vars, inputs)
        return super().__call__(log_measure, integrand, reduced_vars)


[docs]class Integrate(Funsor, metaclass=IntegrateMeta): """ Funsor representing an integral wrt a log density funsor. :param Funsor log_measure: A log density funsor treated as a measure. :param Funsor integrand: An integrand funsor. :param reduced_vars: An input name or set of names to reduce. :type reduced_vars: str, Variable, or set or frozenset thereof. """ def __init__(self, log_measure, integrand, reduced_vars): assert isinstance(log_measure, Funsor) assert isinstance(integrand, Funsor) assert isinstance(reduced_vars, frozenset) assert all(isinstance(v, Variable) for v in reduced_vars) reduced_names = frozenset(v.name for v in reduced_vars) inputs = OrderedDict( (k, d) for term in (log_measure, integrand) for (k, d) in term.inputs.items() if k not in reduced_names ) output = integrand.output fresh = frozenset() bound = {v.name: v.output for v in reduced_vars} super(Integrate, self).__init__(inputs, output, fresh, bound) self.log_measure = log_measure self.integrand = integrand self.reduced_vars = reduced_vars def _alpha_convert(self, alpha_subs): assert set(self.bound).issuperset(alpha_subs) reduced_vars = frozenset( Variable(alpha_subs.get(v.name, v.name), v.output) for v in self.reduced_vars ) alpha_subs = { k: to_funsor( v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k)) ) for k, v in alpha_subs.items() } log_measure = substitute(self.log_measure, alpha_subs) integrand = substitute(self.integrand, alpha_subs) return log_measure, integrand, reduced_vars
@normalize.register(Integrate, Funsor, Funsor, frozenset) def normalize_integrate(log_measure, integrand, reduced_vars): return Contraction(ops.add, ops.mul, reduced_vars, log_measure.exp(), integrand) @normalize.register( Integrate, Contraction[Union[ops.NullOp, ops.LogaddexpOp], ops.AddOp, frozenset, tuple], Funsor, frozenset, ) def normalize_integrate_contraction(log_measure, integrand, reduced_vars): reduced_names = frozenset(v.name for v in reduced_vars) delta_terms = [ t for t in log_measure.terms if isinstance(t, Delta) and t.fresh.intersection(reduced_names, integrand.inputs) ] for delta in delta_terms: delta_fresh = frozenset(Variable(k, delta.inputs[k]) for k in delta.fresh) args = delta, integrand, delta_fresh integrand = eager.dispatch(Integrate, *args)(*args) return normalize_integrate(log_measure, integrand, reduced_vars) EagerConstant = Constant[ Tuple, Union[ Variable, Delta, Gaussian, Unary[ops.NegOp, Gaussian], Number, Tensor, GaussianMixture, ], ] @eager.register( Contraction, ops.AddOp, ops.MulOp, frozenset, Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]], ( Variable, Delta, Gaussian, Unary[ops.NegOp, Gaussian], Number, Tensor, GaussianMixture, EagerConstant, ), ) def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs): reduced_names = frozenset(v.name for v in reduced_vars) if not (reduced_names.issubset(lhs.inputs) and reduced_names.issubset(rhs.inputs)): args = red_op, bin_op, reduced_vars, (lhs, rhs) result = eager.dispatch(Contraction, *args)(*args) if result is not None: return result args = lhs.log(), rhs, reduced_vars result = eager.dispatch(Integrate, *args)(*args) if result is not None: return result return None @eager.register(Integrate, GaussianMixture, Funsor, frozenset) def eager_integrate_gaussianmixture(log_measure, integrand, reduced_vars): real_vars = frozenset(v for v in reduced_vars if v.dtype == "real") if reduced_vars <= real_vars: discrete, gaussian = log_measure.terms return discrete.exp() * Integrate(gaussian, integrand, reduced_vars) return None ######################################## # Delta patterns ######################################## @eager.register(Integrate, Delta, Funsor, frozenset) def eager_integrate(delta, integrand, reduced_vars): delta_fresh = frozenset(Variable(k, delta.inputs[k]) for k in delta.fresh) if reduced_vars.isdisjoint(delta_fresh): return None reduced_names = frozenset(v.name for v in reduced_vars) subs = tuple( (name, point) for name, (point, log_density) in delta.terms if name in reduced_names ) new_integrand = Subs(integrand, subs) new_log_measure = Subs(delta, subs) result = Integrate(new_log_measure, new_integrand, reduced_vars - delta_fresh) return result ######################################## # Gaussian patterns ######################################## @eager.register(Integrate, Gaussian, Variable, frozenset) def eager_integrate_gaussian_variable(log_measure, integrand, reduced_vars): real_input_vars = frozenset(v for v in log_measure.input_vars if v.dtype == "real") real_vars = reduced_vars & real_input_vars if real_vars == frozenset([integrand]): if real_vars != real_input_vars: return None # TODO implement this loc = log_measure._mean data = loc * ops.unsqueeze(ops.exp(log_measure._log_normalizer), -1) data = data.reshape(loc.shape[:-1] + integrand.output.shape) inputs = OrderedDict( (k, d) for k, d in log_measure.inputs.items() if d.dtype != "real" ) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) return None # defer to default implementation @eager.register(Integrate, Gaussian, Gaussian, frozenset) def eager_integrate_gaussian_gaussian(log_measure, integrand, reduced_vars): assert log_measure.is_full_rank reduced_names = frozenset(v.name for v in reduced_vars) real_vars = frozenset(v.name for v in reduced_vars if v.dtype == "real") if real_vars: lhs_reals = frozenset( k for k, d in log_measure.inputs.items() if d.dtype == "real" ) rhs_reals = frozenset( k for k, d in integrand.inputs.items() if d.dtype == "real" ) if lhs_reals == real_vars and rhs_reals <= real_vars: inputs = OrderedDict( (k, d) for t in (log_measure, integrand) for k, d in t.inputs.items() ) lhs_white_vec, lhs_prec_sqrt = align_gaussian(inputs, log_measure) rhs_white_vec, rhs_prec_sqrt = align_gaussian(inputs, integrand) lhs = Gaussian( white_vec=lhs_white_vec, prec_sqrt=lhs_prec_sqrt, inputs=inputs ) # Compute the expectation of a non-normalized quadratic form. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf # If x ~ N(mean,cov) then # E[(x-m)' A (x-m)] = (m-mean)'A(m-mean) + Tr(A cov) # eq. 380 # To perform this computation in rhs's internal space, we first transform # lhs to rhs's whitened space mean = _vm(lhs._mean, rhs_prec_sqrt) norm = ops.exp(lhs._log_normalizer) # Then in rhs's whitened space, A = I so Tr(A cov) = Tr(cov). vmv_term = _norm2(rhs_white_vec - mean) trace_term = ( (ops.triangular_solve(rhs_prec_sqrt, lhs._precision_chol) ** 2) .sum(-1) .sum(-1) ) data = (-0.5) * norm * (vmv_term + trace_term) inputs = OrderedDict( (k, d) for k, d in inputs.items() if k not in reduced_names ) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_names - real_vars) raise NotImplementedError("TODO implement partial integration") return None # defer to default implementation @eager.register(Integrate, Gaussian, Unary[ops.NegOp, Gaussian], frozenset) def eager_integrate_neg_gaussian(log_measure, integrand, reduced_vars): return -Integrate(log_measure, integrand.arg, reduced_vars) @eager.register( Integrate, Gaussian, Contraction[ ops.NullOp, ops.AddOp, frozenset, Tuple[Union[Gaussian, Unary[ops.NegOp, Gaussian]], ...], ], frozenset, ) def eager_distribute_integrate(log_measure, integrand, reduced_vars): return reduce( ops.add, [ -Integrate(log_measure, term.arg, reduced_vars) if isinstance(term, Unary) else Integrate(log_measure, term, reduced_vars) for term in integrand.terms ], ) __all__ = [ "Integrate", ]