Pyro-Compatible Distributions

This interface provides a number of PyTorch-style distributions that use funsors internally to perform inference. These high-level objects are based on a wrapping class: FunsorDistribution which wraps a funsor in a PyTorch-distributions-compatible interface. FunsorDistribution objects can be used directly in Pyro models (using the standard Pyro backend).

FunsorDistribution Base Class

class FunsorDistribution(funsor_dist, batch_shape=torch.Size([]), event_shape=torch.Size([]), dtype='real', validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Distribution wrapper around a Funsor for use in Pyro code. This is typically used as a base class for specific funsor inference algorithms wrapped in a distribution interface.

Parameters:
  • funsor_dist (funsor.terms.Funsor) – A funsor with an input named “value” that is treated as a random variable. The distribution should be normalized over “value”.
  • batch_shape (torch.Size) – The distribution’s batch shape. This must be in the same order as the input of the funsor_dist, but may contain extra dims of size 1.
  • event_shape – The distribution’s event shape.
arg_constraints = {}
support
log_prob(value)[source]
sample(sample_shape=torch.Size([]))[source]
rsample(sample_shape=torch.Size([]))[source]
expand(batch_shape, _instance=None)[source]
funsordistribution_to_funsor(pyro_dist, output=None, dim_to_name=None)[source]

Hidden Markov Models

class DiscreteHMM(initial_logits, transition_logits, observation_dist, validate_args=None)[source]

Bases: funsor.pyro.distribution.FunsorDistribution

Hidden Markov Model with discrete latent state and arbitrary observation distribution. This uses [1] to parallelize over time, achieving O(log(time)) parallel complexity.

The event_shape of this distribution includes time on the left:

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous time dependency of transition_logits and observation_dist. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with num_steps = 1, allowing log_prob() to work with arbitrary length data:

# homogeneous + homogeneous case:
event_shape = (1,) + observation_dist.event_shape

This class should be interchangeable with pyro.distributions.hmm.DiscreteHMM .

References:

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
Parameters:
  • initial_logits (Tensor) – A logits tensor for an initial categorical distribution over latent states. Should have rightmost size state_dim and be broadcastable to batch_shape + (state_dim,).
  • transition_logits (Tensor) – A logits tensor for transition conditional distributions between latent states. Should have rightmost shape (state_dim, state_dim) (old, new), and be broadcastable to batch_shape + (num_steps, state_dim, state_dim).
  • observation_dist (Distribution) – A conditional distribution of observed data conditioned on latent state. The .batch_shape should have rightmost size state_dim and be broadcastable to batch_shape + (num_steps, state_dim). The .event_shape may be arbitrary.
has_rsample
log_prob(value)[source]
expand(batch_shape, _instance=None)[source]
class GaussianHMM(initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None)[source]

Bases: funsor.pyro.distribution.FunsorDistribution

Hidden Markov Model with Gaussians for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure log_prob() is differentiable.

This corresponds to the generative model:

z = initial_distribution.sample()
x = []
for t in range(num_steps):
    z = z @ transition_matrix + transition_dist.sample()
    x.append(z @ observation_matrix + observation_dist.sample())

The event_shape of this distribution includes time on the left:

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous time dependency of transition_dist and observation_dist. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with num_steps = 1, allowing log_prob() to work with arbitrary length data:

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

This class should be compatible with pyro.distributions.hmm.GaussianHMM , but additionally supports funsor adjoint algorithms.

References:

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
Variables:
  • hidden_dim (int) – The dimension of the hidden state.
  • obs_dim (int) – The dimension of the observed state.
Parameters:
  • initial_dist (MultivariateNormal) – A distribution over initial states. This should have batch_shape broadcastable to self.batch_shape. This should have event_shape (hidden_dim,).
  • transition_matrix (Tensor) – A linear transformation of hidden state. This should have shape broadcastable to self.batch_shape + (num_steps, hidden_dim, hidden_dim) where the rightmost dims are ordered (old, new).
  • transition_dist (MultivariateNormal) – A process noise distribution. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (hidden_dim,).
  • transition_matrix – A linear transformation from hidden to observed state. This should have shape broadcastable to self.batch_shape + (num_steps, hidden_dim, obs_dim).
  • observation_dist (MultivariateNormal or Normal) – An observation noise distribution. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (obs_dim,).
has_rsample = True
arg_constraints = {}
class GaussianMRF(initial_dist, transition_dist, observation_dist, validate_args=None)[source]

Bases: funsor.pyro.distribution.FunsorDistribution

Temporal Markov Random Field with Gaussian factors for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure log_prob() is differentiable.

The event_shape of this distribution includes time on the left:

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous time dependency of transition_dist and observation_dist. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with num_steps = 1, allowing log_prob() to work with arbitrary length data:

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

This class should be compatible with pyro.distributions.hmm.GaussianMRF , but additionally supports funsor adjoint algorithms.

References:

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
Variables:
  • hidden_dim (int) – The dimension of the hidden state.
  • obs_dim (int) – The dimension of the observed state.
Parameters:
  • initial_dist (MultivariateNormal) – A distribution over initial states. This should have batch_shape broadcastable to self.batch_shape. This should have event_shape (hidden_dim,).
  • transition_dist (MultivariateNormal) – A joint distribution factor over a pair of successive time steps. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (hidden_dim + hidden_dim,) (old+new).
  • observation_dist (MultivariateNormal) – A joint distribution factor over a hidden and an observed state. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (hidden_dim + obs_dim,).
has_rsample = True
class SwitchingLinearHMM(initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None)[source]

Bases: funsor.pyro.distribution.FunsorDistribution

Switching Linear Dynamical System represented as a Hidden Markov Model.

This corresponds to the generative model:

z = Categorical(logits=initial_logits).sample()
y = initial_mvn[z].sample()
x = []
for t in range(num_steps):
    z = Categorical(logits=transition_logits[t, z]).sample()
    y = y @ transition_matrix[t, z] + transition_mvn[t, z].sample()
    x.append(y @ observation_matrix[t, z] + observation_mvn[t, z].sample())

Viewed as a dynamic Bayesian network:

z[t-1] ----> z[t] ---> z[t+1]         Discrete latent class
   |  \       |  \       |   \
   | y[t-1] ----> y[t] ----> y[t+1]   Gaussian latent state
   |   /      |   /      |   /
   V  /       V  /       V  /
x[t-1]       x[t]      x[t+1]         Gaussian observation

Let class be the latent class, state be the latent multivariate normal state, and value be the observed multivariate normal value.

Parameters:
  • initial_logits (Tensor) – Represents p(class[0]).
  • initial_mvn (MultivariateNormal) – Represents p(state[0] | class[0]).
  • transition_logits (Tensor) – Represents p(class[t+1] | class[t]).
  • transition_matrix (Tensor) –
  • transition_mvn (MultivariateNormal) – Together with transition_matrix, this represents p(state[t], state[t+1] | class[t]).
  • observation_matrix (Tensor) –
  • observation_mvn (MultivariateNormal) – Together with observation_matrix, this represents p(value[t+1], state[t+1] | class[t+1]).
  • exact (bool) – If True, perform exact inference at cost exponential in num_steps. If False, use a moment_matching() approximation and use parallel scan algorithm to reduce parallel complexity to logarithmic in num_steps. Defaults to False.
has_rsample = True
arg_constraints = {}
log_prob(value)[source]
expand(batch_shape, _instance=None)[source]
filter(value)[source]

Compute posterior over final state given a sequence of observations.

Parameters:value (Tensor) – A sequence of observations.
Returns:A posterior distribution over latent states at the final time step, represented as a pair (cat, mvn), where Categorical distribution over mixture components and mvn is a MultivariateNormal with rightmost batch dimension ranging over mixture components. This can then be used to initialize a sequential Pyro model for prediction.
Return type:tuple

Conversion Utilities

This module follows a convention for converting between funsors and PyTorch distribution objects. This convention is compatible with NumPy/PyTorch-style broadcasting. Following PyTorch distributions (and Tensorflow distributions), we consider “event shapes” to be on the right and broadcast-compatible “batch shapes” to be on the left.

This module also aims to be forgiving in inputs and pedantic in outputs: methods accept either the superclass torch.distributions.Distribution objects or the subclass pyro.distributions.TorchDistribution objects. Methods return only the narrower subclass pyro.distributions.TorchDistribution objects.

tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype='real')[source]

Convert a torch.Tensor to a funsor.tensor.Tensor .

Note this should not touch data, but may trigger a torch.Tensor.reshape() op.

Parameters:
  • tensor (torch.Tensor) – A PyTorch tensor.
  • event_inputs (tuple) – A tuple of names for rightmost tensor dimensions. If tensor has these names, they will be converted to result.inputs.
  • event_output (int) – The number of tensor dimensions assigned to result.output. These must be on the right of any event_input dimensions.
Returns:

A funsor.

Return type:

funsor.tensor.Tensor

funsor_to_tensor(funsor_, ndims, event_inputs=())[source]

Convert a funsor.tensor.Tensor to a torch.Tensor .

Note this should not touch data, but may trigger a torch.Tensor.reshape() op.

Parameters:
  • funsor (funsor.tensor.Tensor) – A funsor.
  • ndims (int) – The number of result dims, == result.dim().
  • event_inputs (tuple) – Names assigned to rightmost dimensions.
Returns:

A PyTorch tensor.

Return type:

torch.Tensor

dist_to_funsor(pyro_dist, event_inputs=())[source]

Convert a PyTorch distribution to a Funsor.

Parameters:torch.distribution.Distribution – A PyTorch distribution.
Returns:A funsor.
Return type:funsor.terms.Funsor
mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs={})[source]

Convert a joint torch.distributions.MultivariateNormal distribution into a Funsor with multiple real inputs.

This should satisfy:

sum(d.num_elements for d in real_inputs.values())
  == pyro_dist.event_shape[0]
Parameters:
  • pyro_dist (torch.distributions.MultivariateNormal) – A multivariate normal distribution over one or more variables of real or vector or tensor type.
  • event_inputs (tuple) – A tuple of names for rightmost dimensions. These will be assigned to result.inputs of type Bint.
  • real_inputs (OrderedDict) – A dict mapping real variable name to appropriately sized Real. The sum of all .numel() of all real inputs should be equal to the pyro_dist dimension.
Returns:

A funsor with given real_inputs and possibly additional Bint inputs.

Return type:

funsor.terms.Funsor

funsor_to_mvn(gaussian, ndims, event_inputs=())[source]

Convert a Funsor to a pyro.distributions.MultivariateNormal , dropping the normalization constant.

Parameters:
  • gaussian (funsor.gaussian.Gaussian or funsor.joint.Joint) – A Gaussian funsor.
  • ndims (int) – The number of batch dimensions in the result.
  • event_inputs (tuple) – A tuple of names to assign to rightmost dimensions.
Returns:

a multivariate normal distribution.

Return type:

pyro.distributions.MultivariateNormal

funsor_to_cat_and_mvn(funsor_, ndims, event_inputs)[source]

Converts a labeled gaussian mixture model to a pair of distributions.

Parameters:
  • funsor (funsor.joint.Joint) – A Gaussian mixture funsor.
  • ndims (int) – The number of batch dimensions in the result.
Returns:

A pair (cat, mvn), where cat is a Categorical distribution over mixture components and mvn is a MultivariateNormal with rightmost batch dimension ranging over mixture components.

matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name='value_x', y_name='value_y')[source]

Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:

y = x @ matrix + mvn.sample()

The result is a non-normalized Gaussian funsor with two real inputs, x_name and y_name, corresponding to a conditional distribution of real vector y` given real vector ``x.

Parameters:
  • matrix (torch.Tensor) – A matrix with rightmost shape (x_size, y_size).
  • mvn (torch.distributions.MultivariateNormal or torch.distributions.Independent of torch.distributions.Normal) – A multivariate normal distribution with event_shape == (y_size,).
  • event_dims (tuple) – A tuple of names for rightmost dimensions. These will be assigned to result.inputs of type Bint.
  • x_name (str) – The name of the x random variable.
  • y_name (str) – The name of the y random variable.
Returns:

A funsor with given real_inputs and possibly additional Bint inputs.

Return type:

funsor.terms.Funsor