Source code for funsor.pyro.distribution

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict

import torch
from pyro.distributions import TorchDistribution
from torch.distributions import constraints

from funsor.cnf import Contraction
from import Delta
from import Bint
from funsor.interpreter import reinterpret
from funsor.pyro.convert import DIM_TO_NAME, funsor_to_tensor, tensor_to_funsor
from funsor.terms import Funsor, to_funsor

[docs]class FunsorDistribution(TorchDistribution): """ :class:`~torch.distributions.Distribution` wrapper around a :class:`~funsor.terms.Funsor` for use in Pyro code. This is typically used as a base class for specific funsor inference algorithms wrapped in a distribution interface. :param funsor.terms.Funsor funsor_dist: A funsor with an input named "value" that is treated as a random variable. The distribution should be normalized over "value". :param torch.Size batch_shape: The distribution's batch shape. This must be in the same order as the input of the ``funsor_dist``, but may contain extra dims of size 1. :param event_shape: The distribution's event shape. """ arg_constraints = {} def __init__( self, funsor_dist, batch_shape=torch.Size(), event_shape=torch.Size(), dtype="real", validate_args=None, ): assert isinstance(funsor_dist, Funsor) assert isinstance(batch_shape, tuple) assert isinstance(event_shape, tuple) assert "value" in funsor_dist.inputs super(FunsorDistribution, self).__init__( batch_shape, event_shape, validate_args ) self.funsor_dist = funsor_dist self.dtype = dtype @constraints.dependent_property def support(self): if self.dtype == "real": return constraints.real else: return constraints.integer_interval(0, self.dtype - 1)
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) ndims = max(len(self.batch_shape), value.dim() - self.event_dim) value = tensor_to_funsor(value, event_output=self.event_dim, dtype=self.dtype) log_prob = reinterpret(self.funsor_dist(value=value)) log_prob = funsor_to_tensor(log_prob, ndims=ndims) return log_prob
def _sample_delta(self, sample_shape): sample_inputs = None if sample_shape: sample_inputs = OrderedDict() shape = sample_shape + self.batch_shape for dim in range(-len(shape), -len(self.batch_shape)): if shape[dim] > 1: sample_inputs[DIM_TO_NAME[dim]] = Bint[shape[dim]] delta = self.funsor_dist.sample(frozenset({"value"}), sample_inputs) if isinstance(delta, Contraction): assert len([d for d in delta.terms if isinstance(d, Delta)]) == 1 delta = delta.terms[0] assert isinstance(delta, Delta) return delta
[docs] @torch.no_grad() def sample(self, sample_shape=torch.Size()): delta = self._sample_delta(sample_shape) ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape) value = funsor_to_tensor(delta.terms[0][1][0], ndims=ndims) return value.detach()
[docs] def rsample(self, sample_shape=torch.Size()): delta = self._sample_delta(sample_shape) assert ( not delta.log_density.requires_grad ), "distribution is not fully reparametrized" ndims = len(sample_shape) + len(self.batch_shape) + len(self.event_shape) value = funsor_to_tensor(delta.terms[0][1][0], ndims=ndims) return value
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) batch_shape = torch.Size(batch_shape) funsor_dist = self.funsor_dist + tensor_to_funsor(torch.zeros(batch_shape)) super(type(self), new).__init__( funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False ) new.validate_args = self.__dict__.get("_validate_args") return new
[docs]@to_funsor.register(FunsorDistribution) def funsordistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): raise NotImplementedError("TODO implement conversion for FunsorDistribution")