Pyro-Compatible Distributions¶
This interface provides a number of PyTorch-style distributions that use
funsors internally to perform inference. These high-level objects are based on
a wrapping class: FunsorDistribution
which
wraps a funsor in a PyTorch-distributions-compatible interface.
FunsorDistribution
objects can be used
directly in Pyro models (using the standard Pyro backend).
FunsorDistribution Base Class¶
- class FunsorDistribution(funsor_dist, batch_shape=torch.Size([]), event_shape=torch.Size([]), dtype='real', validate_args=None)[source]¶
Bases:
TorchDistribution
Distribution
wrapper around aFunsor
for use in Pyro code. This is typically used as a base class for specific funsor inference algorithms wrapped in a distribution interface.- Parameters
funsor_dist (funsor.terms.Funsor) – A funsor with an input named “value” that is treated as a random variable. The distribution should be normalized over “value”.
batch_shape (torch.Size) – The distribution’s batch shape. This must be in the same order as the input of the
funsor_dist
, but may contain extra dims of size 1.event_shape – The distribution’s event shape.
- arg_constraints = {}¶
- property support¶
Hidden Markov Models¶
- class DiscreteHMM(initial_logits, transition_logits, observation_dist, validate_args=None)[source]¶
Bases:
FunsorDistribution
Hidden Markov Model with discrete latent state and arbitrary observation distribution. This uses [1] to parallelize over time, achieving O(log(time)) parallel complexity.
The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_logits
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:# homogeneous + homogeneous case: event_shape = (1,) + observation_dist.event_shape
This class should be interchangeable with
pyro.distributions.hmm.DiscreteHMM
.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
- Parameters
initial_logits (Tensor) – A logits tensor for an initial categorical distribution over latent states. Should have rightmost size
state_dim
and be broadcastable tobatch_shape + (state_dim,)
.transition_logits (Tensor) – A logits tensor for transition conditional distributions between latent states. Should have rightmost shape
(state_dim, state_dim)
(old, new), and be broadcastable tobatch_shape + (num_steps, state_dim, state_dim)
.observation_dist (Distribution) – A conditional distribution of observed data conditioned on latent state. The
.batch_shape
should have rightmost sizestate_dim
and be broadcastable tobatch_shape + (num_steps, state_dim)
. The.event_shape
may be arbitrary.
- property has_rsample¶
- class GaussianHMM(initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None)[source]¶
Bases:
FunsorDistribution
Hidden Markov Model with Gaussians for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure
log_prob()
is differentiable.This corresponds to the generative model:
z = initial_distribution.sample() x = [] for t in range(num_steps): z = z @ transition_matrix + transition_dist.sample() x.append(z @ observation_matrix + observation_dist.sample())
The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_dist
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:event_shape = (1, obs_dim) # homogeneous + homogeneous case
This class should be compatible with
pyro.distributions.hmm.GaussianHMM
, but additionally supports funsoradjoint
algorithms.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
- Variables
- Parameters
initial_dist (MultivariateNormal) – A distribution over initial states. This should have batch_shape broadcastable to
self.batch_shape
. This should have event_shape(hidden_dim,)
.transition_matrix (Tensor) – A linear transformation of hidden state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim)
where the rightmost dims are ordered(old, new)
.transition_dist (MultivariateNormal) – A process noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim,)
.transition_matrix – A linear transformation from hidden to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim)
.observation_dist (MultivariateNormal or Independent of Normal) – An observation noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(obs_dim,)
.
- has_rsample = True¶
- arg_constraints = {}¶
- class GaussianMRF(initial_dist, transition_dist, observation_dist, validate_args=None)[source]¶
Bases:
FunsorDistribution
Temporal Markov Random Field with Gaussian factors for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure
log_prob()
is differentiable.The event_shape of this distribution includes time on the left:
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous time dependency of
transition_dist
andobservation_dist
. However, because time is included in this distribution’s event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape withnum_steps = 1
, allowinglog_prob()
to work with arbitrary length data:event_shape = (1, obs_dim) # homogeneous + homogeneous case
This class should be compatible with
pyro.distributions.hmm.GaussianMRF
, but additionally supports funsoradjoint
algorithms.References:
- [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
“Temporal Parallelization of Bayesian Filters and Smoothers” https://arxiv.org/pdf/1905.13002.pdf
- Variables
- Parameters
initial_dist (MultivariateNormal) – A distribution over initial states. This should have batch_shape broadcastable to
self.batch_shape
. This should have event_shape(hidden_dim,)
.transition_dist (MultivariateNormal) – A joint distribution factor over a pair of successive time steps. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim + hidden_dim,)
(old+new).observation_dist (MultivariateNormal) – A joint distribution factor over a hidden and an observed state. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,)
. This should have event_shape(hidden_dim + obs_dim,)
.
- has_rsample = True¶
- class SwitchingLinearHMM(initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None)[source]¶
Bases:
FunsorDistribution
Switching Linear Dynamical System represented as a Hidden Markov Model.
This corresponds to the generative model:
z = Categorical(logits=initial_logits).sample() y = initial_mvn[z].sample() x = [] for t in range(num_steps): z = Categorical(logits=transition_logits[t, z]).sample() y = y @ transition_matrix[t, z] + transition_mvn[t, z].sample() x.append(y @ observation_matrix[t, z] + observation_mvn[t, z].sample())
Viewed as a dynamic Bayesian network:
z[t-1] ----> z[t] ---> z[t+1] Discrete latent class | \ | \ | \ | y[t-1] ----> y[t] ----> y[t+1] Gaussian latent state | / | / | / V / V / V / x[t-1] x[t] x[t+1] Gaussian observation
Let
class
be the latent class,state
be the latent multivariate normal state, andvalue
be the observed multivariate normal value.- Parameters
initial_logits (Tensor) – Represents
p(class[0])
.initial_mvn (MultivariateNormal) – Represents
p(state[0] | class[0])
.transition_logits (Tensor) – Represents
p(class[t+1] | class[t])
.transition_matrix (Tensor) –
transition_mvn (MultivariateNormal) – Together with
transition_matrix
, this representsp(state[t], state[t+1] | class[t])
.observation_matrix (Tensor) –
observation_mvn (MultivariateNormal) – Together with
observation_matrix
, this representsp(value[t+1], state[t+1] | class[t+1])
.exact (bool) – If True, perform exact inference at cost exponential in
num_steps
. If False, use amoment_matching()
approximation and use parallel scan algorithm to reduce parallel complexity to logarithmic innum_steps
. Defaults to False.
- has_rsample = True¶
- arg_constraints = {}¶
- filter(value)[source]¶
Compute posterior over final state given a sequence of observations.
- Parameters
value (Tensor) – A sequence of observations.
- Returns
A posterior distribution over latent states at the final time step, represented as a pair
(cat, mvn)
, whereCategorical
distribution over mixture components andmvn
is aMultivariateNormal
with rightmost batch dimension ranging over mixture components. This can then be used to initialize a sequential Pyro model for prediction.- Return type
Conversion Utilities¶
This module follows a convention for converting between funsors and PyTorch distribution objects. This convention is compatible with NumPy/PyTorch-style broadcasting. Following PyTorch distributions (and Tensorflow distributions), we consider “event shapes” to be on the right and broadcast-compatible “batch shapes” to be on the left.
This module also aims to be forgiving in inputs and pedantic in outputs:
methods accept either the superclass torch.distributions.Distribution
objects or the subclass pyro.distributions.TorchDistribution
objects.
Methods return only the narrower subclass
pyro.distributions.TorchDistribution
objects.
- tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype='real')[source]¶
Convert a
torch.Tensor
to afunsor.tensor.Tensor
.Note this should not touch data, but may trigger a
torch.Tensor.reshape()
op.- Parameters
tensor (torch.Tensor) – A PyTorch tensor.
event_inputs (tuple) – A tuple of names for rightmost tensor dimensions. If
tensor
has these names, they will be converted toresult.inputs
.event_output (int) – The number of tensor dimensions assigned to
result.output
. These must be on the right of anyevent_input
dimensions.
- Returns
A funsor.
- Return type
- funsor_to_tensor(funsor_, ndims, event_inputs=())[source]¶
Convert a
funsor.tensor.Tensor
to atorch.Tensor
.Note this should not touch data, but may trigger a
torch.Tensor.reshape()
op.- Parameters
funsor (funsor.tensor.Tensor) – A funsor.
ndims (int) – The number of result dims,
== result.dim()
.event_inputs (tuple) – Names assigned to rightmost dimensions.
- Returns
A PyTorch tensor.
- Return type
- dist_to_funsor(pyro_dist, event_inputs=())[source]¶
Convert a PyTorch distribution to a Funsor.
- Parameters
torch.distribution.Distribution – A PyTorch distribution.
- Returns
A funsor.
- Return type
- mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs={})[source]¶
Convert a joint
torch.distributions.MultivariateNormal
distribution into aFunsor
with multiple real inputs.This should satisfy:
sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0]
- Parameters
pyro_dist (torch.distributions.MultivariateNormal) – A multivariate normal distribution over one or more variables of real or vector or tensor type.
event_inputs (tuple) – A tuple of names for rightmost dimensions. These will be assigned to
result.inputs
of typeBint
.real_inputs (OrderedDict) – A dict mapping real variable name to appropriately sized
Real
. The sum of all.numel()
of all real inputs should be equal to thepyro_dist
dimension.
- Returns
A funsor with given
real_inputs
and possibly additional Bint inputs.- Return type
- funsor_to_mvn(gaussian, ndims, event_inputs=())[source]¶
Convert a
Funsor
to apyro.distributions.MultivariateNormal
, dropping the normalization constant.- Parameters
gaussian (funsor.gaussian.Gaussian or funsor.joint.Joint) – A Gaussian funsor.
ndims (int) – The number of batch dimensions in the result.
event_inputs (tuple) – A tuple of names to assign to rightmost dimensions.
- Returns
a multivariate normal distribution.
- Return type
- funsor_to_cat_and_mvn(funsor_, ndims, event_inputs)[source]¶
Converts a labeled gaussian mixture model to a pair of distributions.
- Parameters
funsor (funsor.joint.Joint) – A Gaussian mixture funsor.
ndims (int) – The number of batch dimensions in the result.
- Returns
A pair
(cat, mvn)
, wherecat
is aCategorical
distribution over mixture components andmvn
is aMultivariateNormal
with rightmost batch dimension ranging over mixture components.
- matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name='value_x', y_name='value_y')[source]¶
Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:
y = x @ matrix + mvn.sample()
The result is a non-normalized Gaussian funsor with two real inputs,
x_name
andy_name
, corresponding to a conditional distribution of real vectory` given real vector ``x
.- Parameters
matrix (torch.Tensor) – A matrix with rightmost shape
(x_size, y_size)
.mvn (torch.distributions.MultivariateNormal or torch.distributions.Independent of torch.distributions.Normal) – A multivariate normal distribution with
event_shape == (y_size,)
.event_dims (tuple) – A tuple of names for rightmost dimensions. These will be assigned to
result.inputs
of typeBint
.x_name (str) – The name of the
x
random variable.y_name (str) – The name of the
y
random variable.
- Returns
A funsor with given
real_inputs
and possibly additional Bint inputs.- Return type