# Source code for funsor.gaussian

# Copyright Contributors to the Pyro project.

import math
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from functools import reduce

import funsor.ops as ops
from funsor.affine import affine_inputs, extract_affine, is_affine
from funsor.delta import Delta
from funsor.domains import Real, Reals
from funsor.interpretations import compress_gaussians
from funsor.tensor import Tensor, align_tensor, align_tensors
from funsor.terms import (
Binary,
Funsor,
FunsorMeta,
Number,
Slice,
Subs,
Variable,
eager,
reflect,
)
from funsor.util import broadcast_shape, get_tracing_state, lazy_property

def _log_det_tri(x):
return ops.log(ops.diagonal(x, -1, -2)).sum(-1)

def _vv(vec1, vec2):
"""
Computes the inner product < vec1 | vec 2 >.
"""
return (vec1[..., None, :] @ vec2[..., None])[..., 0, 0]

def _norm2(vec):
return _vv(vec, vec)

def _mv(mat, vec):
return (mat @ vec[..., None])[..., 0]

def _vm(vec, mat):
return (vec[..., None, :] @ mat)[..., 0, :]

def _mmt(mat1, mat2=None):
if mat2 is None:
mat2 = mat1
return mat1 @ ops.transpose(mat2, -1, -2)

def _mtm(mat1, mat2=None):
if mat2 is None:
mat2 = mat1
return ops.transpose(mat1, -1, -2) @ mat2

def _inverse_cholesky(P):
"""
Computes a Cholesky decomposition of the inverse of a posdef matrix.
"""
# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
Lf = ops.cholesky(ops.flip(P, (-2, -1)))
L_inv = ops.transpose(ops.flip(Lf, (-2, -1)), -2, -1)
L = ops.triangular_inv(L_inv)
return L

def _compress_rank(white_vec, prec_sqrt, assume_full_rank=False):
"""
Compress a wide representation (white_vec, prec_sqrt) while preserving
the quadratic function ||x @ prec_sqrt - white_vec||^2 + const.
"""
dim, rank = prec_sqrt.shape[-2:]
assert rank >= dim
old_norm2 = _norm2(white_vec)

# Let P = prec_sqrt and w = white_vec define the original Gaussian
#
#   G(x;w,P) = -1/2 || x P - w ||^2
#            = -1/2 x P P' x' + x P w' -1/2 w w'
#
# We seek a compressed Gaussian G(x;wc,Pc) and constant C such that
#
#   G(x;w,P) = G(x;wc,Pc) + C
#            = -1/2 x Pc Pc' x' + x Pc wc' -1/2 wc wc' + C
if assume_full_rank:
# Cholesky factorizing Pc = chol(P P'), we match remaining coefficients
#
#    Pc wc' = P w'  ==>  wc' = Pc \ P w'
info_vec_ = prec_sqrt @ white_vec[..., None]
precision = prec_sqrt @ ops.transpose(prec_sqrt, -1, -2)
prec_sqrt = ops.cholesky(precision)
white_vec = ops.triangular_solve(info_vec_, prec_sqrt)[..., 0]
else:
# Computing a reduced QR representation of P' of shape (rank,dim)
#
#   P' = Q [ R ]     P = [ R'  0 ] Q'
#          [ 0 ]
#
# where Q is orthogonal and R is upper triangular of shape (dim,dim).
# Then splitting along the new dimension,
#
#   G(x;w,P) = -1/2 || x [R' 0] Q' - w ||^2
#            = -1/2 || x [R' 0] - w Q ||^2
#            = -1/2 || x R' - (w Q)[...,:dim] ||^2
#              -1/2 || (w Q)[...,dim:] ||^2
#            =: G(x;wc,Pc) + C
# where
#
#   wc = (w Q)[...,:dim]
#   Pc = R'
Q, R = ops.qr(ops.transpose(prec_sqrt, -1, -2))
assert Q.shape[-2:] == (rank, dim)  # note only part of Q is returned
assert R.shape[-2:] == (dim, dim)
prec_sqrt = ops.transpose(R, -1, -2)
white_vec = _vm(white_vec, Q)
# Finally the shifting constant is
#
#    C = 1/2 (wc wc' - w w')
new_norm2 = _norm2(white_vec)
shift = 0.5 * (new_norm2 - old_norm2)
return white_vec, prec_sqrt, shift

def _compute_offsets(inputs):
"""
Compute offsets of real inputs into the concatenated Gaussian dims.
This ignores all int inputs.

:param OrderedDict inputs: A schema mapping variable name to domain.
:return: a pair (offsets, total), where offsets is an OrderedDict
mapping input name to integer offset, and total is the total event
size.
:rtype: tuple
"""
assert isinstance(inputs, OrderedDict)
offsets = OrderedDict()
total = 0
for key, domain in inputs.items():
if domain.dtype == "real":
offsets[key] = total
total += domain.num_elements
return offsets, total

def _split_real_inputs(inputs, lhs_keys, prototype):
"""
Finds a splitting set of indices (lhs, rhs) into the flat real
dimension such that lhs indexes into real inputs in lhs_keys and
rhs indexes into everything else.
"""
lhs_blocks = []
rhs_blocks = []
start = 0
for key, domain in inputs.items():
if domain.dtype == "real":
stop = start + domain.num_elements
(lhs_blocks if key in lhs_keys else rhs_blocks).append(slice(start, stop))
start = stop

# There are three cases: lhs left of rhs (cheap slices), lhs right of rhs
# (cheap slices), and interleaved (expensive advanced indexing tensors).
lhs_start = min(b.start for b in lhs_blocks)
rhs_start = min(b.start for b in rhs_blocks)
lhs_stop = max(b.stop for b in lhs_blocks)
rhs_stop = max(b.stop for b in rhs_blocks)
if lhs_stop <= rhs_start or rhs_stop <= lhs_start:
# Construct cheap slices.
lhs = slice(lhs_start, lhs_stop)
rhs = slice(rhs_start, rhs_stop)
return lhs, rhs

# Construct interleaving indices.
lhs = ops.cat([ops.new_arange(prototype, b.start, b.stop) for b in lhs_blocks])
rhs = ops.cat([ops.new_arange(prototype, b.start, b.stop) for b in rhs_blocks])
return lhs, rhs

def _find_intervals(intervals, end):
"""
Finds a complete set of intervals partitioning [0, end), given a partial
set of non-overlapping intervals.
"""
cuts = list(sorted({0, end}.union(*intervals)))
return list(zip(cuts[:-1], cuts[1:]))

def _parse_slices(index, value):
if not isinstance(index, tuple):
index = (index,)
if index is Ellipsis:
index = index[1:]
start_stops = []
for pos, i in reversed(list(enumerate(index))):
if isinstance(i, slice):
start_stops.append((i.start, i.stop))
elif isinstance(i, int):
start_stops.append((i, i + 1))
value = ops.unsqueeze(value, pos - len(index))
else:
raise ValueError("invalid index: {}".format(i))
start_stops.reverse()
return start_stops, value

[docs]class BlockVector(object):
"""
Jit-compatible helper to build blockwise vectors.
Syntax is similar to :func:torch.zeros ::

x = BlockVector((100, 20))
x[..., 0:4] = x1
x[..., 6:10] = x2
x = x.as_tensor()
assert x.shape == (100, 20)
"""

def __init__(self, shape):
self.shape = shape
self.parts = {}

def __setitem__(self, index, value):
(i,), value = _parse_slices(index, value)
self.parts[i] = value

[docs]    def as_tensor(self):
# TODO optimize this to use backend-specific block setters:
# .__setitem__ for numpy and torch; .at(...).set(...) for jax.

# Fill gaps with zeros.
prototype = next(iter(self.parts.values()))
for i in _find_intervals(self.parts.keys(), self.shape[-1]):
if i not in self.parts:
self.parts[i] = ops.new_zeros(
prototype, self.shape[:-1] + (i - i,)
)

# Concatenate parts.
parts = [v for k, v in sorted(self.parts.items())]
result = ops.cat(parts, -1)
if not get_tracing_state():
assert result.shape == self.shape
return result

[docs]class BlockMatrix(object):
"""
Jit-compatible helper to build blockwise matrices.
Syntax is similar to :func:torch.zeros ::

x = BlockMatrix((100, 20, 20))
x[..., 0:4, 0:4] = x11
x[..., 0:4, 6:10] = x12
x[..., 6:10, 0:4] = x12.transpose(-1, -2)
x[..., 6:10, 6:10] = x22
x = x.as_tensor()
assert x.shape == (100, 20, 20)
"""

def __init__(self, shape):
self.shape = shape
self.parts = defaultdict(dict)

def __setitem__(self, index, value):
(i, j), value = _parse_slices(index, value)
self.parts[i][j] = value

[docs]    def as_tensor(self):
# TODO optimize this to use backend-specific block setters:
# .__setitem__ for numpy and torch; .at(...).set(...) for jax.

# Fill gaps with zeros.
arbitrary_row = next(iter(self.parts.values()))
prototype = next(iter(arbitrary_row.values()))
js = set().union(*(part.keys() for part in self.parts.values()))
rows = _find_intervals(self.parts.keys(), self.shape[-2])
cols = _find_intervals(js, self.shape[-1])
for i in rows:
for j in cols:
if j not in self.parts[i]:
shape = self.shape[:-2] + (i - i, j - j)
self.parts[i][j] = ops.new_zeros(prototype, shape)

# Concatenate parts.
columns = {
i: ops.cat([v for j, v in sorted(part.items())], -1)
for i, part in self.parts.items()
}
result = ops.cat([v for i, v in sorted(columns.items())], -2)
if not get_tracing_state():
assert result.shape == self.shape
return result

[docs]def align_gaussian(new_inputs, old, expand=False):
"""
Align data of a Gaussian distribution to a new inputs shape.
"""
assert isinstance(new_inputs, OrderedDict)
assert isinstance(old, Gaussian)
white_vec = old.white_vec
prec_sqrt = old.prec_sqrt

# Align int inputs.
# Since these are are managed as in Tensor, we can defer to align_tensor().
new_ints = OrderedDict((k, d) for k, d in new_inputs.items() if d.dtype != "real")
old_ints = OrderedDict((k, d) for k, d in old.inputs.items() if d.dtype != "real")
if new_ints != old_ints:
white_vec = align_tensor(new_ints, Tensor(white_vec, old_ints), expand=expand)
prec_sqrt = align_tensor(new_ints, Tensor(prec_sqrt, old_ints), expand=expand)

# Align real inputs, which are all concatenated in the rightmost dims.
new_offsets, new_dim = _compute_offsets(new_inputs)
old_offsets, old_dim = _compute_offsets(old.inputs)
assert prec_sqrt.shape[-2:-1] == (old_dim,)
if new_offsets != old_offsets:
old_prec_sqrt = ops.transpose(prec_sqrt, -1, -2)
prec_sqrt = BlockVector(old_prec_sqrt.shape[:-1] + (new_dim,))
for k, new_offset in new_offsets.items():
if k not in old_offsets:
continue
offset = old_offsets[k]
num_elements = old.inputs[k].num_elements
old_slice = slice(offset, offset + num_elements)
new_slice = slice(new_offset, new_offset + num_elements)
prec_sqrt[..., new_slice] = old_prec_sqrt[..., old_slice]
prec_sqrt = prec_sqrt.as_tensor()
prec_sqrt = ops.transpose(prec_sqrt, -1, -2)

return white_vec, prec_sqrt

class GaussianMeta(FunsorMeta):
"""
Wrapper to convert from external to internal compressed representation.

This may return either a Gaussian or a Gaussian + Tensor, where the Tensor
represents byproducts of compression.
"""

def __call__(
cls,
white_vec=None,
prec_sqrt=None,
inputs=None,
*,
mean=None,
info_vec=None,
precision=None,
scale_tril=None,
covariance=None,
):
# Convert inputs.
assert inputs is not None
if isinstance(inputs, OrderedDict):
inputs = tuple(inputs.items())
assert isinstance(inputs, tuple)

# Convert scale parameter to prec_sqrt.
if prec_sqrt is None and white_vec is not None:
raise ValueError("Cannot specify white_vec without prec_sqrt")
if prec_sqrt is not None:
is_tril = False
elif precision is not None:
prec_sqrt = ops.cholesky(precision)
is_tril = True
elif covariance is not None:
prec_sqrt = _inverse_cholesky(covariance)
is_tril = True
elif scale_tril is not None:
prec_sqrt = ops.transpose(ops.triangular_inv(scale_tril), -1, -2)
is_tril = False
else:
raise ValueError(
"At least one of prec_sqrt, precision, scale_tril, or covariance "
"must be specified"
)

# Convert location parameter to white_vec.
if white_vec is not None:
pass
elif mean is not None:
white_vec = _vm(mean, prec_sqrt)
elif info_vec is not None:
if not is_tril:
prec_sqrt = ops.cholesky(_mmt(prec_sqrt))  # triangularize
is_tril = True
white_vec = ops.triangular_solve(info_vec[..., None], prec_sqrt)[..., 0]
else:
raise ValueError(
"At least one of white_vec, mean, or info_vec must be specified"
)

# Compress wide representations.
shift = None
dim, rank = prec_sqrt.shape[-2:]
if rank > dim * cls.compression_threshold:
white_vec, prec_sqrt, shift = _compress_rank(white_vec, prec_sqrt)

# Create a Gaussian.
result = super().__call__(white_vec, prec_sqrt, inputs)

if shift is not None:
int_inputs = OrderedDict((k, v) for k, v in inputs if v.dtype != "real")
result += Tensor(shift, int_inputs)

return result

[docs]class Gaussian(Funsor, metaclass=GaussianMeta):
r"""
Funsor representing a batched Gaussian log-density function.

Gaussians are the internal representation for joint and conditional
multivariate normal distributions and multivariate normal likelihoods.
Mathematically, a Gaussian represents the quadratic log density function::

f(x) = -0.5 * || x @ prec_sqrt - white_vec ||^2
= -0.5 * < x @ prec_sqrt - white_vec | x @ prec_sqrt - white_vec >
= -0.5 * < x | prec_sqrt @ prec_sqrt.T | x>
+ < x | prec_sqrt | white_vec > - 0.5 ||white_vec||^2

Internally Gaussians use a square root information filter (SRIF)
representation consisting of a square root of the precision matrix
prec_sqrt and a vector in the whitened space white_vec. This
representation allows space-efficient construction of Gaussians with
incomplete information, i.e. with zero eigenvalues in the precision matrix.
These incomplete log densities arise when making low-dimensional
observations of higher-dimensional hidden state. Sampling and
marginalization are supported only for full-rank Gaussians or full-rank
subsets of Gaussians. See the :meth:rank and :meth:is_full_rank
properties.

.. note:: :class:Gaussian s are not normalized probability distributions,
rather they are canonicalized to evaluate to zero log density at their
maximum: f(prec_sqrt \ white_vec) = 0. Not only are Gaussians
non-normalized, but they may be rank deficient and non-normalizable, in
which case sampling and marginalization are supported only un full-rank
subsets of variables.

:param torch.Tensor white_vec: An batched white noise vector, where
white_vec = prec_sqrt.T @ mean. Alternatively you can specify one
of the kwargs mean or info_vec, which will be converted to
white_vec.
:param torch.Tensor prec_sqrt: A batched square root of the positive
semidefinite precision matrix. This need not be square, and typically
has shape prec_sqrt.shape == white_vec.shape[:-1] + (dim, rank),
where dim is the total flattened size of real inputs and
rank = white_vec.shape[-1].  Alternatively you can specify one of
the kwargs precision, covariance, or scale_tril, which will
be converted to prec_sqrt.
:param OrderedDict inputs: Mapping from name to
:class:~funsor.domains.Domain .
"""

compression_threshold = 2

def __init__(self, white_vec, prec_sqrt, inputs):
assert ops.is_numeric_array(white_vec) and ops.is_numeric_array(prec_sqrt)
assert isinstance(inputs, tuple)
inputs = OrderedDict(inputs)

# Compute total dimension of all real inputs.
dim = sum(d.num_elements for d in inputs.values() if d.dtype == "real")
if not get_tracing_state():
assert dim
assert len(prec_sqrt.shape) >= 2 and prec_sqrt.shape[-2] == dim
rank = prec_sqrt.shape[-1]
assert len(white_vec.shape) >= 1 and white_vec.shape[-1] == rank
# This should be true but weirdly fails in pytest tests that use
# set_compression_threshold(large_value).
# assert rank <= dim * self.compression_threshold

# Compute total shape of all Bint inputs.
batch_shape = tuple(
d.dtype for d in inputs.values() if isinstance(d.dtype, int)
)
if not get_tracing_state():
assert prec_sqrt.shape[:-2] == batch_shape
assert white_vec.shape[:-1] == batch_shape

output = Real
fresh = frozenset(inputs.keys())
bound = {}
super().__init__(inputs, output, fresh, bound)
self.white_vec = white_vec
self.prec_sqrt = prec_sqrt
self.batch_shape = batch_shape
self.event_shape = (dim,)

[docs]    @classmethod
@contextmanager
def set_compression_threshold(cls, threshold: float):
"""
Context manager to set rank compression threshold.

To save space Gaussians compress wide prec_sqrt matrices down to
square. However compression uses a QR decomposition which can be
expensive and which has unstable gradients when the resulting precision
matrix is rank deficient. To balance space and time costs and numerical
stability, compression is trigger only on prec_sqrt matrices whose
width to height ratio is greater than threshold.

:param float threshold: Defaults to 2. To optimize for space and
deterministic computations, set threshold = 1. To optimize for
fewest QR decompositions and numerical stability, set threshold =
math.inf.
"""
assert isinstance(threshold, (int, float))
assert threshold >= 1
old = cls.compression_threshold
try:
cls.compression_threshold = threshold
yield
finally:
cls.compression_threshold = old

def __repr__(self):
return "Gaussian(..., ({}))".format(
" ".join("({}, {}),".format(*kv) for kv in self.inputs.items())
)

@property
def rank(self):
return self.prec_sqrt.shape[-1]

@property
def is_full_rank(self):
dim, rank = self.prec_sqrt.shape[-2:]
return rank >= dim

# TODO Consider weak-memoizing these so they persist through alpha conversion.
# https://github.com/pyro-ppl/pyro/blob/ac3c588/pyro/distributions/coalescent.py#L412
@lazy_property
def _precision(self):
return self.prec_sqrt @ ops.transpose(self.prec_sqrt, -1, -2)

@lazy_property
def _precision_chol(self):
# Note self.prec_sqrt may be neither lower triangular nor square.
assert self.is_full_rank
return ops.cholesky(self._precision)

@lazy_property
def _covariance(self):
return ops.cholesky_inverse(self._precision_chol)

@lazy_property
def _scale_tril(self):
return _inverse_cholesky(self._precision)

@lazy_property
def _mean(self):
return ops.cholesky_solve(self._info_vec[..., None], self._precision_chol)[
..., 0
]

@lazy_property
def _info_vec(self):
return _mv(self.prec_sqrt, self.white_vec)

@lazy_property
def _log_normalizer(self):
dim = self.prec_sqrt.shape[-2]
log_det_term = _log_det_tri(self._precision_chol)
result = 0.5 * dim * math.log(2 * math.pi) - log_det_term
if self.rank == dim:
return result
# Shift, as in logic in _compress_rank().
old_norm2 = _norm2(self.white_vec)
white_vec = ops.triangular_solve(
self._info_vec[..., None], self._precision_chol
)[..., 0]
new_norm2 = _norm2(white_vec)
shift = 0.5 * (new_norm2 - old_norm2)
return result + shift

[docs]    @lazy_property
def log_normalizer(self):
inputs = OrderedDict(
(k, v) for k, v in self.inputs.items() if v.dtype != "real"
)
return Tensor(self._log_normalizer, inputs)

[docs]    def align(self, names):
assert isinstance(names, tuple)
assert all(name in self.inputs for name in names)
if not names or names == tuple(self.inputs):
return self

inputs = OrderedDict((name, self.inputs[name]) for name in names)
inputs.update(self.inputs)
white_vec, prec_sqrt = align_gaussian(inputs, self)
return Gaussian(white_vec, prec_sqrt, inputs)

[docs]    def eager_subs(self, subs):
assert isinstance(subs, tuple)
prototype = Tensor(self.white_vec)
subs = tuple(
(k, v if isinstance(v, (Variable, Slice)) else prototype.materialize(v))
for k, v in subs
if k in self.inputs
)
if not subs:
return self

# Constants and Affine funsors are eagerly substituted;
# everything else is lazily substituted.
lazy_subs = tuple(
(k, v)
for k, v in subs
if not isinstance(v, (Number, Tensor, Variable, Slice))
and not (is_affine(v) and affine_inputs(v))
)
var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
int_subs = tuple(
(k, v)
for k, v in subs
if isinstance(v, (Number, Tensor, Slice))
if v.dtype != "real"
)
real_subs = tuple(
(k, v)
for k, v in subs
if isinstance(v, (Number, Tensor))
if v.dtype == "real"
)
affine_subs = tuple(
(k, v)
for k, v in subs
if is_affine(v) and affine_inputs(v) and not isinstance(v, Variable)
)
if var_subs:
return self._eager_subs_var(
var_subs, int_subs + real_subs + affine_subs + lazy_subs
)
if int_subs:
return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs)
if real_subs:
return self._eager_subs_real(real_subs, affine_subs + lazy_subs)
if affine_subs:
return self._eager_subs_affine(affine_subs, lazy_subs)
return reflect.interpret(Subs, self, lazy_subs)

def _eager_subs_var(self, subs, remaining_subs):
# Perform variable substitution, i.e. renaming of inputs.
rename = {k: v.name for k, v in subs}
inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items())
if len(inputs) != len(self.inputs):
raise ValueError("Variable substitution name conflict")
var_result = Gaussian(self.white_vec, self.prec_sqrt, inputs)
return Subs(var_result, remaining_subs) if remaining_subs else var_result

def _eager_subs_int(self, subs, remaining_subs):
# Perform integer substitution, i.e. slicing into a batch.
int_inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if d.dtype != "real"
)
real_inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if d.dtype == "real"
)
tensors = [self.white_vec, self.prec_sqrt]
funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors]
inputs = funsors.inputs.copy()
inputs.update(real_inputs)
int_result = Gaussian(funsors.data, funsors.data, inputs)
return Subs(int_result, remaining_subs) if remaining_subs else int_result

def _eager_subs_real(self, subs, remaining_subs):
subs = OrderedDict(subs)
int_inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if d.dtype != "real"
)
tensors = [
Tensor(self.white_vec, int_inputs),
Tensor(self.prec_sqrt, int_inputs),
]
tensors.extend(subs.values())
int_inputs, tensors = align_tensors(*tensors)
batch_dim = len(tensors.shape) - 1
batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
(white_vec, prec_sqrt), values = tensors[:2], tensors[2:]
offsets, event_size = _compute_offsets(self.inputs)
slices = [
(k, slice(offset, offset + self.inputs[k].num_elements))
for k, offset in offsets.items()
]

# Expand all substituted values.
values = OrderedDict(zip(subs, values))
for k, value in values.items():
value = value.reshape(value.shape[:batch_dim] + (-1,))
if not get_tracing_state():
assert value.shape[-1] == self.inputs[k].num_elements
values[k] = ops.expand(value, batch_shape + value.shape[-1:])

# Try to perform a complete substitution of all real variables,
# resulting in a Tensor.
if all(k in subs for k, d in self.inputs.items() if d.dtype == "real"):
# Form the concatenated value.
value = BlockVector(batch_shape + (event_size,))
for k, i in slices:
if k in values:
value[..., i] = values[k]
value = value.as_tensor()

# Evaluate the non-normalized log density.
result = -0.5 * _norm2(_vm(value, prec_sqrt) - white_vec)
result = Tensor(result, int_inputs)
assert result.output == Real
return Subs(result, remaining_subs) if remaining_subs else result

# Perform a partial substution of a subset of real variables, resulting
# in a Joint. We split real inputs into two sets: a for the preserved
# and b for the substituted.
#   G([xa xb]; w, [ Pa ]) = -0.5 || xa Pa + xb Pb - w||2
#                 [ Pb ]
#                         = G(xa; w - xb Pb, Pa)
# where  wa := w - xb Pb  is the new white_vec.
b = frozenset(k for k, v in subs.items())
a = frozenset(
k for k, d in self.inputs.items() if d.dtype == "real" and k not in b
)
prec_sqrt_a = ops.cat([prec_sqrt[..., i, :] for k, i in slices if k in a], -2)
prec_sqrt_b = ops.cat([prec_sqrt[..., i, :] for k, i in slices if k in b], -2)
value_b = ops.cat([values[k] for k, i in slices if k in b], -1)
white_vec_a = white_vec - _vm(value_b, prec_sqrt_b)
prec_sqrt_a = ops.expand(prec_sqrt_a, white_vec_a.shape[:-1] + (-1, -1))
inputs = int_inputs.copy()
for k, d in self.inputs.items():
if k not in subs:
inputs[k] = d
result = Gaussian(white_vec_a, prec_sqrt_a, inputs)
return Subs(result, remaining_subs) if remaining_subs else result

def _eager_subs_affine(self, subs, remaining_subs):
# Extract an affine representation.
affine = OrderedDict()
for k, v in subs:
const, coeffs = extract_affine(v)
if isinstance(const, Tensor) and all(
isinstance(coeff, Tensor) for coeff, _ in coeffs.values()
):
affine[k] = const, coeffs
else:
remaining_subs += ((k, v),)
if not affine:
return reflect.interpret(Subs, self, remaining_subs)

# Align integer dimensions.
old_int_inputs = OrderedDict(
(k, v) for k, v in self.inputs.items() if v.dtype != "real"
)
tensors = [
Tensor(self.white_vec, old_int_inputs),
Tensor(self.prec_sqrt, old_int_inputs),
]
for const, coeffs in affine.values():
tensors.append(const)
tensors.extend(coeff for coeff, _ in coeffs.values())
new_int_inputs, tensors = align_tensors(*tensors, expand=True)
tensors = (Tensor(x, new_int_inputs) for x in tensors)
white_vec = next(tensors).data
prec_sqrt = next(tensors).data
for old_k, (const, coeffs) in affine.items():
const = next(tensors)
for new_k, (coeff, eqn) in coeffs.items():
coeff = next(tensors)
coeffs[new_k] = coeff, eqn
affine[old_k] = const, coeffs
batch_shape = white_vec.shape[:-1]

# Align real dimensions.
old_real_inputs = OrderedDict(
(k, v) for k, v in self.inputs.items() if v.dtype == "real"
)
new_real_inputs = old_real_inputs.copy()
for old_k, (const, coeffs) in affine.items():
del new_real_inputs[old_k]
for new_k, (coeff, eqn) in coeffs.items():
new_shape = coeff.shape[: len(eqn.split("->").split(","))]
new_real_inputs[new_k] = Reals[new_shape]
old_offsets, old_dim = _compute_offsets(old_real_inputs)
new_offsets, new_dim = _compute_offsets(new_real_inputs)
new_inputs = new_int_inputs.copy()
new_inputs.update(new_real_inputs)

# Construct a blockwise affine representation of the substitution.
subs_vector = BlockVector(batch_shape + (old_dim,))
subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
for old_k, old_offset in old_offsets.items():
old_size = old_real_inputs[old_k].num_elements
old_slice = slice(old_offset, old_offset + old_size)
if old_k in new_real_inputs:
new_offset = new_offsets[old_k]
new_slice = slice(new_offset, new_offset + old_size)
subs_matrix[..., new_slice, old_slice] = ops.new_eye(
self.white_vec, batch_shape + (old_size,)
)
continue
const, coeffs = affine[old_k]
old_shape = old_real_inputs[old_k].shape
assert const.data.shape == batch_shape + old_shape
subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size,))
for new_k, new_offset in new_offsets.items():
if new_k in coeffs:
coeff, eqn = coeffs[new_k]
new_size = new_real_inputs[new_k].num_elements
new_slice = slice(new_offset, new_offset + new_size)
assert coeff.shape == new_real_inputs[new_k].shape + old_shape
subs_matrix[..., new_slice, old_slice] = coeff.data.reshape(
batch_shape + (new_size, old_size)
)
subs_vector = subs_vector.as_tensor()
subs_matrix = subs_matrix.as_tensor()

# Construct the new Gaussian. Suppose the old Gaussian funsor g has density
#   G(x;w,P) = -1/2 || x P - w ||^2
# Now define a new Gaussian by substituting x = A y + b:
#   G(y;w',P') = G(y A + b; w, P)
#              = -1/2 || (y A + b) P - w ||^2
#              = -1/2 || y A P - w + b P ||^2
#              =: -1/2 || y P' - w' ||^2
#              = G(y; w - b P, A P)
# where  P' = A P  and  w' = w - b P  parametrize the new Gaussian.
white_vec = white_vec - _vm(subs_vector, prec_sqrt)
prec_sqrt = subs_matrix @ prec_sqrt
result = Gaussian(white_vec, prec_sqrt, new_inputs)
return Subs(result, remaining_subs) if remaining_subs else result

[docs]    def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.inputs)
# Marginalize out real variables, but keep mixtures lazy.
assert all(v in self.inputs for v in reduced_vars)
real_vars = frozenset(
k for k, d in self.inputs.items() if d.dtype == "real"
)
reduced_reals = reduced_vars & real_vars
reduced_ints = reduced_vars - real_vars
if not reduced_reals:
return None  # defer to default implementation
if reduced_reals == real_vars:

inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if k not in reduced_reals
)
int_inputs = OrderedDict(
(k, v) for k, v in inputs.items() if v.dtype != "real"
)

# Let x = [xa xb] where xb will be marginalized out and will xa
# remain. Following the formula in _compress_rank, we can rewrite
# the joint Gaussian as a Gaussian in xb plus a term C that does
# not depend on xb:
#
#                                  [ Pb ]
#     =: G(xb; wb, Qb).log_normalizer + C
#
# In normalizable models, rank >= dim(xb), so we can choose Qb to
# be the Cholesky square root, making it easy to compute a
# determinant and solve for wb.
#
#    Qb = chol(Pb Pb')
#   wb' = Qb \ Pb (w - xa Pa)'
#
# Next we match moments of C to a Gaussian in xa:
#
#   C = 1/2 (wb wb' - w w')  # from _compress_rank
#     = 1/2 (xa Pa - w) Pb' inv(Qb Qb') Pb (xa Pa - w)'
#     - 1/2 (xa Pa - w) (xa Pa - w)'
#     =: G(xa; wa, Qa)
#
# whence  Qa = Pa S  and  wa = w S, where S is a square root of the
# rank-by-rank projection matrix (S = S S' by idempotence):
#
#   S S' = I - Pb' inv(Pb Pb') Pb = I - (Qb\Pb)' (Qb\Pb) = S
#
# Note if rank == dim(xb), then the projection matrix is zero,
# and the Gaussian G(xa; wa, Qa) is zero can be dropped.
b, a = _split_real_inputs(self.inputs, reduced_vars, self.white_vec)
prec_sqrt_a = self.prec_sqrt[..., a, :]
prec_sqrt_b = self.prec_sqrt[..., b, :]
dim_b = prec_sqrt_b.shape[-2]
if self.rank < dim_b:
raise ValueError(
f"Too little information to marginalize over {set(reduced_vars)}. "
)
precision_chol_b = ops.cholesky(_mmt(prec_sqrt_b))  # assume full rank
result = self._marginalize_after_split(
inputs, int_inputs, prec_sqrt_b, prec_sqrt_a, precision_chol_b
)

# Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian().
inputs = OrderedDict()
old_ints = OrderedDict()
new_ints = OrderedDict()
kept_perm = []
reduced_perm = []
for i, (k, v) in enumerate(self.inputs.items()):
if k not in reduced_vars:
inputs[k] = v
if v.dtype == "real":
if v in reduced_vars:
raise ValueError(
f"Cannot sum along a real dimension: {repr(v)}"
)
else:
old_ints[k] = v
if k in reduced_vars:
reduced_perm.append(i)
else:
kept_perm.append(i)
new_ints[k] = v
n = len(kept_perm) + len(reduced_perm)

# The square root information filter fuses via transpose and reshape.
perm = kept_perm + reduced_perm + [n]
white_vec = ops.permute(self.white_vec, perm)
white_vec = white_vec.reshape(white_vec.shape[: len(kept_perm)] + (-1,))
perm = kept_perm + [n] + reduced_perm + [n + 1]
prec_sqrt = ops.permute(self.prec_sqrt, perm)
prec_sqrt = prec_sqrt.reshape(prec_sqrt.shape[: len(kept_perm) + 1] + (-1,))
assert prec_sqrt.shape[:-2] == white_vec.shape[:-1]
assert prec_sqrt.shape[-1] == white_vec.shape[-1]

return Gaussian(white_vec, prec_sqrt, inputs)

return None  # defer to default implementation

def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
if any(self.inputs[k].dtype != "real" for k in sampled_vars):
raise ValueError(
"Sampling from non-normalized Gaussian mixtures is intentionally "
"not implemented. You probably want to normalize. To work around, "
"add a zero Tensor/Array with given inputs."
)

# Partition inputs into sample_inputs + int_inputs + real_inputs.
sample_inputs = OrderedDict(
(k, d) for k, d in sample_inputs.items() if k not in self.inputs
)
int_inputs = OrderedDict()
sampled_real_inputs = OrderedDict()
remaining_real_inputs = OrderedDict()
for k, d in self.inputs.items():
if d.dtype != "real":
int_inputs[k] = d
elif k in sampled_vars:
sampled_real_inputs[k] = d
else:
remaining_real_inputs[k] = d
if self.rank < sum(d.num_elements for d in sampled_real_inputs.values()):
raise ValueError(
f"Too little information to sample over {set(sampled_vars)}. "
)

if not remaining_real_inputs:  # Sample all variables.
# Triangularize via _compress_rank().
white_vec, prec_sqrt, _ = _compress_rank(
self.white_vec, self.prec_sqrt, assume_full_rank=True
)

# Jointly sample.
# This section may involve either Funsors or backend arrays.
dim = prec_sqrt.shape[-1]
white_noise = _sample_white_noise(
sample_inputs, int_inputs, dim, self.white_vec, rng_key
)
if isinstance(white_noise, Funsor):
white_vec = Tensor(white_vec, int_inputs)
prec_sqrt = Tensor(prec_sqrt, int_inputs)
sample = ops.triangular_solve(
(white_noise + white_vec)[..., None], prec_sqrt, transpose=True
)[..., 0]

# Compute the remaining Tensor.
remaining = self.log_normalizer

else:  # Sample only a subset of real variables.
# Split into sampled variables a and remaining variables b.
a, b = _split_real_inputs(self.inputs, sampled_vars, self.white_vec)
prec_sqrt_a = self.prec_sqrt[..., a, :]
prec_sqrt_b = self.prec_sqrt[..., b, :]
dim_a = prec_sqrt_a.shape[-2]

# Compute white_vec of a lazily conditioned on b's variables.
# This requires Funsors rather than backend arrays.
flat = ops.cat(
[
Variable(k, d).reshape((d.num_elements,))
for k, d in remaining_real_inputs.items()
]
)
white_vec_a = (
Tensor(self.white_vec, int_inputs)
- (flat[None] @ Tensor(prec_sqrt_b, int_inputs))
)

# Triangularize.
precision_chol_a = Tensor(ops.cholesky(_mmt(prec_sqrt_a)), int_inputs)
white_vec_a = ops.triangular_solve(
Tensor(prec_sqrt_a, int_inputs) @ white_vec_a[..., None],
precision_chol_a,
)[..., 0]

# Jointly sample.
white_noise = _sample_white_noise(
sample_inputs, int_inputs, dim_a, self.white_vec, rng_key
)
if not isinstance(white_noise, Funsor):
inputs = sample_inputs.copy()
inputs.update(int_inputs)
white_noise = Tensor(white_noise, inputs)
sample = ops.triangular_solve(
(white_noise + white_vec_a)[..., None], precision_chol_a, transpose=True
)[..., 0]

# Compute the remaining Gaussian, equivalent to
# self.reduce(ops.logaddexp, sampled_vars), but avoiding duplicate work.
inputs = int_inputs.copy()
inputs.update(remaining_real_inputs)
remaining = self._marginalize_after_split(
inputs, int_inputs, prec_sqrt_a, prec_sqrt_b, precision_chol_a.data
)

# Extract shaped components of the flat concatenated sample.
results = [remaining]
offsets, _ = _compute_offsets(sampled_real_inputs)
for key, domain in sampled_real_inputs.items():
point = sample[..., offsets[key] : offsets[key] + domain.num_elements]
point = point.reshape(point.shape[:-1] + domain.shape)
if not isinstance(point, Funsor):  # I.e. when eagerly sampling.
inputs = sample_inputs.copy()
inputs.update(int_inputs)
point = Tensor(point, inputs)
assert point.output == domain
results.append(Delta(key, point))

def _marginalize_after_split(
self, inputs, int_inputs, prec_sqrt_a, prec_sqrt_b, precision_chol_a
):
"""
Helper used in partial reduction and partial sampling.
This marginalizes over a and returns a shifted Gaussian over b.
"""
dim_a = prec_sqrt_a.shape[-2]
dim_b = prec_sqrt_b.shape[-2]
result = Tensor(
dim_a * math.log(2 * math.pi) / 2 - _log_det_tri(precision_chol_a),
int_inputs,
)
if self.rank > dim_a:
proj_a = _mtm(ops.triangular_solve(prec_sqrt_a, precision_chol_a))
prec_sqrt = prec_sqrt_b - prec_sqrt_b @ proj_a
white_vec = self.white_vec - _vm(self.white_vec, proj_a)
result += Gaussian(white_vec, prec_sqrt, inputs)
else:  # The Gaussian over xa is zero.
# TODO switch from an empty Gaussian to a Constant once this works:
# from .constant import Constant
# const_inputs = OrderedDict(
#     (k, v) for k, v in inputs.items() if k not in result.inputs
# )
# result = Constant(const_inputs, result)
batch_shape = self.white_vec.shape[:-1]
white_vec = ops.new_zeros(self.white_vec, batch_shape + (0,))
prec_sqrt = ops.new_zeros(self.white_vec, batch_shape + (dim_b, 0))
result += Gaussian(white_vec, prec_sqrt, inputs)
return result

def _sample_white_noise(sample_inputs, int_inputs, dim, prototype, rng_key):
if [v.dtype for v in sample_inputs.values()] == ["real"]:
# Lazily compute a sample as a function of white noise.
k, d = next(iter(sample_inputs.items()))
return Variable(k, d)[tuple(int_inputs)]

# Eagerly draw noise.
shape = tuple(d.size for d in sample_inputs.values() if d.dtype != "real")
shape += tuple(d.size for d in int_inputs.values())
shape += (dim,)
assert ops.is_numeric_array(prototype)
return ops.randn(prototype, shape, rng_key)

@compress_gaussians.register(Gaussian, object, object, tuple)
def _compress_gaussians(white_vec, prec_sqrt, inputs):
dim, rank = prec_sqrt.shape[-2:]
if rank <= dim:
return None
white_vec, prec_sqrt, shift = _compress_rank(white_vec, prec_sqrt)
int_inputs = OrderedDict((k, v) for k, v in inputs if v.dtype != "real")
return Gaussian(white_vec, prec_sqrt, inputs) + Tensor(shift, int_inputs)

# Fuse two Gaussians by adding their log-densities pointwise.
# This is similar to a Kalman filter update, but also keeps track of
# the marginal likelihood which accumulates into a Tensor.

# Align data.
inputs = lhs.inputs.copy()
inputs.update(rhs.inputs)
lhs_white_vec, lhs_prec_sqrt = align_gaussian(inputs, lhs, expand=True)
rhs_white_vec, rhs_prec_sqrt = align_gaussian(inputs, rhs, expand=True)

# Fuse aligned Gaussians via concatenation.
white_vec = ops.cat([lhs_white_vec, rhs_white_vec], -1)
prec_sqrt = ops.cat([lhs_prec_sqrt, rhs_prec_sqrt], -1)
return Gaussian(white_vec, prec_sqrt, inputs)

@eager.register(Binary, SubOp, Gaussian, Gaussian)
def eager_sub(op, lhs, rhs):
return lhs + (-rhs)

__all__ = [
"BlockMatrix",
"BlockVector",
"Gaussian",
"align_gaussian",
]