Source code for funsor.pyro.distribution

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict

import torch
from torch.distributions import constraints
from pyro.distributions import TorchDistribution

from funsor.cnf import Contraction
from import Delta
from import Bint
from funsor.pyro.convert import DIM_TO_NAME, funsor_to_tensor, tensor_to_funsor
from funsor.terms import Funsor, to_funsor

[docs]class FunsorDistribution(TorchDistribution): """ :class:`~torch.distributions.Distribution` wrapper around a :class:`~funsor.terms.Funsor` for use in Pyro code. This is typically used as a base class for specific funsor inference algorithms wrapped in a distribution interface. :param funsor.terms.Funsor funsor_dist: A funsor with an input named "value" that is treated as a random variable. The distribution should be normalized over "value". :param torch.Size batch_shape: The distribution's batch shape. This must be in the same order as the input of the ``funsor_dist``, but may contain extra dims of size 1. :param event_shape: The distribution's event shape. """ arg_constraints = {} def __init__(self, funsor_dist, batch_shape=torch.Size(), event_shape=torch.Size(), dtype="real", validate_args=None): assert isinstance(funsor_dist, Funsor) assert isinstance(batch_shape, tuple) assert isinstance(event_shape, tuple) assert "value" in funsor_dist.inputs super(FunsorDistribution, self).__init__(batch_shape, event_shape, validate_args) self.funsor_dist = funsor_dist self.dtype = dtype @constraints.dependent_property def support(self): if self.dtype == "real": return constraints.real else: return constraints.integer_interval(0, self.dtype - 1)
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) ndims = max(len(self.batch_shape), value.dim() - self.event_dim) value = tensor_to_funsor(value, event_output=self.event_dim, dtype=self.dtype) log_prob = self.funsor_dist(value=value) log_prob = funsor_to_tensor(log_prob, ndims=ndims) return log_prob
def _sample_delta(self, sample_shape): sample_inputs = None if sample_shape: sample_inputs = OrderedDict() shape = sample_shape + self.batch_shape for dim in range(-len(shape), -len(self.batch_shape)): if shape[dim] > 1: sample_inputs[DIM_TO_NAME[dim]] = Bint[shape[dim]] delta = self.funsor_dist.sample(frozenset({"value"}), sample_inputs) if isinstance(delta, Contraction): assert len([d for d in delta.terms if isinstance(d, Delta)]) == 1 delta = delta.terms[0] assert isinstance(delta, Delta) return delta
[docs] @torch.no_grad() def sample(self, sample_shape=torch.Size()): delta = self._sample_delta(sample_shape) ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape) value = funsor_to_tensor(delta.terms[0][1][0], ndims=ndims) return value.detach()
[docs] def rsample(self, sample_shape=torch.Size()): delta = self._sample_delta(sample_shape) assert not delta.log_density.requires_grad, "distribution is not fully reparametrized" ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape) value = funsor_to_tensor(delta.terms[0][1][0], ndims=ndims) return value
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) batch_shape = torch.Size(batch_shape) funsor_dist = self.funsor_dist + tensor_to_funsor(torch.zeros(batch_shape)) super(type(self), new).__init__( funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False) new.validate_args = self.__dict__.get('_validate_args') return new
[docs]@to_funsor.register(FunsorDistribution) def funsordistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): raise NotImplementedError("TODO implement conversion for FunsorDistribution")