Source code for funsor.delta

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from functools import reduce

from funsor.domains import Domain, Real
from funsor.instrument import debug_logged
from funsor.ops import AddOp, SubOp, TransformOp
from funsor.registry import KeyedRegistry
from funsor.terms import (
    Align,
    Binary,
    Funsor,
    FunsorMeta,
    Independent,
    Lambda,
    Number,
    Unary,
    Variable,
    eager,
    to_funsor,
)
from funsor.util import get_default_dtype

from . import ops


[docs]def solve(expr, value): """ Tries to solve for free inputs of an ``expr`` such that ``expr == value``, and computes the log-abs-det-Jacobian of the resulting substitution. :param Funsor expr: An expression with a free variable. :param Funsor value: A target value. :return: A tuple ``(name, point, log_abs_det_jacobian)`` :rtype: tuple :raises: ValueError """ assert isinstance(expr, Funsor) assert isinstance(value, Funsor) result = solve.dispatch(type(expr), *(expr._ast_values + (value,))) if result is None: raise ValueError("Cannot substitute into a Delta: {}".format(value)) return result
_solve = KeyedRegistry(lambda *args: None) solve.dispatch = _solve.__call__ solve.register = _solve.register @solve.register(Variable, str, Domain, Funsor) @debug_logged def solve_variable(name, output, y): assert y.output == output point = y log_density = Number(0) return name, point, log_density @solve.register(Unary, TransformOp, Funsor, Funsor) @debug_logged def solve_unary(op, arg, y): x = op.inv(y) name, point, log_density = solve(arg, x) log_density += op.log_abs_det_jacobian(x, y) return name, point, log_density class DeltaMeta(FunsorMeta): """ Makes Delta less of a pain to use by supporting Delta(name, point, log_density) """ def __call__(cls, *args): if len(args) > 1: assert len(args) == 2 or len(args) == 3 assert isinstance(args[0], str) and isinstance(args[1], Funsor) args = args + (Number(0.0),) if len(args) == 2 else args args = (((args[0], (to_funsor(args[1]), to_funsor(args[2]))),),) assert isinstance(args[0], tuple) return super().__call__(args[0])
[docs]class Delta(Funsor, metaclass=DeltaMeta): """ Normalized delta distribution binding multiple variables. There are three syntaxes supported for constructing Deltas:: Delta(((name1, (point1, log_density1)), (name2, (point2, log_density2)), (name3, (point3, log_density3)))) or for a single name:: Delta(name, point, log_density) or for default ``log_density == 0``:: Delta(name, point) :param tuple terms: A tuple of tuples of the form ``(name, (point, log_density))``. """ def __init__(self, terms): assert isinstance(terms, tuple) and len(terms) > 0 inputs = OrderedDict() for name, (point, log_density) in terms: assert isinstance(name, str) assert isinstance(point, Funsor) assert isinstance(log_density, Funsor) assert log_density.output == Real assert name not in inputs assert name not in point.inputs inputs.update({name: point.output}) inputs.update(point.inputs) output = Real fresh = frozenset(name for name, term in terms) bound = {} super(Delta, self).__init__(inputs, output, fresh, bound) self.terms = terms
[docs] def align(self, names): assert isinstance(names, tuple) assert all(name in self.fresh for name in names) if not names or names == tuple(n for n, p in self.terms): return self new_terms = tuple(sorted(self.terms, key=lambda t: names.index(t[0]))) return Delta(new_terms)
[docs] def eager_subs(self, subs): subs = OrderedDict(subs) new_terms = [] log_densities = [] for name, (point, log_density) in self.terms: if name in subs: value = subs[name] assert value.output == point.output if isinstance(value, Variable): new_terms.append((value.name, (point, log_density))) continue if not any( d.dtype == "real" for side in (value, point) for d in side.inputs.values() ): dtype = get_default_dtype() is_equal = ops.astype((value == point).all(), dtype) log_densities.append(is_equal.log() + log_density) continue # Try to invert the substitution. soln = solve(value, point) if soln is None: return None # lazily substitute new_name, new_point, point_log_density = soln new_terms.append( (new_name, (new_point, log_density + point_log_density)) ) else: new_terms.append((name, (point, log_density))) if not log_densities: return Delta(tuple(new_terms)) elif not new_terms: return reduce(ops.add, log_densities) else: return Delta(tuple(new_terms)) + reduce(ops.add, log_densities)
[docs] def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) if op in (ops.max, ops.logaddexp): if reduced_vars - self.fresh and self.fresh - reduced_vars: result = self if not reduced_vars.isdisjoint(self.fresh): result = result.eager_reduce(op, reduced_vars & self.fresh) if result is not self: if not reduced_vars.issubset(self.fresh): result = result.eager_reduce(op, reduced_vars - self.fresh) if result is not self: return result return None result_terms = [ (name, (point, log_density)) for name, (point, log_density) in self.terms if name not in reduced_vars ] result_terms, scale = [], Number(0) for name, (point, log_density) in self.terms: if name in reduced_vars: # XXX obscenely wasteful - need a lazy Zero term if point.inputs: scale += (point == point).all().log() if log_density.inputs: scale += log_density * 0.0 else: result_terms.append((name, (point, log_density))) result = Delta(tuple(result_terms)) + scale if result_terms else scale return result.reduce(op, reduced_vars - self.fresh) if op is ops.add: raise NotImplementedError("TODO Implement ops.add to simulate .to_event().") return None # defer to default implementation
def _sample(self, sampled_vars, sample_inputs, rng_key): return self
@eager.register(Binary, AddOp, Delta, Delta) def eager_add_multidelta(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): return eager_add_delta_funsor(op, lhs, rhs) if rhs.fresh.intersection(lhs.inputs): return eager_add_funsor_delta(op, lhs, rhs) return Delta(lhs.terms + rhs.terms) @eager.register(Binary, (AddOp, SubOp), Delta, (Funsor, Align)) def eager_add_delta_funsor(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): rhs = rhs( **{ name: point for name, (point, log_density) in lhs.terms if name in rhs.inputs } ) return op(lhs, rhs) return None # defer to default implementation @eager.register(Binary, AddOp, (Funsor, Align), Delta) def eager_add_funsor_delta(op, lhs, rhs): if rhs.fresh.intersection(lhs.inputs): lhs = lhs( **{ name: point for name, (point, log_density) in rhs.terms if name in lhs.inputs } ) return op(lhs, rhs) return None @eager.register(Independent, Delta, str, str, str) def eager_independent_delta(delta, reals_var, bint_var, diag_var): for i, (name, (point, log_density)) in enumerate(delta.terms): if name == diag_var: bv = Variable(bint_var, delta.inputs[bint_var]) point = Lambda(bv, point) if bint_var in log_density.inputs: log_density = log_density.reduce(ops.add, bint_var) else: log_density = log_density * delta.inputs[bint_var].dtype new_terms = ( delta.terms[:i] + ((reals_var, (point, log_density)),) + delta.terms[i + 1 :] ) return Delta(new_terms) return None __all__ = [ "Delta", "solve", ]