Source code for funsor.pyro.hmm

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict

import torch

import funsor.ops as ops
from funsor.domains import Bint, Reals
from funsor.interpretations import eager, lazy, moment_matching
from funsor.pyro.convert import (
    dist_to_funsor,
    funsor_to_cat_and_mvn,
    funsor_to_tensor,
    matrix_and_mvn_to_funsor,
    mvn_to_funsor,
    tensor_to_funsor,
)
from funsor.pyro.distribution import FunsorDistribution
from funsor.sum_product import (
    MarkovProduct,
    naive_sequential_sum_product,
    sequential_sum_product,
)
from funsor.terms import Variable
from funsor.util import broadcast_shape


[docs]class DiscreteHMM(FunsorDistribution): r""" Hidden Markov Model with discrete latent state and arbitrary observation distribution. This uses [1] to parallelize over time, achieving O(log(time)) parallel complexity. The event_shape of this distribution includes time on the left:: event_shape = (num_steps,) + observation_dist.event_shape This distribution supports any combination of homogeneous/heterogeneous time dependency of ``transition_logits`` and ``observation_dist``. However, because time is included in this distribution's event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with ``num_steps = 1``, allowing :meth:`log_prob` to work with arbitrary length data:: # homogeneous + homogeneous case: event_shape = (1,) + observation_dist.event_shape This class should be interchangeable with :class:`pyro.distributions.hmm.DiscreteHMM` . **References:** [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019) "Temporal Parallelization of Bayesian Filters and Smoothers" https://arxiv.org/pdf/1905.13002.pdf :param ~torch.Tensor initial_logits: A logits tensor for an initial categorical distribution over latent states. Should have rightmost size ``state_dim`` and be broadcastable to ``batch_shape + (state_dim,)``. :param ~torch.Tensor transition_logits: A logits tensor for transition conditional distributions between latent states. Should have rightmost shape ``(state_dim, state_dim)`` (old, new), and be broadcastable to ``batch_shape + (num_steps, state_dim, state_dim)``. :param ~torch.distributions.Distribution observation_dist: A conditional distribution of observed data conditioned on latent state. The ``.batch_shape`` should have rightmost size ``state_dim`` and be broadcastable to ``batch_shape + (num_steps, state_dim)``. The ``.event_shape`` may be arbitrary. """ def __init__( self, initial_logits, transition_logits, observation_dist, validate_args=None ): assert isinstance(initial_logits, torch.Tensor) assert isinstance(transition_logits, torch.Tensor) assert isinstance(observation_dist, torch.distributions.Distribution) assert initial_logits.dim() >= 1 assert transition_logits.dim() >= 2 assert len(observation_dist.batch_shape) >= 1 shape = broadcast_shape( initial_logits.shape[:-1] + (1,), transition_logits.shape[:-2], observation_dist.batch_shape[:-1], ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self._has_rsample = observation_dist.has_rsample # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp(-1, True) # Convert tensors and distributions to funsors. init = tensor_to_funsor(initial_logits, ("state",)) trans = tensor_to_funsor(transition_logits, ("time", "state", "state(time=1)")) obs = dist_to_funsor(observation_dist, ("time", "state(time=1)")) dtype = obs.inputs["value"].dtype # Construct the joint funsor. with lazy: # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(DiscreteHMM, self).__init__( funsor_dist, batch_shape, event_shape, dtype, validate_args ) @torch.distributions.constraints.dependent_property def has_rsample(self): return self._has_rsample # TODO remove this once self.funsor_dist is defined.
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) ndims = max(len(self.batch_shape), value.dim() - self.event_dim) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor( value, ("time",), event_output=self.event_dim - 1, dtype=self.dtype ) # Compare with pyro.distributions.hmm.DiscreteHMM.log_prob(). obs = self._obs(value=value) result = self._trans + obs result = sequential_sum_product( ops.logaddexp, ops.add, result, time, {"state": "state(time=1)"} ) result = self._init + result.reduce(ops.logaddexp, "state(time=1)") result = result.reduce(ops.logaddexp, "state") result = funsor_to_tensor(result, ndims=ndims) return result
# TODO remove this once self.funsor_dist is defined. def _sample_delta(self, sample_shape): raise NotImplementedError("TODO")
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(DiscreteHMM, _instance) batch_shape = torch.Size(batch_shape) new._has_rsample = self._has_rsample new._init = self._init + tensor_to_funsor(torch.zeros(batch_shape)) new._trans = self._trans new._obs = self._obs super(DiscreteHMM, new).__init__( self.funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False, ) new.validate_args = self.__dict__.get("_validate_args") return new
[docs]class GaussianHMM(FunsorDistribution): r""" Hidden Markov Model with Gaussians for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure :meth:`log_prob` is differentiable. This corresponds to the generative model:: z = initial_distribution.sample() x = [] for t in range(num_steps): z = z @ transition_matrix + transition_dist.sample() x.append(z @ observation_matrix + observation_dist.sample()) The event_shape of this distribution includes time on the left:: event_shape = (num_steps,) + observation_dist.event_shape This distribution supports any combination of homogeneous/heterogeneous time dependency of ``transition_dist`` and ``observation_dist``. However, because time is included in this distribution's event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with ``num_steps = 1``, allowing :meth:`log_prob` to work with arbitrary length data:: event_shape = (1, obs_dim) # homogeneous + homogeneous case This class should be compatible with :class:`pyro.distributions.hmm.GaussianHMM` , but additionally supports funsor :mod:`~funsor.adjoint` algorithms. **References:** [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019) "Temporal Parallelization of Bayesian Filters and Smoothers" https://arxiv.org/pdf/1905.13002.pdf :ivar int hidden_dim: The dimension of the hidden state. :ivar int obs_dim: The dimension of the observed state. :param ~torch.distributions.MultivariateNormal initial_dist: A distribution over initial states. This should have batch_shape broadcastable to ``self.batch_shape``. This should have event_shape ``(hidden_dim,)``. :param ~torch.Tensor transition_matrix: A linear transformation of hidden state. This should have shape broadcastable to ``self.batch_shape + (num_steps, hidden_dim, hidden_dim)`` where the rightmost dims are ordered ``(old, new)``. :param ~torch.distributions.MultivariateNormal transition_dist: A process noise distribution. This should have batch_shape broadcastable to ``self.batch_shape + (num_steps,)``. This should have event_shape ``(hidden_dim,)``. :param ~torch.Tensor transition_matrix: A linear transformation from hidden to observed state. This should have shape broadcastable to ``self.batch_shape + (num_steps, hidden_dim, obs_dim)``. :param observation_dist: An observation noise distribution. This should have batch_shape broadcastable to ``self.batch_shape + (num_steps,)``. This should have event_shape ``(obs_dim,)``. :type observation_dist: ~torch.distributions.MultivariateNormal or ~torch.distributions.Independent of ~torch.distributions.Normal """ has_rsample = True arg_constraints = {} def __init__( self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None, ): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_dist.event_shape == (hidden_dim,) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim,) assert observation_dist.event_shape == (obs_dim,) shape = broadcast_shape( initial_dist.batch_shape + (1,), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape, ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = matrix_and_mvn_to_funsor( transition_matrix, transition_dist, ("time",), "state", "state(time=1)" ) obs = matrix_and_mvn_to_funsor( observation_matrix, observation_dist, ("time",), "state(time=1)", "value" ) dtype = "real" # Construct the joint funsor. with lazy: value = Variable("value", Reals[time_shape[0], obs_dim]) result = trans + obs(value=value["time"]) result = MarkovProduct( ops.logaddexp, ops.add, result, "time", {"state": "state(time=1)"} ) result = init + result.reduce(ops.logaddexp, "state(time=1)") funsor_dist = result.reduce(ops.logaddexp, "state") super(GaussianHMM, self).__init__( funsor_dist, batch_shape, event_shape, dtype, validate_args ) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
[docs]class GaussianMRF(FunsorDistribution): r""" Temporal Markov Random Field with Gaussian factors for initial, transition, and observation distributions. This adapts [1] to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure :meth:`log_prob` is differentiable. The event_shape of this distribution includes time on the left:: event_shape = (num_steps,) + observation_dist.event_shape This distribution supports any combination of homogeneous/heterogeneous time dependency of ``transition_dist`` and ``observation_dist``. However, because time is included in this distribution's event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with ``num_steps = 1``, allowing :meth:`log_prob` to work with arbitrary length data:: event_shape = (1, obs_dim) # homogeneous + homogeneous case This class should be compatible with :class:`pyro.distributions.hmm.GaussianMRF` , but additionally supports funsor :mod:`~funsor.adjoint` algorithms. **References:** [1] Simo Sarkka, Angel F. Garcia-Fernandez (2019) "Temporal Parallelization of Bayesian Filters and Smoothers" https://arxiv.org/pdf/1905.13002.pdf :ivar int hidden_dim: The dimension of the hidden state. :ivar int obs_dim: The dimension of the observed state. :param ~torch.distributions.MultivariateNormal initial_dist: A distribution over initial states. This should have batch_shape broadcastable to ``self.batch_shape``. This should have event_shape ``(hidden_dim,)``. :param ~torch.distributions.MultivariateNormal transition_dist: A joint distribution factor over a pair of successive time steps. This should have batch_shape broadcastable to ``self.batch_shape + (num_steps,)``. This should have event_shape ``(hidden_dim + hidden_dim,)`` (old+new). :param ~torch.distributions.MultivariateNormal observation_dist: A joint distribution factor over a hidden and an observed state. This should have batch_shape broadcastable to ``self.batch_shape + (num_steps,)``. This should have event_shape ``(hidden_dim + obs_dim,)``. """ has_rsample = True def __init__( self, initial_dist, transition_dist, observation_dist, validate_args=None ): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape( initial_dist.batch_shape + (1,), transition_dist.batch_shape, observation_dist.batch_shape, ) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = mvn_to_funsor( transition_dist, ("time",), OrderedDict( [("state", Reals[hidden_dim]), ("state(time=1)", Reals[hidden_dim])] ), ) obs = mvn_to_funsor( observation_dist, ("time",), OrderedDict( [("state(time=1)", Reals[hidden_dim]), ("value", Reals[obs_dim])] ), ) # Construct the joint funsor. # Compare with pyro.distributions.hmm.GaussianMRF.log_prob(). with lazy: time = Variable("time", Bint[time_shape[0]]) value = Variable("value", Reals[time_shape[0], obs_dim]) logp_oh = trans + obs(value=value["time"]) logp_oh = MarkovProduct( ops.logaddexp, ops.add, logp_oh, time, {"state": "state(time=1)"} ) logp_oh += init logp_oh = logp_oh.reduce( ops.logaddexp, frozenset({"state", "state(time=1)"}) ) logp_h = trans + obs.reduce(ops.logaddexp, "value") logp_h = MarkovProduct( ops.logaddexp, ops.add, logp_h, time, {"state": "state(time=1)"} ) logp_h += init logp_h = logp_h.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) funsor_dist = logp_oh - logp_h dtype = "real" super(GaussianMRF, self).__init__( funsor_dist, batch_shape, event_shape, dtype, validate_args ) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
[docs]class SwitchingLinearHMM(FunsorDistribution): r""" Switching Linear Dynamical System represented as a Hidden Markov Model. This corresponds to the generative model:: z = Categorical(logits=initial_logits).sample() y = initial_mvn[z].sample() x = [] for t in range(num_steps): z = Categorical(logits=transition_logits[t, z]).sample() y = y @ transition_matrix[t, z] + transition_mvn[t, z].sample() x.append(y @ observation_matrix[t, z] + observation_mvn[t, z].sample()) Viewed as a dynamic Bayesian network:: z[t-1] ----> z[t] ---> z[t+1] Discrete latent class | \ | \ | \ | y[t-1] ----> y[t] ----> y[t+1] Gaussian latent state | / | / | / V / V / V / x[t-1] x[t] x[t+1] Gaussian observation Let ``class`` be the latent class, ``state`` be the latent multivariate normal state, and ``value`` be the observed multivariate normal value. :param ~torch.Tensor initial_logits: Represents ``p(class[0])``. :param ~torch.distributions.MultivariateNormal initial_mvn: Represents ``p(state[0] | class[0])``. :param ~torch.Tensor transition_logits: Represents ``p(class[t+1] | class[t])``. :param ~torch.Tensor transition_matrix: :param ~torch.distributions.MultivariateNormal transition_mvn: Together with ``transition_matrix``, this represents ``p(state[t], state[t+1] | class[t])``. :param ~torch.Tensor observation_matrix: :param ~torch.distributions.MultivariateNormal observation_mvn: Together with ``observation_matrix``, this represents ``p(value[t+1], state[t+1] | class[t+1])``. :param bool exact: If True, perform exact inference at cost exponential in ``num_steps``. If False, use a :func:`~funsor.terms.moment_matching` approximation and use parallel scan algorithm to reduce parallel complexity to logarithmic in ``num_steps``. Defaults to False. """ has_rsample = True arg_constraints = {} def __init__( self, initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None, ): assert isinstance(initial_logits, torch.Tensor) assert isinstance(initial_mvn, torch.distributions.MultivariateNormal) assert isinstance(transition_logits, torch.Tensor) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_mvn, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_mvn, torch.distributions.MultivariateNormal) hidden_cardinality = initial_logits.size(-1) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_mvn.event_shape[0] == hidden_dim assert transition_logits.size(-1) == hidden_cardinality assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_mvn.event_shape[0] == hidden_dim assert observation_mvn.event_shape[0] == obs_dim init_shape = broadcast_shape(initial_logits.shape, initial_mvn.batch_shape) shape = broadcast_shape( init_shape[:-1] + (1, init_shape[-1]), transition_logits.shape[:-1], transition_matrix.shape[:-2], transition_mvn.batch_shape, observation_matrix.shape[:-2], observation_mvn.batch_shape, ) assert shape[-1] == hidden_cardinality batch_shape, time_shape = shape[:-2], shape[-2:-1] event_shape = time_shape + (obs_dim,) # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp(-1, True) # Convert tensors and distributions to funsors. init = tensor_to_funsor(initial_logits, ("class",)) + dist_to_funsor( initial_mvn, ("class",) )(value="state") trans = tensor_to_funsor( transition_logits, ("time", "class", "class(time=1)") ) + matrix_and_mvn_to_funsor( transition_matrix, transition_mvn, ("time", "class(time=1)"), "state", "state(time=1)", ) obs = matrix_and_mvn_to_funsor( observation_matrix, observation_mvn, ("time", "class(time=1)"), "state(time=1)", "value", ) if "class(time=1)" not in set(trans.inputs).union(obs.inputs): raise ValueError( "neither transition nor observation depend on discrete state" ) dtype = "real" # Construct the joint funsor. with lazy: # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(SwitchingLinearHMM, self).__init__( funsor_dist, batch_shape, event_shape, dtype, validate_args ) self.exact = exact # TODO remove this once self.funsor_dist is defined.
[docs] def log_prob(self, value): ndims = max(len(self.batch_shape), value.dim() - 2) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor(value, ("time",), 1) seq_sum_prod = ( naive_sequential_sum_product if self.exact else sequential_sum_product ) with eager if self.exact else moment_matching: result = self._trans + self._obs(value=value) result = seq_sum_prod( ops.logaddexp, ops.add, result, time, {"class": "class(time=1)", "state": "state(time=1)"}, ) result += self._init result = result.reduce( ops.logaddexp, frozenset(["class", "state", "class(time=1)", "state(time=1)"]), ) result = funsor_to_tensor(result, ndims=ndims) return result
# TODO remove this once self.funsor_dist is defined. def _sample_delta(self, sample_shape): raise NotImplementedError("TODO")
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(SwitchingLinearHMM, _instance) batch_shape = torch.Size(batch_shape) new._init = self._init + tensor_to_funsor(torch.zeros(batch_shape)) new._trans = self._trans new._obs = self._obs new.exact = self.exact super(SwitchingLinearHMM, new).__init__( self.funsor_dist, batch_shape, self.event_shape, self.dtype, validate_args=False, ) new.validate_args = self.__dict__.get("_validate_args") return new
[docs] def filter(self, value): """ Compute posterior over final state given a sequence of observations. :param ~torch.Tensor value: A sequence of observations. :return: A posterior distribution over latent states at the final time step, represented as a pair ``(cat, mvn)``, where :class:`~pyro.distributions.Categorical` distribution over mixture components and ``mvn`` is a :class:`~pyro.distributions.MultivariateNormal` with rightmost batch dimension ranging over mixture components. This can then be used to initialize a sequential Pyro model for prediction. :rtype: tuple """ ndims = max(len(self.batch_shape), value.dim() - 2) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor(value, ("time",), 1) seq_sum_prod = ( naive_sequential_sum_product if self.exact else sequential_sum_product ) with eager if self.exact else moment_matching: logp = self._trans + self._obs(value=value) logp = seq_sum_prod( ops.logaddexp, ops.add, logp, time, {"class": "class(time=1)", "state": "state(time=1)"}, ) logp += self._init logp = logp.reduce(ops.logaddexp, frozenset(["class", "state"])) cat, mvn = funsor_to_cat_and_mvn(logp, ndims, ("class(time=1)",)) cat = cat.expand(self.batch_shape) mvn = mvn.expand(self.batch_shape + cat.logits.shape[-1:]) return cat, mvn