# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import functools
import numbers
from typing import Tuple, Union
import pyro.distributions as dist
import pyro.distributions.testing.fakes as fakes
import torch
from pyro.distributions.torch_distribution import (
ExpandedDistribution,
MaskedDistribution,
)
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.constant import Constant
from funsor.distribution import ( # noqa: F401
FUNSOR_DIST_NAMES,
Bernoulli,
LogNormal,
backenddist_to_funsor,
eager_beta,
eager_beta_bernoulli,
eager_binomial,
eager_categorical_funsor,
eager_categorical_tensor,
eager_delta_funsor_funsor,
eager_delta_funsor_variable,
eager_delta_tensor,
eager_delta_variable_variable,
eager_dirichlet_categorical,
eager_dirichlet_multinomial,
eager_dirichlet_posterior,
eager_gamma_gamma,
eager_gamma_poisson,
eager_multinomial,
eager_mvn,
eager_normal,
eager_plate_multinomial,
expandeddist_to_funsor,
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
transformeddist_to_funsor,
)
from funsor.domains import Real, Reals
from funsor.interpretations import eager
from funsor.tensor import Tensor
from funsor.terms import Binary, Funsor, Reduce, Unary, Variable, to_data, to_funsor
from funsor.util import methodof
__all__ = list(x[0] for x in FUNSOR_DIST_NAMES)
################################################################################
# Distribution Wrappers
################################################################################
class _PyroWrapper_BernoulliProbs(dist.Bernoulli):
def __init__(self, probs, validate_args=None):
return super().__init__(probs=probs, validate_args=validate_args)
# XXX: subclasses of Pyro distribution which defines a custom __init__ method
# should also have `expand` implemented.
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(_PyroWrapper_BernoulliProbs, _instance)
return super().expand(batch_shape, _instance=new)
class _PyroWrapper_BernoulliLogits(dist.Bernoulli):
def __init__(self, logits, validate_args=None):
return super().__init__(logits=logits, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(_PyroWrapper_BernoulliLogits, _instance)
return super().expand(batch_shape, _instance=new)
class _PyroWrapper_CategoricalLogits(dist.Categorical):
def __init__(self, logits, validate_args=None):
return super().__init__(logits=logits, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(_PyroWrapper_CategoricalLogits, _instance)
return super().expand(batch_shape, _instance=new)
def _get_pyro_dist(dist_name):
if dist_name in ["BernoulliProbs", "BernoulliLogits", "CategoricalLogits"]:
return globals().get("_PyroWrapper_" + dist_name)
elif dist_name.startswith("Nonreparameterized"):
return getattr(fakes, dist_name)
else:
return getattr(dist, dist_name)
PYRO_DIST_NAMES = FUNSOR_DIST_NAMES + [
("ContinuousBernoulli", ("logits",)),
("FisherSnedecor", ()),
# ("LogisticNormal", ()), # TODO handle as transformed dist
("NegativeBinomial", ("total_count", "probs")),
("OneHotCategorical", ("probs",)),
("RelaxedBernoulli", ("temperature", "logits")),
("Weibull", ()),
]
for dist_name, param_names in PYRO_DIST_NAMES:
locals()[dist_name] = make_dist(_get_pyro_dist(dist_name), param_names=param_names)
# Delta has to be treated specially because of its weird shape inference semantics
@methodof(Delta) # noqa: F821
@staticmethod
def _infer_value_domain(**kwargs):
return kwargs["v"]
@methodof(Categorical) # noqa: F821
@methodof(CategoricalLogits) # noqa: F821
@classmethod
def _infer_value_dtype(cls, domains):
if "logits" in domains:
return domains["logits"].shape[-1]
if "probs" in domains:
return domains["probs"].shape[-1]
raise ValueError
# Multinomial and related dists have dependent Bint dtypes, so we just make them 'real'
# See issue: https://github.com/pyro-ppl/funsor/issues/322
@methodof(Binomial) # noqa: F821
@methodof(Multinomial) # noqa: F821
@methodof(DirichletMultinomial) # noqa: F821
@classmethod
def _infer_value_dtype(cls, domains):
return "real"
# TODO fix Delta.arg_constraints["v"] to be a
# constraints.independent[constraints.real]
@methodof(Delta) # noqa: F821
@staticmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(name, raw_shape):
if name == "v":
return Reals[raw_shape]
elif name == "log_density":
return Real
else:
raise ValueError(name)
# TODO fix Dirichlet.arg_constraints["concentration"] to be a
# constraints.independent[constraints.positive]
@methodof(Dirichlet) # noqa: F821
@methodof(NonreparameterizedDirichlet) # noqa: F821
@staticmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(name, raw_shape):
assert name == "concentration"
return Reals[raw_shape[-1]]
# TODO fix DirichletMultinomial.arg_constraints["concentration"] to be a
# constraints.independent[constraints.positive]
@methodof(DirichletMultinomial) # noqa: F821
@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(cls, name, raw_shape):
if name == "concentration":
return Reals[raw_shape[-1]]
assert name == "total_count"
return Real
# TODO fix LowRankMultivariateNormal.arg_constraints upstream
@methodof(LowRankMultivariateNormal) # noqa: F821
@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(cls, name, raw_shape):
if name == "loc":
return Reals[raw_shape[-1]]
elif name == "cov_factor":
return Reals[raw_shape[-2:]]
elif name == "cov_diag":
return Reals[raw_shape[-1]]
raise ValueError(f"{name} invalid param for {cls}")
# TODO add temperature to RelaxedBernoulli.arg_constraints upstream
@methodof(RelaxedBernoulli) # noqa: F821
@classmethod
@functools.lru_cache(maxsize=5000)
def _infer_param_domain(cls, name, raw_shape):
if name == "temperature":
return Real
return Real
###########################################################
# Converting distribution funsors to PyTorch distributions
###########################################################
@to_data.register(Multinomial) # noqa: F821
def multinomial_to_data(funsor_dist, name_to_dim=None):
probs = to_data(funsor_dist.probs, name_to_dim)
total_count = to_data(funsor_dist.total_count, name_to_dim)
if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0:
return dist.Multinomial(int(total_count), probs=probs)
raise NotImplementedError("inhomogeneous total_count not supported")
# Convert Delta **distribution** to raw data
@to_data.register(Delta) # noqa: F821
def deltadist_to_data(funsor_dist, name_to_dim=None):
v = to_data(funsor_dist.v, name_to_dim=name_to_dim)
log_density = to_data(funsor_dist.log_density, name_to_dim=name_to_dim)
return dist.Delta(v, log_density, event_dim=len(funsor_dist.v.output.shape))
@functools.singledispatch
def op_to_torch_transform(op, name_to_dim=None):
raise NotImplementedError("cannot convert {} to a Transform".format(op))
@op_to_torch_transform.register(ops.TransformOp)
def transform_to_torch_transform(op, name_to_dim=None):
raise NotImplementedError("{} is not a currently supported transform".format(op))
@op_to_torch_transform.register(ops.WrappedTransformOp)
def transform_to_torch_transform(op, name_to_dim=None):
return op.defaults["fn"]
@op_to_torch_transform.register(ops.ExpOp)
def exp_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.ExpTransform()
@op_to_torch_transform.register(ops.LogOp)
def log_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.ExpTransform().inv
@op_to_torch_transform.register(ops.SigmoidOp)
def sigmoid_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.SigmoidTransform()
@op_to_torch_transform.register(ops.TanhOp)
def tanh_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.TanhTransform()
@op_to_torch_transform.register(ops.AtanhOp)
def atanh_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.TanhTransform().inv
@to_data.register(Unary[ops.TransformOp, Union[Unary, Variable]])
def transform_to_data(expr, name_to_dim=None):
if isinstance(expr.op, ops.TransformOp):
tfm = op_to_torch_transform(expr.op, name_to_dim=name_to_dim)
if isinstance(expr.arg, Unary):
tfm = torch.distributions.transforms.ComposeTransform(
[to_data(expr.arg, name_to_dim=name_to_dim), tfm]
)
return tfm
raise NotImplementedError("cannot convert to data: {}".format(expr))
###############################################
# Converting PyTorch Distributions to funsors
###############################################
@to_funsor.register(torch.distributions.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
op = ops.WrappedTransformOp(fn=tfm)
name = next(real_inputs.keys()) if real_inputs else "value"
return op(Variable(name, output))
@to_funsor.register(torch.distributions.transforms.ExpTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.exp(Variable(name, output))
@to_funsor.register(torch.distributions.transforms.TanhTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.tanh(Variable(name, output))
@to_funsor.register(torch.distributions.transforms.SigmoidTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.sigmoid(Variable(name, output))
@to_funsor.register(torch.distributions.transforms._InverseTransform)
def inversetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
expr = to_funsor(
tfm._inv, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs
)
assert isinstance(expr, Unary)
return expr.op.inv(expr.arg)
@to_funsor.register(torch.distributions.transforms.ComposeTransform)
def composetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
expr = Variable(name, output)
for part in tfm.parts:
expr = to_funsor(
part, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs
)(**{name: expr})
return expr
to_funsor.register(ExpandedDistribution)(expandeddist_to_funsor)
to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(
transformeddist_to_funsor
)
@to_funsor.register(torch.distributions.Bernoulli)
def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits)
return backenddist_to_funsor(
BernoulliLogits, new_pyro_dist, output, dim_to_name
) # noqa: F821
@to_funsor.register(dist.Delta) # Delta **distribution**
def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None):
v = to_funsor(
pyro_dist.v, output=Reals[pyro_dist.event_shape], dim_to_name=dim_to_name
)
log_density = to_funsor(pyro_dist.log_density, output=Real, dim_to_name=dim_to_name)
return Delta(v, log_density) # noqa: F821
JointDirichletMultinomial = Contraction[
Union[ops.LogaddexpOp, ops.NullOp],
ops.AddOp,
frozenset,
Tuple[Dirichlet, Multinomial], # noqa: F821
]
eager.register(Beta, Funsor, Funsor, Funsor)(eager_beta) # noqa: F821)
eager.register(Binomial, Funsor, Funsor, Funsor)(eager_binomial) # noqa: F821
eager.register(Multinomial, Tensor, Tensor, Tensor)(eager_multinomial) # noqa: F821)
eager.register(Categorical, Funsor, Tensor)(eager_categorical_funsor) # noqa: F821)
eager.register(Categorical, Tensor, Variable)(eager_categorical_tensor) # noqa: F821)
eager.register(Categorical, Constant[Tuple, Tensor], Variable)(
eager_categorical_tensor
) # noqa: F821)
eager.register(Delta, Tensor, Tensor, Tensor)(eager_delta_tensor) # noqa: F821
eager.register(Delta, Funsor, Funsor, Variable)(
eager_delta_funsor_variable
) # noqa: F821
eager.register(Delta, Variable, Funsor, Variable)(
eager_delta_funsor_variable
) # noqa: F821
eager.register(Delta, Variable, Funsor, Funsor)(eager_delta_funsor_funsor) # noqa: F821
eager.register(Delta, Variable, Variable, Variable)(
eager_delta_variable_variable
) # noqa: F821
eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821
eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821
eager.register(
Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, BernoulliProbs
)( # noqa: F821
eager_beta_bernoulli
)
eager.register(
Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Categorical
)( # noqa: F821
eager_dirichlet_categorical
)
eager.register(
Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Dirichlet, Multinomial
)( # noqa: F821
eager_dirichlet_multinomial
)
eager.register(
Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Gamma
)( # noqa: F821
eager_gamma_gamma
)
eager.register(
Contraction, ops.LogaddexpOp, ops.AddOp, frozenset, Gamma, Poisson
)( # noqa: F821
eager_gamma_poisson
)
eager.register(
Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial
)( # noqa: F821
eager_dirichlet_posterior
)
eager.register(
Reduce, ops.AddOp, Multinomial[Tensor, Funsor, Funsor], frozenset
)( # noqa: F821
eager_plate_multinomial
)