Source code for funsor.distribution

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
import importlib
import inspect
import math
import typing
import warnings
from collections import OrderedDict
from importlib import import_module

import makefun

import funsor.delta
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.cnf import Contraction, GaussianMixture
from funsor.domains import Array, Real, Reals
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.tensor import (
    Tensor,
    align_tensors,
    dummy_numeric_array,
    get_default_prototype,
    ignore_jit_warnings,
    numeric_array,
)
from funsor.terms import (
    Finitary,
    Funsor,
    FunsorMeta,
    Independent,
    Lambda,
    Number,
    Variable,
    eager,
    reflect,
    to_data,
    to_funsor,
)
from funsor.typing import deep_isinstance
from funsor.util import (
    broadcast_shape,
    get_backend,
    get_default_dtype,
    getargspec,
    lazy_property,
)

BACKEND_TO_DISTRIBUTIONS_BACKEND = {
    "torch": "funsor.torch.distributions",
    "jax": "funsor.jax.distributions",
}


def numbers_to_tensors(*args):
    """
    Convert :class:`~funsor.terms.Number` s to :class:`funsor.tensor.Tensor` s,
    using any provided tensor as a prototype, if available.
    """
    if any(isinstance(x, Number) for x in args):
        prototype = get_default_prototype()
        options = dict(dtype=prototype.dtype)
        for x in args:
            if isinstance(x, Tensor):
                options = dict(
                    dtype=x.data.dtype, device=getattr(x.data, "device", None)
                )
                break
        with ignore_jit_warnings():
            args = tuple(
                Tensor(numeric_array(x.data, **options), dtype=x.dtype)
                if isinstance(x, Number)
                else x
                for x in args
            )
    return args


class DistributionMeta(FunsorMeta):
    """
    Wrapper to fill in default values and convert Numbers to Tensors.
    """

    def __call__(cls, *args, **kwargs):
        kwargs.update(zip(cls._ast_fields, args))
        kwargs["value"] = kwargs.get("value", "value")
        kwargs = OrderedDict(
            (k, kwargs[k]) for k in cls._ast_fields
        )  # make sure args are sorted

        domains = OrderedDict()
        for k, v in kwargs.items():
            if k == "value":
                continue

            # compute unbroadcasted param domains
            domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))
            # use to_funsor to infer output dimensions of e.g. tensors
            domains[k] = domain if domain is not None else to_funsor(v).output

            # broadcast individual param domains with Funsor inputs
            # this avoids .expand-ing underlying parameter tensors
            dtype = domains[k].dtype
            if isinstance(v, Funsor):
                domains[k] = Array[dtype, broadcast_shape(v.shape, domains[k].shape)]
            elif ops.is_numeric_array(v):
                domains[k] = Array[dtype, broadcast_shape(v.shape, domains[k].shape)]

        # now use the broadcasted parameter shapes to infer the event_shape
        domains["value"] = cls._infer_value_domain(**domains)

        # finally, perform conversions to funsors
        kwargs = OrderedDict(
            (k, to_funsor(v, output=domains[k])) for k, v in kwargs.items()
        )
        args = numbers_to_tensors(*kwargs.values())

        return super(DistributionMeta, cls).__call__(*args)


[docs]class Distribution(Funsor, metaclass=DistributionMeta): r""" Funsor backed by a PyTorch/JAX distribution object. :param \*args: Distribution-dependent parameters. These can be either funsors or objects that can be coerced to funsors via :func:`~funsor.terms.to_funsor` . See derived classes for details. """ dist_class = "defined by derived classes" def __init__(self, *args): params = tuple(zip(self._ast_fields, args)) assert any(k == "value" for k, v in params) inputs = OrderedDict() for name, value in params: assert isinstance(name, str) assert isinstance(value, Funsor) inputs.update(value.inputs) inputs = OrderedDict(inputs) output = Real super(Distribution, self).__init__(inputs, output) self.params = OrderedDict(params) def __repr__(self): return "{}({})".format( type(self).__name__, ", ".join("{}={}".format(*kv) for kv in self.params.items()), )
[docs] def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) if ( op is ops.logaddexp and isinstance(self.value, Variable) and self.value.name in reduced_vars ): return Number(0.0) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars)
def _get_raw_dist(self): """ Internal method for working with underlying distribution attributes """ value_name = [ name for name, domain in self.value.inputs.items() # TODO is this right? if domain == self.value.output ][0] # arbitrary name-dim mapping, since we're converting back to a funsor anyway name_to_dim = { name: -dim - 1 for dim, (name, domain) in enumerate(self.inputs.items()) if isinstance(domain.dtype, int) and name != value_name } raw_dist = to_data(self, name_to_dim=name_to_dim) dim_to_name = {dim: name for name, dim in name_to_dim.items()} # also return value output, dim_to_name for converting results back to funsor value_output = self.inputs[value_name] return raw_dist, value_name, value_output, dim_to_name @property def has_enumerate_support(self): return getattr(self.dist_class, "has_enumerate_support", False)
[docs] @classmethod def eager_log_prob(cls, *params): params, value = params[:-1], params[-1] params = params + (Variable("value", value.output),) instance = reflect.interpret(cls, *params) raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() assert value.output == value_output name_to_dim = {v: k for k, v in dim_to_name.items()} dim_to_name.update( { -1 - d - len(raw_dist.batch_shape): name for d, name in enumerate(value.inputs) if name not in name_to_dim } ) name_to_dim.update( {v: k for k, v in dim_to_name.items() if v not in name_to_dim} ) raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) # this logic ensures that the inputs have the canonical order # implied by align_tensors, which is assumed pervasively in tests inputs = OrderedDict() for x in params[:-1] + (value,): inputs.update(x.inputs) return log_prob.align(tuple(inputs))
def _sample(self, sampled_vars, sample_inputs, rng_key): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()): dim_to_name[-d - len(raw_dist.batch_shape)] = name if value_name not in sampled_vars: return self sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = ( (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) ) if raw_dist.has_rsample: raw_value = raw_dist.rsample(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) funsor_value = to_funsor( raw_value, output=value_output, dim_to_name=dim_to_name ) funsor_value = funsor_value.align( tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs) ) if not raw_dist.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) dice_factor = to_funsor( raw_log_prob - ops.detach(raw_log_prob), output=self.output, dim_to_name=dim_to_name, ) result = funsor.delta.Delta(value_name, funsor_value, dice_factor) else: result = funsor.delta.Delta(value_name, funsor_value) return result
[docs] def enumerate_support(self, expand=False): assert self.has_enumerate_support and isinstance(self.value, Variable) raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.enumerate_support(expand=expand) dim_to_name[min(dim_to_name.keys(), default=0) - 1] = value_name return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)
[docs] def entropy(self): raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.entropy() return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name)
[docs] def mean(self): raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.mean return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)
[docs] def variance(self): raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.variance return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)
def __getattribute__(self, attr): if attr in type(self)._ast_fields and attr != "name": return self.params[attr] return super().__getattribute__(attr) @classmethod def _infer_value_dtype(cls, domains): try: support = cls.dist_class.support except NotImplementedError: raise NotImplementedError( f"Failed to infer dtype of {cls.dist_class.__name__}" ) while hasattr(support, "base_constraint"): support = support.base_constraint if type(support).__name__ == "_IntegerInterval": return int(support.upper_bound + 1) return "real" @classmethod @functools.lru_cache(maxsize=5000) def _infer_value_domain(cls, **domains): dtype = cls._infer_value_dtype(domains) # TODO implement .infer_shapes() methods on each distribution # TODO fix distribution constraints by wrapping in _Independent batch_shape, event_shape = infer_shapes(cls.dist_class, domains) shape = batch_shape + event_shape if "value" in domains: shape = broadcast_shape(shape, domains["value"].shape) return Array[dtype, shape] @classmethod @functools.lru_cache(maxsize=5000) def _infer_param_domain(cls, name, raw_shape): support = cls.dist_class.arg_constraints.get(name, None) # XXX: if the backend does not have the same definition of constraints, we should # define backend-specific distributions and overide these `infer_value_domain`, # `infer_param_domain` methods. # Because NumPyro and Pyro have the same pattern, we use name check for simplicity. event_dim = 0 while hasattr(support, "base_constraint"): event_dim += support.reinterpreted_batch_ndims support = support.base_constraint support_name = type(support).__name__.lstrip("_") if support_name == "Simplex": output = Reals[raw_shape[-1 - event_dim :]] elif support_name == "RealVector": output = Reals[raw_shape[-1 - event_dim :]] elif support_name in ["LowerCholesky", "PositiveDefinite"]: output = Reals[raw_shape[-2 - event_dim :]] # resolve the issue: logits's constraints are real (instead of real_vector) # for discrete multivariate distributions in Pyro elif support_name == "Real": if name == "logits" and ( "probs" in cls.dist_class.arg_constraints and type(cls.dist_class.arg_constraints["probs"]).__name__.lstrip("_") == "Simplex" ): output = Reals[raw_shape[-1 - event_dim :]] else: output = Reals[raw_shape[len(raw_shape) - event_dim :]] elif support_name in ( "Interval", "GreaterThan", "LessThan", "UnitInterval", "Positive", ): output = Reals[raw_shape[len(raw_shape) - event_dim :]] else: output = None return output
def infer_shapes(dist_class, domains): arg_shapes = {k: domain.shape for k, domain in domains.items() if k != "value"} try: return dist_class.infer_shapes(**arg_shapes) except (AttributeError, NotImplementedError): pass # warnings.warn(f"Failed to infer shape for {dist_class.__name__}, " # "falling back to expensive instance construction") # Rely on the underlying distribution's logic to infer the event_shape # given param domains. args = { k: dummy_numeric_array(domain) for k, domain in domains.items() if k != "value" } instance = dist_class(**args, validate_args=False) return instance.batch_shape, instance.event_shape ################################################################################ # Distribution Wrappers ################################################################################ def make_dist( backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True ): if not param_names: param_names = tuple( name for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:] if name in backend_dist_class.arg_constraints ) @makefun.with_signature( "__init__(self, {}, value='value')".format(", ".join(param_names)) ) def dist_init(self, **kwargs): return Distribution.__init__(self, *tuple(kwargs[k] for k in self._ast_fields)) dist_class = DistributionMeta( backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution,), {"dist_class": backend_dist_class, "__init__": dist_init}, ) if generate_eager: eager.register(dist_class, *((Tensor,) * (len(param_names) + 1)))( dist_class.eager_log_prob ) if generate_to_funsor: to_funsor.register(backend_dist_class)( functools.partial(backenddist_to_funsor, dist_class) ) return dist_class FUNSOR_DIST_NAMES = [ ("Beta", ("concentration1", "concentration0")), ("Cauchy", ()), ("Chi2", ()), ("BernoulliProbs", ("probs",)), ("BernoulliLogits", ("logits",)), ("Binomial", ("total_count", "probs")), ("Categorical", ("probs",)), ("CategoricalLogits", ("logits",)), ("Delta", ("v", "log_density")), ("Dirichlet", ("concentration",)), ("DirichletMultinomial", ("concentration", "total_count")), ("Exponential", ()), ("Gamma", ("concentration", "rate")), ("GammaPoisson", ("concentration", "rate")), ("Geometric", ("probs",)), ("Gumbel", ()), ("HalfCauchy", ()), ("HalfNormal", ()), ("Laplace", ()), ("Logistic", ()), ("LowRankMultivariateNormal", ()), ("Multinomial", ("total_count", "probs")), ("MultivariateNormal", ("loc", "scale_tril")), ("NonreparameterizedBeta", ("concentration1", "concentration0")), ("NonreparameterizedDirichlet", ("concentration",)), ("NonreparameterizedGamma", ("concentration", "rate")), ("NonreparameterizedNormal", ("loc", "scale")), ("Normal", ("loc", "scale")), ("Pareto", ()), ("Poisson", ()), ("StudentT", ()), ("Uniform", ()), ("VonMises", ()), ] ############################################### # Converting backend Distributions to funsors ############################################### def backenddist_to_funsor( funsor_dist_class, backend_dist, output=None, dim_to_name=None ): params = [ to_funsor( getattr(backend_dist, param_name), output=funsor_dist_class._infer_param_domain( param_name, getattr(getattr(backend_dist, param_name), "shape", ()) ), dim_to_name=dim_to_name, ) for param_name in funsor_dist_class._ast_fields if param_name != "value" ] return funsor_dist_class(*params) def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None): if dim_to_name is None: dim_to_name = {} event_dim_to_name = OrderedDict( (i, "_pyro_event_dim_{}".format(i)) for i in range(-backend_dist.reinterpreted_batch_ndims, 0) ) dim_to_name = OrderedDict( (dim - backend_dist.reinterpreted_batch_ndims, name) for dim, name in dim_to_name.items() ) dim_to_name.update(event_dim_to_name) result = to_funsor(backend_dist.base_dist, dim_to_name=dim_to_name) if isinstance(result, Distribution) and not deep_isinstance( result.value, Finitary[ops.StackOp, tuple] ): # ops.stack() used in some eager patterns params = tuple(result.params.values())[:-1] for dim, name in reversed(event_dim_to_name.items()): dim_var = to_funsor(name, result.inputs[name]) params = tuple(Lambda(dim_var, param) for param in params) if isinstance(result.value, Variable): # broadcasting logic in Distribution will compute correct value domain result = type(result)(*(params + (result.value.name,))) else: raise NotImplementedError("TODO support converting Indep(Transform)") else: # this handles the output of eager rewrites, e.g. Normal->Gaussian or Beta->Dirichlet for dim, name in reversed(event_dim_to_name.items()): result = funsor.terms.Independent(result, "value", name, "value") return result def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None): funsor_base_dist = to_funsor( backend_dist.base_dist, output=output, dim_to_name=dim_to_name ) if not dim_to_name: assert not backend_dist.batch_shape return funsor_base_dist name_to_dim = {name: dim for dim, name in dim_to_name.items()} raw_expanded_params = {} for name, funsor_param in funsor_base_dist.params.items(): if name == "value": continue raw_param = to_data(funsor_param, name_to_dim=name_to_dim) raw_expanded_params[name] = ops.expand( raw_param, backend_dist.batch_shape + funsor_param.shape ) raw_expanded_dist = type(backend_dist.base_dist)(**raw_expanded_params) return to_funsor(raw_expanded_dist, output, dim_to_name) def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None): mask = to_funsor( ops.astype(backend_dist._mask, get_default_dtype()), output=output, dim_to_name=dim_to_name, ) funsor_base_dist = to_funsor( backend_dist.base_dist, output=output, dim_to_name=dim_to_name ) return mask * funsor_base_dist # TODO make this work with transforms with nontrivial event_dim logic # converts TransformedDistributions def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None): dist_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist base_dist, transforms = backend_dist, [] while isinstance(base_dist, dist_module.TransformedDistribution): transforms = base_dist.transforms + transforms base_dist = base_dist.base_dist funsor_base_dist = to_funsor(base_dist, output=output, dim_to_name=dim_to_name) # TODO make this work with transforms that change the output type transform = to_funsor( dist_module.transforms.ComposeTransform(transforms), funsor_base_dist.inputs["value"], dim_to_name, ) _, inv_transform, ldj = funsor.delta.solve( transform, to_funsor("value", funsor_base_dist.inputs["value"]) ) return -ldj + funsor_base_dist(value=inv_transform) class CoerceDistributionToFunsor: """ Handler to reinterpret a backend distribution ``D`` as a corresponding funsor during ``type(D).__call__()`` in case any constructor args are funsors rather than backend tensors. Example usage:: # in foo/distribution.py coerce_to_funsor = CoerceDistributionToFunsor("foo") class DistributionMeta(type): def __call__(cls, *args, **kwargs): result = coerce_to_funsor(cls, args, kwargs) if result is not None: return result return super().__call__(*args, **kwargs) class Distribution(metaclass=DistributionMeta): ... :param str backend: Name of a funsor backend. """ def __init__(self, backend): self.backend = backend @lazy_property def module(self): funsor.set_backend(self.backend) module_name = BACKEND_TO_DISTRIBUTIONS_BACKEND[self.backend] return importlib.import_module(module_name) def __call__(self, cls, args, kwargs): # Check whether distribution class takes any tensor inputs. arg_constraints = getattr(cls, "arg_constraints", None) if not arg_constraints: return # Check whether any tensor inputs are actually funsors. try: ast_fields = cls._funsor_ast_fields except AttributeError: ast_fields = cls._funsor_ast_fields = getargspec(cls.__init__)[0][1:] kwargs = { name: value for pairs in (zip(ast_fields, args), kwargs.items()) for name, value in pairs } if not any( isinstance(value, (str, Funsor)) for name, value in kwargs.items() if name in arg_constraints ): return # Check for a corresponding funsor class. try: funsor_cls = cls._funsor_cls except AttributeError: funsor_cls = getattr(self.module, cls.__name__, None) # resolve the issues Binomial/Multinomial are functions in NumPyro, which # fallback to either BinomialProbs or BinomialLogits if funsor_cls is None and cls.__name__.endswith("Probs"): funsor_cls = getattr(self.module, cls.__name__[:-5], None) cls._funsor_cls = funsor_cls if funsor_cls is None: warnings.warn("missing funsor for {}".format(cls.__name__), RuntimeWarning) return # Coerce to funsor. return funsor_cls(**kwargs) ############################################################### # Converting distribution funsors to backend distributions ############################################################### @to_data.register(Distribution) def distribution_to_data(funsor_dist, name_to_dim=None): funsor_event_shape = funsor_dist.value.output.shape # attempt to generically infer the independent output dimensions domains = {k: v.output for k, v in funsor_dist.params.items()} indep_shape, _ = infer_shapes(funsor_dist.dist_class, domains) params = [] for param_name, funsor_param in zip( funsor_dist._ast_fields, funsor_dist._ast_values[:-1] ): param = to_data(funsor_param, name_to_dim=name_to_dim) # infer the independent dimensions of each parameter separately, since we chose to keep them unbroadcasted param_event_shape = getattr( funsor_dist._infer_param_domain(param_name, funsor_param.output.shape), "shape", (), ) param_indep_shape = funsor_param.output.shape[ : len(funsor_param.output.shape) - len(param_event_shape) ] for i in range(max(0, len(indep_shape) - len(param_indep_shape))): # add singleton event dimensions, leave broadcasting/expanding to backend param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape)) params.append(param) pyro_dist = funsor_dist.dist_class( **dict(zip(funsor_dist._ast_fields[:-1], params)) ) pyro_dist = pyro_dist.to_event( max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0) ) # TODO get this working for all backends if not isinstance(funsor_dist.value, Variable): if get_backend() != "torch": raise NotImplementedError( "transformed distributions not yet supported under this backend," "try set_backend('torch')" ) inv_value = funsor.delta.solve( funsor_dist.value, Variable("value", funsor_dist.value.output) )[1] transforms = to_data(inv_value, name_to_dim=name_to_dim) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()] ).dist pyro_dist = backend_dist.TransformedDistribution(pyro_dist, transforms) if pyro_dist.event_shape != funsor_event_shape: raise ValueError("Event shapes don't match, something went wrong") return pyro_dist @to_data.register(Independent[typing.Union[Independent, Distribution], str, str, str]) def indep_to_data(funsor_dist, name_to_dim=None): if not isinstance(funsor_dist.fn, (Independent, Distribution, Gaussian)): raise NotImplementedError(f"cannot convert {funsor_dist} to data") name_to_dim = OrderedDict((name, dim - 1) for name, dim in name_to_dim.items()) name_to_dim.update({funsor_dist.bint_var: -1}) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist result = to_data(funsor_dist.fn, name_to_dim=name_to_dim) # collapse nested Independents into a single Independent for conversion reinterpreted_batch_ndims = 1 while isinstance(result, backend_dist.Independent): result = result.base_dist reinterpreted_batch_ndims += 1 return backend_dist.Independent(result, reinterpreted_batch_ndims) @to_data.register(Gaussian) def gaussian_to_data(funsor_dist, name_to_dim=None): int_inputs = OrderedDict( (k, d) for k, d in funsor_dist.inputs.items() if d.dtype != "real" ) loc = to_data(Tensor(funsor_dist._mean, int_inputs), name_to_dim) precision = to_data(Tensor(funsor_dist._precision, int_inputs), name_to_dim) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.MultivariateNormal.dist_class(loc, precision_matrix=precision) @to_data.register(GaussianMixture) def gaussianmixture_to_data(funsor_dist, name_to_dim=None): discrete, gaussian = funsor_dist.terms backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) cat = backend_dist.CategoricalLogits.dist_class( logits=to_data(discrete + gaussian.log_normalizer, name_to_dim=name_to_dim) ) mvn = to_data(gaussian, name_to_dim=name_to_dim) return cat, mvn ################################################ # Backend-agnostic distribution patterns ################################################ def Bernoulli(probs=None, logits=None, value="value"): """ Wraps backend `Bernoulli` distributions. This dispatches to either `BernoulliProbs` or `BernoulliLogits` to accept either ``probs`` or ``logits`` args. :param Funsor probs: Probability of 1. :param Funsor value: Optional observation in ``{0,1}``. """ backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) if probs is not None: probs = to_funsor(probs, output=Real) return backend_dist.BernoulliProbs(probs, value) # noqa: F821 if logits is not None: logits = to_funsor(logits, output=Real) return backend_dist.BernoulliLogits(logits, value) # noqa: F821 raise ValueError("Either probs or logits must be specified") def LogNormal(loc, scale, value="value"): """ Wraps backend `LogNormal` distributions. :param Funsor loc: Mean of the untransformed Normal distribution. :param Funsor scale: Standard deviation of the untransformed Normal distribution. :param Funsor value: Optional real observation. """ loc, scale = to_funsor(loc), to_funsor(scale) y = to_funsor(value, output=loc.output) t = ops.exp x = t.inv(y) log_abs_det_jacobian = t.log_abs_det_jacobian(x, y) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Normal(loc, scale, x) - log_abs_det_jacobian # noqa: F821 def eager_beta(concentration1, concentration0, value): concentration = ops.stack((concentration0, concentration1)) value = ops.stack((1 - value, value)) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Dirichlet(concentration, value=value) # noqa: F821 def eager_binomial(total_count, probs, value): probs = ops.stack((1 - probs, probs)) value = ops.stack((total_count - value, value)) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Multinomial(total_count, probs, value=value) # noqa: F821 def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape) probs = Tensor(ops.expand(probs, shape), inputs) value = Tensor(ops.expand(value, shape), inputs) if get_backend() == "torch": total_count = Number( ops.amax(total_count, None).item() ) # Used by distributions validation code. else: total_count = Tensor(ops.expand(total_count, shape[:-1]), inputs) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Multinomial.eager_log_prob( total_count, probs, value ) # noqa: F821 def eager_categorical_funsor(probs, value): return probs[value].log() def eager_categorical_tensor(probs, value): value = probs.materialize(value) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Categorical(probs=probs, value=value) # noqa: F821 def eager_delta_tensor(v, log_density, value): # This handles event_dim specially, and hence cannot use the # generic Delta.eager_log_prob() method. assert v.output == value.output event_dim = len(v.output.shape) inputs, (v, log_density, value) = align_tensors(v, log_density, value) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) data = backend_dist.Delta.dist_class(v, log_density, event_dim).log_prob( value ) # noqa: F821 return Tensor(data, inputs) def eager_delta_funsor_variable(v, log_density, value): assert v.output == value.output return funsor.delta.Delta(value.name, v, log_density) def eager_delta_funsor_funsor(v, log_density, value): assert v.output == value.output return funsor.delta.Delta(v.name, value, log_density) def eager_delta_variable_variable(v, log_density, value): return None def eager_normal(loc, scale, value): assert loc.output == Real assert scale.output == Real assert value.output == Real if not is_affine(loc) or not is_affine(value): return None # lazy white_vec = ops.new_zeros(scale.data, scale.data.shape + (1,)) prec_sqrt = (1 / scale.data)[..., None, None] log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale) inputs = scale.inputs.copy() var = gensym("value") inputs[var] = Real gaussian = log_prob + Gaussian( white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs, ) return gaussian(**{var: value - loc}) def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy white_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1]) prec_sqrt = ops.transpose(ops.triangular_inv(scale_tril.data), -1, -2) scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2), scale_tril.inputs) log_prob = ( -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - ops.log(scale_diag).sum() ) inputs = scale_tril.inputs.copy() var = gensym("value") inputs[var] = Reals[scale_diag.shape[0]] gaussian = log_prob + Gaussian( white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs ) return gaussian(**{var: value - loc}) def eager_beta_bernoulli(red_op, bin_op, reduced_vars, x, y): backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return eager_dirichlet_multinomial( red_op, bin_op, reduced_vars, x, backend_dist.Binomial(total_count=1, probs=y.probs, value=y.value), ) def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y): dirichlet_reduction = x.input_vars & reduced_vars if dirichlet_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) identity = Tensor( ops.new_eye(funsor.tensor.get_default_prototype(), x.concentration.shape) ) return backend_dist.DirichletMultinomial( concentration=x.concentration, total_count=1, value=identity[y.value] ) else: return eager.interpret(Contraction, red_op, bin_op, reduced_vars, (x, y)) def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y): dirichlet_reduction = x.input_vars & reduced_vars if dirichlet_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.DirichletMultinomial( concentration=x.concentration, total_count=y.total_count, value=y.value ) else: return eager.interpret(Contraction, red_op, bin_op, reduced_vars, (x, y)) def eager_plate_multinomial(op, x, reduced_vars): if not reduced_vars.isdisjoint(x.probs.input_vars): return None if not reduced_vars.issubset(x.value.input_vars): return None backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) total_count = x.total_count for v in reduced_vars: if v.name in total_count.inputs: total_count = total_count.reduce(ops.add, v) else: total_count = total_count * v.output.size return backend_dist.Multinomial( total_count=total_count, probs=x.probs, value=x.value.reduce(ops.add, reduced_vars), ) def _log_beta(x, y): return ops.lgamma(x) + ops.lgamma(y) - ops.lgamma(x + y) def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y): gamma_reduction = x.input_vars & reduced_vars if gamma_reduction: unnormalized = (y.concentration - 1) * ops.log(y.value) - ( y.concentration + x.concentration ) * ops.log(y.value + x.rate) const = -x.concentration * ops.log(x.rate) + _log_beta( y.concentration, x.concentration ) return unnormalized - const else: return eager.interpret(Contraction, red_op, bin_op, reduced_vars, (x, y)) def eager_gamma_poisson(red_op, bin_op, reduced_vars, x, y): gamma_reduction = x.input_vars & reduced_vars if gamma_reduction: backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.GammaPoisson( concentration=x.concentration, rate=x.rate, value=y.value ) else: return eager.interpret(Contraction, red_op, bin_op, reduced_vars, (x, y)) def eager_dirichlet_posterior(op, c, z): if (z.concentration is c.terms[0].concentration) and ( c.terms[1].total_count is z.total_count ): backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Dirichlet( concentration=z.concentration + c.terms[1].value, value=c.terms[0].value ) else: return None