Source code for funsor.distributions

import math
from collections import OrderedDict

import pyro.distributions as dist
import torch
from pyro.distributions.util import broadcast_shape

import funsor.delta
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian, cholesky_inverse
from funsor.interpreter import gensym, interpretation
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, lazy, to_funsor
from funsor.torch import Tensor, align_tensors, ignore_jit_warnings, materialize, torch_stack


def numbers_to_tensors(*args):
    """
    Convert :class:`~funsor.terms.Number`s to :class:`funsor.torch.Tensor`s,
    using any provided tensor as a prototype, if available.
    """
    if any(isinstance(x, Number) for x in args):
        options = dict(dtype=torch.get_default_dtype())
        for x in args:
            if isinstance(x, Tensor):
                options = dict(dtype=x.data.dtype, device=x.data.device)
                break
        with ignore_jit_warnings():
            args = tuple(Tensor(torch.tensor(x.data, **options), dtype=x.dtype)
                         if isinstance(x, Number) else x
                         for x in args)
    return args


class DistributionMeta(FunsorMeta):
    """
    Wrapper to fill in default values and convert Numbers to Tensors.
    """
    def __call__(cls, *args, **kwargs):
        kwargs.update(zip(cls._ast_fields, args))
        args = cls._fill_defaults(**kwargs)
        args = numbers_to_tensors(*args)

        # If value was explicitly specified, evaluate under current interpretation.
        if 'value' in kwargs:
            return super(DistributionMeta, cls).__call__(*args)

        # Otherwise lazily construct a distribution instance.
        # This makes it cheaper to construct observations in minipyro.
        with interpretation(lazy):
            return super(DistributionMeta, cls).__call__(*args)


[docs]class Distribution(Funsor, metaclass=DistributionMeta): r""" Funsor backed by a PyTorch distribution object. :param \*args: Distribution-dependent parameters. These can be either funsors or objects that can be coerced to funsors via :func:`~funsor.terms.to_funsor` . See derived classes for details. """ dist_class = "defined by derived classes" def __init__(self, *args): params = tuple(zip(self._ast_fields, args)) assert any(k == 'value' for k, v in params) inputs = OrderedDict() for name, value in params: assert isinstance(name, str) assert isinstance(value, Funsor) inputs.update(value.inputs) inputs = OrderedDict(inputs) output = reals() super(Distribution, self).__init__(inputs, output) self.params = params def __repr__(self): return '{}({})'.format(type(self).__name__, ', '.join('{}={}'.format(*kv) for kv in self.params))
[docs] def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp and isinstance(self.value, Variable) and self.value.name in reduced_vars: return Number(0.) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars)
[docs] @classmethod def eager_log_prob(cls, **params): inputs, tensors = align_tensors(*params.values()) params = dict(zip(params, tensors)) value = params.pop('value') data = cls.dist_class(**params).log_prob(value) return Tensor(data, inputs)
################################################################################ # Distribution Wrappers ################################################################################ class BernoulliProbs(Distribution): """ Wraps :class:`pyro.distributions.Bernoulli` . :param Funsor probs: Probability of 1. :param Funsor value: Optional observation in ``{0,1}``. """ dist_class = dist.Bernoulli @staticmethod def _fill_defaults(probs, value='value'): probs = to_funsor(probs) assert probs.dtype == "real" value = to_funsor(value, reals()) return probs, value def __init__(self, probs, value=None): super(BernoulliProbs, self).__init__(probs, value) @eager.register(BernoulliProbs, Tensor, Tensor) def eager_bernoulli(probs, value): return BernoulliProbs.eager_log_prob(probs=probs, value=value)
[docs]class BernoulliLogits(Distribution): """ Wraps :class:`pyro.distributions.Bernoulli` . :param Funsor logits: Log likelihood ratio of 1. This should equal ``log(p1 / p0)``. :param Funsor value: Optional observation in ``{0,1}``. """ dist_class = dist.Bernoulli @staticmethod def _fill_defaults(logits, value='value'): logits = to_funsor(logits) assert logits.dtype == "real" value = to_funsor(value, reals()) return logits, value def __init__(self, logits, value=None): super(BernoulliLogits, self).__init__(logits, value)
@eager.register(BernoulliLogits, Tensor, Tensor) def eager_bernoulli_logits(logits, value): return BernoulliLogits.eager_log_prob(logits=logits, value=value)
[docs]def Bernoulli(probs=None, logits=None, value='value'): """ Wraps :class:`pyro.distributions.Bernoulli` . This dispatches to either :class:`BernoulliProbs` or :class:`BernoulliLogits` to accept either ``probs`` or ``logits`` args. :param Funsor probs: Probability of 1. :param Funsor value: Optional observation in ``{0,1}``. """ if probs is not None: return BernoulliProbs(probs, value) if logits is not None: return BernoulliLogits(logits, value) raise ValueError('Either probs or logits must be specified')
[docs]class Beta(Distribution): """ Wraps :class:`pyro.distributions.Beta` . :param Funsor concentration1: Positive concentration parameter. :param Funsor concentration0: Positive concentration parameter. :param Funsor value: Optional observation in ``(0,1)``. """ dist_class = dist.Beta @staticmethod def _fill_defaults(concentration1, concentration0, value='value'): concentration1 = to_funsor(concentration1, reals()) concentration0 = to_funsor(concentration0, reals()) value = to_funsor(value, reals()) return concentration1, concentration0, value def __init__(self, concentration1, concentration0, value=None): super(Beta, self).__init__(concentration1, concentration0, value)
@eager.register(Beta, Tensor, Tensor, Tensor) def eager_beta(concentration1, concentration0, value): return Beta.eager_log_prob(concentration1=concentration1, concentration0=concentration0, value=value) @eager.register(Beta, Funsor, Funsor, Funsor) def eager_beta(concentration1, concentration0, value): concentration = torch_stack((concentration0, concentration1)) value = torch_stack((1 - value, value)) return Dirichlet(concentration, value=value)
[docs]class Binomial(Distribution): """ Wraps :class:`pyro.distributions.Binomial` . :param Funsor total_count: Total number of trials. :param Funsor probs: Probability of each positive trial. :param Funsor value: Optional integer observation (encoded as "real"). """ dist_class = dist.Binomial @staticmethod def _fill_defaults(total_count, probs, value='value'): total_count = to_funsor(total_count, reals()) probs = to_funsor(probs) assert probs.dtype == "real" value = to_funsor(value, reals()) return total_count, probs, value def __init__(self, total_count, probs, value=None): super(Binomial, self).__init__(total_count, probs, value)
@eager.register(Binomial, Tensor, Tensor, Tensor) def eager_binomial(total_count, probs, value): return Binomial.eager_log_prob(total_count=total_count, probs=probs, value=value) @eager.register(Binomial, Funsor, Funsor, Funsor) def eager_binomial(total_count, probs, value): probs = torch_stack((1 - probs, probs)) value = torch_stack((total_count - value, value)) return Multinomial(total_count, probs, value=value)
[docs]class Categorical(Distribution): """ Wraps :class:`pyro.distributions.Categorical` . :param Funsor probs: Probability vector over outcomes. :param Funsor value: Optional bouded integer observation. """ dist_class = dist.Categorical @staticmethod def _fill_defaults(probs, value='value'): probs = to_funsor(probs) assert probs.dtype == "real" value = to_funsor(value, bint(probs.output.shape[0])) return probs, value def __init__(self, probs, value='value'): super(Categorical, self).__init__(probs, value)
@eager.register(Categorical, Funsor, Tensor) def eager_categorical(probs, value): return probs[value].log() @eager.register(Categorical, Tensor, Tensor) def eager_categorical(probs, value): return Categorical.eager_log_prob(probs=probs, value=value) @eager.register(Categorical, Tensor, Variable) def eager_categorical(probs, value): value = materialize(value) return Categorical.eager_log_prob(probs=probs, value=value)
[docs]class Delta(Distribution): """ Wraps :class:`pyro.distributions.Delta` . :param Funsor v: The unique point of concentration. :param Funsor log_density: Optional density (used by transformed distributions). :param Funsor value: Optional observation of similar domain as ``v``. """ dist_class = dist.Delta @staticmethod def _fill_defaults(v, log_density=0, value='value'): v = to_funsor(v) log_density = to_funsor(log_density, reals()) value = to_funsor(value, v.output) return v, log_density, value def __init__(self, v, log_density=0, value='value'): return super(Delta, self).__init__(v, log_density, value)
@eager.register(Delta, Tensor, Tensor, Tensor) def eager_delta(v, log_density, value): # This handles event_dim specially, and hence cannot use the # generic Delta.eager_log_prob() method. assert v.output == value.output event_dim = len(v.output.shape) inputs, (v, log_density, value) = align_tensors(v, log_density, value) data = dist.Delta(v, log_density, event_dim).log_prob(value) return Tensor(data, inputs) @eager.register(Delta, Funsor, Funsor, Variable) @eager.register(Delta, Variable, Funsor, Variable) def eager_delta(v, log_density, value): assert v.output == value.output return funsor.delta.Delta(value.name, v, log_density) @eager.register(Delta, Variable, Funsor, Funsor) def eager_delta(v, log_density, value): assert v.output == value.output return funsor.delta.Delta(v.name, value, log_density)
[docs]class Dirichlet(Distribution): """ Wraps :class:`pyro.distributions.Dirichlet` . :param Funsor concentration: Positive concentration vector. :param Funsor value: Optional observation in the unit simplex. """ dist_class = dist.Dirichlet @staticmethod def _fill_defaults(concentration, value='value'): concentration = to_funsor(concentration) assert concentration.dtype == "real" assert len(concentration.output.shape) == 1 dim = concentration.output.shape[0] value = to_funsor(value, reals(dim)) return concentration, value def __init__(self, concentration, value='value'): super(Dirichlet, self).__init__(concentration, value)
@eager.register(Dirichlet, Tensor, Tensor) def eager_dirichlet(concentration, value): return Dirichlet.eager_log_prob(concentration=concentration, value=value)
[docs]class DirichletMultinomial(Distribution): """ Wraps :class:`pyro.distributions.DirichletMultinomial` . :param Funsor concentration: Positive concentration vector. :param Funsor total_count: Total number of trials. :param Funsor value: Optional observation in the unit simplex. """ dist_class = dist.DirichletMultinomial @staticmethod def _fill_defaults(concentration, total_count=1, value='value'): concentration = to_funsor(concentration) assert concentration.dtype == "real" assert len(concentration.output.shape) == 1 total_count = to_funsor(total_count, reals()) dim = concentration.output.shape[0] value = to_funsor(value, reals(dim)) # Should this be bint(total_count)? return concentration, total_count, value def __init__(self, concentration, total_count, value='value'): super(DirichletMultinomial, self).__init__(concentration, total_count, value)
@eager.register(DirichletMultinomial, Tensor, Tensor, Tensor) def eager_dirichlet_multinomial(concentration, total_count, value): return DirichletMultinomial.eager_log_prob( concentration=concentration, total_count=total_count, value=value)
[docs]def LogNormal(loc, scale, value='value'): """ Wraps :class:`pyro.distributions.LogNormal` . :param Funsor loc: Mean of the untransformed Normal distribution. :param Funsor scale: Standard deviation of the untransformed Normal distribution. :param Funsor value: Optional real observation. """ loc, scale, y = Normal._fill_defaults(loc, scale, value) t = ops.exp x = t.inv(y) log_abs_det_jacobian = t.log_abs_det_jacobian(x, y) return Normal(loc, scale, x) - log_abs_det_jacobian
[docs]class Multinomial(Distribution): """ Wraps :class:`pyro.distributions.Multinomial` . :param Funsor probs: Probability vector over outcomes. :param Funsor total_count: Total number of trials. :param Funsor value: Optional value in the unit simplex. """ dist_class = dist.Multinomial @staticmethod def _fill_defaults(total_count, probs, value='value'): total_count = to_funsor(total_count, reals()) probs = to_funsor(probs) assert probs.dtype == "real" assert len(probs.output.shape) == 1 value = to_funsor(value, probs.output) return total_count, probs, value def __init__(self, total_count, probs, value=None): super(Multinomial, self).__init__(total_count, probs, value)
@eager.register(Multinomial, Tensor, Tensor, Tensor) def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape) probs = Tensor(probs.expand(shape), inputs) value = Tensor(value.expand(shape), inputs) total_count = Number(total_count.max().item()) # Used by distributions validation code. return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value)
[docs]class Normal(Distribution): """ Wraps :class:`pyro.distributions.Normal` . :param Funsor loc: Mean. :param Funsor scale: Standard deviation. :param Funsor value: Optional real observation. """ dist_class = dist.Normal @staticmethod def _fill_defaults(loc, scale, value='value'): loc = to_funsor(loc, reals()) scale = to_funsor(scale, reals()) value = to_funsor(value, reals()) return loc, scale, value def __init__(self, loc, scale, value='value'): super(Normal, self).__init__(loc, scale, value)
@eager.register(Normal, Tensor, Tensor, Tensor) def eager_normal(loc, scale, value): return Normal.eager_log_prob(loc=loc, scale=scale, value=value) @eager.register(Normal, Funsor, Tensor, Funsor) def eager_normal(loc, scale, value): assert loc.output == reals() assert scale.output == reals() assert value.output == reals() if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = scale.data.new_zeros(scale.data.shape + (1,)) precision = scale.data.pow(-2).reshape(scale.data.shape + (1, 1)) log_prob = -0.5 * math.log(2 * math.pi) - scale.log().sum() inputs = scale.inputs.copy() var = gensym('value') inputs[var] = reals() gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
[docs]class MultivariateNormal(Distribution): """ Wraps :class:`pyro.distributions.MultivariateNormal` . :param Funsor loc: Mean vector. :param Funsor scale_tril: Lower Cholesky factor of the covariance matrix. :param Funsor value: Optional real vector observation. """ dist_class = dist.MultivariateNormal @staticmethod def _fill_defaults(loc, scale_tril, value='value'): loc = to_funsor(loc) scale_tril = to_funsor(scale_tril) assert loc.dtype == 'real' assert scale_tril.dtype == 'real' assert len(loc.output.shape) == 1 dim = loc.output.shape[0] assert scale_tril.output.shape == (dim, dim) value = to_funsor(value, loc.output) return loc, scale_tril, value def __init__(self, loc, scale_tril, value='value'): super(MultivariateNormal, self).__init__(loc, scale_tril, value)
@eager.register(MultivariateNormal, Tensor, Tensor, Tensor) def eager_mvn(loc, scale_tril, value): return MultivariateNormal.eager_log_prob(loc=loc, scale_tril=scale_tril, value=value) @eager.register(MultivariateNormal, Funsor, Tensor, Funsor) def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1]) precision = cholesky_inverse(scale_tril.data) scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() inputs = scale_tril.inputs.copy() var = gensym('value') inputs[var] = reals(scale_diag.shape[0]) gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
[docs]class Poisson(Distribution): """ Wraps :class:`pyro.distributions.Poisson` . :param Funsor rate: Mean parameter. :param Funsor value: Optional integer observation (coded as "real"). """ dist_class = dist.Poisson @staticmethod def _fill_defaults(rate, value='value'): rate = to_funsor(rate) assert rate.dtype == "real" value = to_funsor(value, reals()) return rate, value def __init__(self, rate, value=None): super().__init__(rate, value)
@eager.register(Poisson, Tensor, Tensor) def eager_poisson(rate, value): return Poisson.eager_log_prob(rate=rate, value=value)
[docs]class Gamma(Distribution): """ Wraps :class:`pyro.distributions.Gamma` . :param Funsor concentration: Positive concentration parameter. :param Funsor rate: Positive rate parameter. :param Funsor value: Optional positive observation. """ dist_class = dist.Gamma @staticmethod def _fill_defaults(concentration, rate, value='value'): concentration = to_funsor(concentration) assert concentration.dtype == "real" rate = to_funsor(rate) assert rate.dtype == "real" value = to_funsor(value, reals()) return concentration, rate, value def __init__(self, concentration, rate, value=None): super().__init__(concentration, rate, value)
@eager.register(Gamma, Tensor, Tensor, Tensor) def eager_gamma(concentration, rate, value): return Gamma.eager_log_prob(concentration=concentration, rate=rate, value=value)
[docs]class VonMises(Distribution): """ Wraps :class:`pyro.distributions.VonMises` . :param Funsor loc: A location angle. :param Funsor concentration: Positive concentration parameter. :param Funsor value: Optional angular observation. """ dist_class = dist.VonMises @staticmethod def _fill_defaults(loc, concentration, value='value'): loc = to_funsor(loc) assert loc.dtype == "real" concentration = to_funsor(concentration) assert concentration.dtype == "real" value = to_funsor(value, reals()) return loc, concentration, value def __init__(self, loc, concentration, value=None): super().__init__(loc, concentration, value)
@eager.register(VonMises, Tensor, Tensor, Tensor) def eager_vonmises(loc, concentration, value): return VonMises.eager_log_prob(loc=loc, concentration=concentration, value=value) __all__ = [ 'Bernoulli', 'BernoulliLogits', 'Beta', 'Binomial', 'Categorical', 'Delta', 'Dirichlet', 'DirichletMultinomial', 'Distribution', 'Gamma', 'LogNormal', 'Multinomial', 'MultivariateNormal', 'Normal', 'Poisson', 'VonMises', ]