Source code for funsor.torch.distributions

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
from typing import Tuple, Union

import pyro.distributions as dist
import pyro.distributions.testing.fakes as fakes
from pyro.distributions.torch_distribution import MaskedDistribution
import torch

from funsor.cnf import Contraction
from funsor.distribution import (  # noqa: F401
    Bernoulli,
    FUNSOR_DIST_NAMES,
    LogNormal,
    backenddist_to_funsor,
    eager_beta,
    eager_beta_bernoulli,
    eager_binomial,
    eager_categorical_funsor,
    eager_categorical_tensor,
    eager_delta_funsor_funsor,
    eager_delta_funsor_variable,
    eager_delta_tensor,
    eager_dirichlet_multinomial,
    eager_dirichlet_posterior,
    eager_delta_variable_variable,
    eager_gamma_poisson,
    eager_multinomial,
    eager_mvn,
    eager_normal,
    indepdist_to_funsor,
    make_dist,
    maskeddist_to_funsor,
    mvndist_to_funsor,
    transformeddist_to_funsor,
)
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor, dummy_numeric_array
from funsor.terms import Binary, Funsor, Variable, eager, to_funsor
from funsor.util import methodof


__all__ = list(x[0] for x in FUNSOR_DIST_NAMES)


################################################################################
# Distribution Wrappers
################################################################################


class _PyroWrapper_BernoulliProbs(dist.Bernoulli):
    def __init__(self, probs, validate_args=None):
        return super().__init__(probs=probs, validate_args=validate_args)

    # XXX: subclasses of Pyro distribution which defines a custom __init__ method
    # should also have `expand` implemented.
    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(_PyroWrapper_BernoulliProbs, _instance)
        return super().expand(batch_shape, _instance=new)


class _PyroWrapper_BernoulliLogits(dist.Bernoulli):
    def __init__(self, logits, validate_args=None):
        return super().__init__(logits=logits, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(_PyroWrapper_BernoulliLogits, _instance)
        return super().expand(batch_shape, _instance=new)


class _PyroWrapper_CategoricalLogits(dist.Categorical):
    def __init__(self, logits, validate_args=None):
        return super().__init__(logits=logits, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(_PyroWrapper_CategoricalLogits, _instance)
        return super().expand(batch_shape, _instance=new)


def _get_pyro_dist(dist_name):
    if dist_name in ['BernoulliProbs', 'BernoulliLogits', 'CategoricalLogits']:
        return globals().get('_PyroWrapper_' + dist_name)
    elif dist_name.startswith('Nonreparameterized'):
        return getattr(fakes, dist_name)
    else:
        return getattr(dist, dist_name)


PYRO_DIST_NAMES = FUNSOR_DIST_NAMES


for dist_name, param_names in PYRO_DIST_NAMES:
    locals()[dist_name] = make_dist(_get_pyro_dist(dist_name), param_names)


# Delta has to be treated specially because of its weird shape inference semantics
@methodof(Delta)  # noqa: F821
@staticmethod
def _infer_value_domain(**kwargs):
    return kwargs['v']


# Multinomial and related dists have dependent Bint dtypes, so we just make them 'real'
# See issue: https://github.com/pyro-ppl/funsor/issues/322
@methodof(Binomial)  # noqa: F821
@methodof(Multinomial)  # noqa: F821
@methodof(DirichletMultinomial)  # noqa: F821
@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_value_domain(cls, **kwargs):
    instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, validate_args=False)
    return Reals[instance.event_shape]


# TODO fix Delta.arg_constraints["v"] to be a
# constraints.independent[constraints.real]
@methodof(Delta)  # noqa: F821
@staticmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(name, raw_shape):
    if name == "v":
        return Reals[raw_shape]
    elif name == "log_density":
        return Real
    else:
        raise ValueError(name)


# TODO fix Dirichlet.arg_constraints["concentration"] to be a
# constraints.independent[constraints.positive]
@methodof(Dirichlet)  # noqa: F821
@methodof(NonreparameterizedDirichlet)  # noqa: F821
@staticmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(name, raw_shape):
    assert name == "concentration"
    return Reals[raw_shape[-1]]


# TODO fix DirichletMultinomial.arg_constraints["concentration"] to be a
# constraints.independent[constraints.positive]
@methodof(DirichletMultinomial)  # noqa: F821
@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(cls, name, raw_shape):
    if name == "concentration":
        return Reals[raw_shape[-1]]
    assert name == "total_count"
    return Real


###############################################
# Converting PyTorch Distributions to funsors
###############################################

to_funsor.register(torch.distributions.Distribution)(backenddist_to_funsor)
to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor)
to_funsor.register(torch.distributions.MultivariateNormal)(mvndist_to_funsor)


@to_funsor.register(torch.distributions.Bernoulli)
def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
    new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits)
    return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)


JointDirichletMultinomial = Contraction[
    Union[ops.LogAddExpOp, ops.NullOp],
    ops.AddOp,
    frozenset,
    Tuple[Dirichlet, Multinomial],  # noqa: F821
]


eager.register(Beta, Funsor, Funsor, Funsor)(eager_beta)  # noqa: F821)
eager.register(Binomial, Funsor, Funsor, Funsor)(eager_binomial)  # noqa: F821
eager.register(Multinomial, Tensor, Tensor, Tensor)(eager_multinomial)  # noqa: F821)
eager.register(Categorical, Funsor, Tensor)(eager_categorical_funsor)  # noqa: F821)
eager.register(Categorical, Tensor, Variable)(eager_categorical_tensor)  # noqa: F821)
eager.register(Delta, Tensor, Tensor, Tensor)(eager_delta_tensor)  # noqa: F821
eager.register(Delta, Funsor, Funsor, Variable)(eager_delta_funsor_variable)  # noqa: F821
eager.register(Delta, Variable, Funsor, Variable)(eager_delta_funsor_variable)  # noqa: F821
eager.register(Delta, Variable, Funsor, Funsor)(eager_delta_funsor_funsor)  # noqa: F821
eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable)  # noqa: F821
eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal)  # noqa: F821
eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn)  # noqa: F821
eager.register(Contraction, ops.LogAddExpOp, ops.AddOp, frozenset, Dirichlet, BernoulliProbs)(  # noqa: F821
    eager_beta_bernoulli)
eager.register(Contraction, ops.LogAddExpOp, ops.AddOp, frozenset, Dirichlet, Multinomial)(  # noqa: F821
    eager_dirichlet_multinomial)
eager.register(Contraction, ops.LogAddExpOp, ops.AddOp, frozenset, Gamma, Poisson)(  # noqa: F821
    eager_gamma_poisson)
eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial)(  # noqa: F821
    eager_dirichlet_posterior)