"""
This module follows a convention for converting between funsors and PyTorch
distribution objects. This convention is compatible with NumPy/PyTorch-style
broadcasting. Following PyTorch distributions (and Tensorflow distributions),
we consider "event shapes" to be on the right and broadcast-compatible "batch
shapes" to be on the left.
This module also aims to be forgiving in inputs and pedantic in outputs:
methods accept either the superclass :class:`torch.distributions.Distribution`
objects or the subclass :class:`pyro.distributions.TorchDistribution` objects.
Methods return only the narrower subclass
:class:`pyro.distributions.TorchDistribution` objects.
"""
import math
from collections import OrderedDict
from functools import singledispatch
import pyro.distributions
import torch
import torch.distributions as dist
from pyro.distributions.torch_distribution import MaskedDistribution
from pyro.distributions.util import broadcast_shape
from funsor.cnf import Contraction
from funsor.delta import Delta
from funsor.distributions import BernoulliLogits, MultivariateNormal, Normal
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian, align_tensors, cholesky
from funsor.interpreter import gensym
from funsor.terms import Funsor, Independent, Variable, eager
from funsor.torch import Tensor
# Conversion functions use fixed names for Pyro batch dims, but
# accept an event_inputs tuple for custom event dim names.
DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0)))
NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0)))
[docs]def tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype="real"):
"""
Convert a :class:`torch.Tensor` to a :class:`funsor.torch.Tensor` .
Note this should not touch data, but may trigger a
:meth:`torch.Tensor.reshape` op.
:param torch.Tensor tensor: A PyTorch tensor.
:param tuple event_inputs: A tuple of names for rightmost tensor
dimensions. If ``tensor`` has these names, they will be converted to
``result.inputs``.
:param int event_output: The number of tensor dimensions assigned to
``result.output``. These must be on the right of any ``event_input``
dimensions.
:return: A funsor.
:rtype: funsor.torch.Tensor
"""
assert isinstance(tensor, torch.Tensor)
assert isinstance(event_inputs, tuple)
assert isinstance(event_output, int) and event_output >= 0
inputs_shape = tensor.shape[:tensor.dim() - event_output]
output_shape = tensor.shape[tensor.dim() - event_output:]
dim_to_name = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME
# Squeeze shape of inputs.
inputs = OrderedDict()
squeezed_shape = []
for dim, size in enumerate(inputs_shape):
if size > 1:
name = dim_to_name[dim - len(inputs_shape)]
inputs[name] = bint(size)
squeezed_shape.append(size)
squeezed_shape = torch.Size(squeezed_shape)
if squeezed_shape != inputs_shape:
tensor = tensor.reshape(squeezed_shape + output_shape)
return Tensor(tensor, inputs, dtype)
[docs]def funsor_to_tensor(funsor_, ndims, event_inputs=()):
"""
Convert a :class:`funsor.torch.Tensor` to a :class:`torch.Tensor` .
Note this should not touch data, but may trigger a
:meth:`torch.Tensor.reshape` op.
:param funsor.torch.Tensor funsor_: A funsor.
:param int ndims: The number of result dims, ``== result.dim()``.
:param tuple event_inputs: Names assigned to rightmost dimensions.
:return: A PyTorch tensor.
:rtype: torch.Tensor
"""
assert isinstance(funsor_, Tensor)
assert all(k.startswith("_pyro_dim_") or k in event_inputs for k in funsor_.inputs)
name_to_dim = NAME_TO_DIM
if event_inputs:
dim_to_name = DIM_TO_NAME + event_inputs
name_to_dim = dict(zip(dim_to_name, range(-len(dim_to_name), 0)))
names = tuple(sorted(funsor_.inputs, key=name_to_dim.__getitem__))
tensor = funsor_.align(names).data
if names:
# Unsqueeze shape of inputs.
dims = list(map(name_to_dim.__getitem__, names))
inputs_shape = [1] * (-dims[0])
for dim, size in zip(dims, tensor.shape):
inputs_shape[dim] = size
inputs_shape = torch.Size(inputs_shape)
tensor = tensor.reshape(inputs_shape + funsor_.output.shape)
if ndims != tensor.dim():
tensor = tensor.reshape((1,) * (ndims - tensor.dim()) + tensor.shape)
assert tensor.dim() == ndims
return tensor
[docs]def mvn_to_funsor(pyro_dist, event_dims=(), real_inputs=OrderedDict()):
"""
Convert a joint :class:`torch.distributions.MultivariateNormal`
distribution into a :class:`~funsor.terms.Funsor` with multiple real
inputs.
This should satisfy::
sum(d.num_elements for d in real_inputs.values())
== pyro_dist.event_shape[0]
:param torch.distributions.MultivariateNormal pyro_dist: A
multivariate normal distribution over one or more variables
of real or vector or tensor type.
:param tuple event_dims: A tuple of names for rightmost dimensions.
These will be assigned to ``result.inputs`` of type ``bint``.
:param OrderedDict real_inputs: A dict mapping real variable name
to appropriately sized ``reals()``. The sum of all ``.numel()``
of all real inputs should be equal to the ``pyro_dist`` dimension.
:return: A funsor with given ``real_inputs`` and possibly additional
bint inputs.
:rtype: funsor.terms.Funsor
"""
assert isinstance(pyro_dist, torch.distributions.MultivariateNormal)
assert isinstance(event_dims, tuple)
assert isinstance(real_inputs, OrderedDict)
loc = tensor_to_funsor(pyro_dist.loc, event_dims, 1)
scale_tril = tensor_to_funsor(pyro_dist.scale_tril, event_dims, 2)
precision = tensor_to_funsor(pyro_dist.precision_matrix, event_dims, 2)
assert loc.inputs == scale_tril.inputs
assert loc.inputs == precision.inputs
info_vec = precision.data.matmul(loc.data.unsqueeze(-1)).squeeze(-1)
log_prob = (-0.5 * loc.output.shape[0] * math.log(2 * math.pi)
- scale_tril.data.diagonal(dim1=-1, dim2=-2).log().sum(-1)
- 0.5 * (info_vec * loc.data).sum(-1))
inputs = loc.inputs.copy()
inputs.update(real_inputs)
return Tensor(log_prob, loc.inputs) + Gaussian(info_vec, precision.data, inputs)
[docs]def funsor_to_mvn(gaussian, ndims, event_inputs=()):
"""
Convert a :class:`~funsor.terms.Funsor` to a
:class:`pyro.distributions.MultivariateNormal` , dropping the normalization
constant.
:param gaussian: A Gaussian funsor.
:type gaussian: funsor.gaussian.Gaussian or funsor.joint.Joint
:param int ndims: The number of batch dimensions in the result.
:param tuple event_inputs: A tuple of names to assign to rightmost
dimensions.
:return: a multivariate normal distribution.
:rtype: pyro.distributions.MultivariateNormal
"""
assert sum(1 for d in gaussian.inputs.values() if d.dtype == "real") == 1
if isinstance(gaussian, Contraction):
gaussian = [v for v in gaussian.terms if isinstance(v, Gaussian)][0]
assert isinstance(gaussian, Gaussian)
precision = gaussian.precision
loc = gaussian.info_vec.unsqueeze(-1).cholesky_solve(cholesky(precision)).squeeze(-1)
int_inputs = OrderedDict((k, d) for k, d in gaussian.inputs.items() if d.dtype != "real")
loc = Tensor(loc, int_inputs)
precision = Tensor(precision, int_inputs)
assert len(loc.output.shape) == 1
assert precision.output.shape == loc.output.shape * 2
loc = funsor_to_tensor(loc, ndims + 1, event_inputs)
precision = funsor_to_tensor(precision, ndims + 2, event_inputs)
return pyro.distributions.MultivariateNormal(loc, precision_matrix=precision)
[docs]def funsor_to_cat_and_mvn(funsor_, ndims, event_inputs):
"""
Converts a labeled gaussian mixture model to a pair of distributions.
:param funsor.joint.Joint funsor_: A Gaussian mixture funsor.
:param int ndims: The number of batch dimensions in the result.
:return: A pair ``(cat, mvn)``, where ``cat`` is a
:class:`~pyro.distributions.Categorical` distribution over mixture
components and ``mvn`` is a
:class:`~pyro.distributions.MultivariateNormal` with rightmost batch
dimension ranging over mixture components.
"""
assert isinstance(funsor_, Contraction), funsor_
assert sum(1 for d in funsor_.inputs.values() if d.dtype == "real") == 1
assert event_inputs, "no components name found"
assert not any(isinstance(v, Delta) for v in funsor_.terms)
discrete = [v for v in funsor_.terms if isinstance(v, Tensor)][0]
gaussian = [v for v in funsor_.terms if isinstance(v, Gaussian)][0]
assert isinstance(discrete, Tensor)
assert isinstance(gaussian, Gaussian)
logits = funsor_to_tensor(discrete + gaussian.log_normalizer, ndims + 1, event_inputs)
cat = pyro.distributions.Categorical(logits=logits)
mvn = funsor_to_mvn(gaussian, ndims + 1, event_inputs)
assert cat.batch_shape == mvn.batch_shape[:-1]
return cat, mvn
[docs]class AffineNormal(Funsor):
"""
Represents a conditional diagonal normal distribution over a random
variable ``Y`` whose mean is an affine function of a random variable ``X``.
The likelihood of ``X`` is thus::
AffineNormal(matrix, loc, scale).condition(y).log_density(x)
which is equivalent to::
Normal(x @ matrix + loc, scale).to_event(1).log_prob(y)
:param ~funsor.terms.Funsor matrix: A transformation from ``X`` to ``Y``.
Should have rightmost shape ``(x_dim, y_dim)``.
:param ~funsor.terms.Funsor loc: A constant offset for ``Y``'s mean.
Should have output shape ``(y_dim,)``.
:param ~funsor.terms.Funsor scale: Standard deviation for ``Y``.
Should have output shape ``(y_dim,)``.
:param ~funsor.terms.Funsor value_x: A value ``X``.
:param ~funsor.terms.Funsor value_y: A value ``Y``.
"""
def __init__(self, matrix, loc, scale, value_x, value_y):
assert len(matrix.output.shape) == 2
assert value_x.output == reals(matrix.output.shape[0])
assert value_y.output == reals(matrix.output.shape[1])
inputs = OrderedDict()
for f in (matrix, loc, scale, value_x, value_y):
inputs.update(f.inputs)
output = reals()
super().__init__(inputs, output)
self.matrix = matrix
self.loc = loc
self.scale = scale
self.value_x = value_x
self.value_y = value_y
@eager.register(AffineNormal, Tensor, Tensor, Tensor, Tensor, (Funsor, Tensor))
def eager_affine_normal(matrix, loc, scale, value_x, value_y):
assert len(matrix.output.shape) == 2
assert value_x.output == reals(matrix.output.shape[0])
assert value_y.output == reals(matrix.output.shape[1])
loc += value_x @ matrix
int_inputs, (loc, scale) = align_tensors(loc, scale, expand=True)
i_name = gensym("i")
y_name = gensym("y")
y_i_name = gensym("y_i")
int_inputs[i_name] = bint(value_y.output.shape[0])
loc = Tensor(loc, int_inputs)
scale = Tensor(scale, int_inputs)
y_dist = Independent(Normal(loc, scale, y_i_name), y_name, i_name, y_i_name)
return y_dist(**{y_name: value_y})
@eager.register(AffineNormal, Tensor, Tensor, Tensor, Funsor, Tensor)
def eager_affine_normal(matrix, loc, scale, value_x, value_y):
assert len(matrix.output.shape) == 2
assert value_x.output == reals(matrix.output.shape[0])
assert value_y.output == reals(matrix.output.shape[1])
tensors = (matrix, loc, scale, value_y)
int_inputs, tensors = align_tensors(*tensors)
matrix, loc, scale, value_y = tensors
assert value_y.size(-1) == loc.size(-1)
prec_sqrt = matrix / scale.unsqueeze(-2)
precision = prec_sqrt.matmul(prec_sqrt.transpose(-1, -2))
delta = (value_y - loc) / scale
info_vec = prec_sqrt.matmul(delta.unsqueeze(-1)).squeeze(-1)
log_normalizer = (-0.5 * loc.size(-1) * math.log(2 * math.pi)
- 0.5 * delta.pow(2).sum(-1) - scale.log().sum(-1))
precision = precision.expand(info_vec.shape + (-1,))
log_normalizer = log_normalizer.expand(info_vec.shape[:-1])
inputs = int_inputs.copy()
x_name = gensym("x")
inputs[x_name] = value_x.output
x_dist = Tensor(log_normalizer, int_inputs) + Gaussian(info_vec, precision, inputs)
return x_dist(**{x_name: value_x})
[docs]def matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name="value_x", y_name="value_y"):
"""
Convert a noisy affine function to a Gaussian. The noisy affine function is
defined as::
y = x @ matrix + mvn.sample()
The result is a non-normalized Gaussian funsor with two real inputs,
``x_name`` and ``y_name``, corresponding to a conditional distribution of
real vector ``y` given real vector ``x``.
:param torch.Tensor matrix: A matrix with rightmost shape ``(x_size, y_size)``.
:param mvn: A multivariate normal distribution with
``event_shape == (y_size,)``.
:type mvn: torch.distributions.MultivariateNormal or
torch.distributions.Independent of torch.distributions.Normal
:param tuple event_dims: A tuple of names for rightmost dimensions.
These will be assigned to ``result.inputs`` of type ``bint``.
:param str x_name: The name of the ``x`` random variable.
:param str y_name: The name of the ``y`` random variable.
:return: A funsor with given ``real_inputs`` and possibly additional
bint inputs.
:rtype: funsor.terms.Funsor
"""
assert (isinstance(mvn, torch.distributions.MultivariateNormal) or
(isinstance(mvn, torch.distributions.Independent) and
isinstance(mvn.base_dist, torch.distributions.Normal)))
assert isinstance(matrix, torch.Tensor)
x_size, y_size = matrix.shape[-2:]
assert mvn.event_shape == (y_size,)
# Handle diagonal normal distributions as an efficient special case.
if isinstance(mvn, torch.distributions.Independent):
return AffineNormal(tensor_to_funsor(matrix, event_dims, 2),
tensor_to_funsor(mvn.base_dist.loc, event_dims, 1),
tensor_to_funsor(mvn.base_dist.scale, event_dims, 1),
Variable(x_name, reals(x_size)),
Variable(y_name, reals(y_size)))
info_vec = mvn.loc.unsqueeze(-1).cholesky_solve(mvn.scale_tril).squeeze(-1)
log_prob = (-0.5 * y_size * math.log(2 * math.pi)
- mvn.scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
- 0.5 * (info_vec * mvn.loc).sum(-1))
batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape)
P_yy = mvn.precision_matrix.expand(batch_shape + (y_size, y_size))
neg_P_xy = matrix.matmul(P_yy)
P_xy = -neg_P_xy
P_yx = P_xy.transpose(-1, -2)
P_xx = neg_P_xy.matmul(matrix.transpose(-1, -2))
precision = torch.cat([torch.cat([P_xx, P_xy], -1),
torch.cat([P_yx, P_yy], -1)], -2)
info_y = info_vec.expand(batch_shape + (y_size,))
info_x = -matrix.matmul(info_y.unsqueeze(-1)).squeeze(-1)
info_vec = torch.cat([info_x, info_y], -1)
info_vec = tensor_to_funsor(info_vec, event_dims, 1)
precision = tensor_to_funsor(precision, event_dims, 2)
inputs = info_vec.inputs.copy()
inputs[x_name] = reals(x_size)
inputs[y_name] = reals(y_size)
return tensor_to_funsor(log_prob, event_dims) + Gaussian(info_vec.data, precision.data, inputs)
[docs]@singledispatch
def dist_to_funsor(pyro_dist, event_inputs=()):
"""
Convert a PyTorch distribution to a Funsor.
This is currently implemented for only a subset of distribution types.
:param torch.distribution.Distribution: A PyTorch distribution.
:return: A funsor.
:rtype: funsor.terms.Funsor
"""
assert isinstance(pyro_dist, torch.distributions.Distribution)
raise ValueError("Cannot convert {} distribution to a Funsor"
.format(type(pyro_dist).__name__))
@dist_to_funsor.register(dist.Independent)
def _independent_to_funsor(pyro_dist, event_inputs=()):
event_names = tuple("_event_{}".format(len(event_inputs) + i)
for i in range(pyro_dist.reinterpreted_batch_ndims))
result = dist_to_funsor(pyro_dist.base_dist, event_inputs + event_names)
for name in reversed(event_names):
result = Independent(result, "value", name, "value")
return result
@dist_to_funsor.register(MaskedDistribution)
def _masked_to_funsor(pyro_dist, event_inputs=()):
# FIXME This is subject to NANs.
mask = tensor_to_funsor(pyro_dist._mask.float(), event_inputs)
result = mask * dist_to_funsor(pyro_dist.base_dist, event_inputs)
return result
@dist_to_funsor.register(dist.Categorical)
def _categorical_to_funsor(pyro_dist, event_inputs=()):
return tensor_to_funsor(pyro_dist.logits, event_inputs + ("value",))
@dist_to_funsor.register(dist.Bernoulli)
def _bernoulli_to_funsor(pyro_dist, event_inputs=()):
logits = tensor_to_funsor(pyro_dist.logits, event_inputs)
return BernoulliLogits(logits)
@dist_to_funsor.register(dist.Normal)
def _normal_to_funsor(pyro_dist, event_inputs=()):
loc = tensor_to_funsor(pyro_dist.loc, event_inputs)
scale = tensor_to_funsor(pyro_dist.scale, event_inputs)
return Normal(loc, scale)
@dist_to_funsor.register(dist.MultivariateNormal)
def _mvn_to_funsor(pyro_dist, event_inputs=()):
loc = tensor_to_funsor(pyro_dist.loc, event_inputs, 1)
scale_tril = tensor_to_funsor(pyro_dist.scale_tril, event_inputs, 2)
return MultivariateNormal(loc, scale_tril)