Source code for funsor.gaussian

import math
from collections import OrderedDict, defaultdict
from functools import reduce

import torch
from pyro.distributions.util import broadcast_shape

import funsor.ops as ops
from funsor.affine import affine_inputs, extract_affine, is_affine
from import Delta
from import reals
from funsor.ops import AddOp, NegOp, SubOp
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Slice, Subs, Unary, Variable, eager, reflect
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
from funsor.util import lazy_property

def _log_det_tri(x):
    return x.diagonal(dim1=-1, dim2=-2).log().sum(-1)

def _vv(vec1, vec2):
    Computes the inner product ``< vec1 | vec 2 >``.
    return vec1.unsqueeze(-2).matmul(vec2.unsqueeze(-1)).squeeze(-1).squeeze(-1)

def _mv(mat, vec):
    return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1)

def _trace_mm(x, y):
    Computes ``trace(x.T @ y)``.
    assert x.dim() >= 2
    assert y.dim() >= 2
    return (x * y).sum([-1, -2])

def cholesky(u):
    Like :func:`torch.cholesky` but uses sqrt for scalar matrices.
    Works around often.
    if u.size(-1) == 1:
        return u.sqrt()
    return u.cholesky()

def cholesky_inverse(u):
    Like :func:`torch.cholesky_inverse` but supports batching and gradients.
    if u.dim() == 2:
        return u.cholesky_inverse()
    return torch.eye(u.size(-1)).expand(u.size()).cholesky_solve(u)

def _compute_offsets(inputs):
    Compute offsets of real inputs into the concatenated Gaussian dims.
    This ignores all int inputs.

    :param OrderedDict inputs: A schema mapping variable name to domain.
    :return: a pair ``(offsets, total)``, where ``offsets`` is an OrderedDict
        mapping input name to integer offset, and ``total`` is the total event
    :rtype: tuple
    assert isinstance(inputs, OrderedDict)
    offsets = OrderedDict()
    total = 0
    for key, domain in inputs.items():
        if domain.dtype == 'real':
            offsets[key] = total
            total += domain.num_elements
    return offsets, total

def _find_intervals(intervals, end):
    Finds a complete set of intervals partitioning [0, end), given a partial
    set of non-overlapping intervals.
    cuts = list(sorted({0, end}.union(*intervals)))
    return list(zip(cuts[:-1], cuts[1:]))

def _parse_slices(index, value):
    if not isinstance(index, tuple):
        index = (index,)
    if index[0] is Ellipsis:
        index = index[1:]
    start_stops = []
    for pos, i in reversed(list(enumerate(index))):
        if isinstance(i, slice):
            start_stops.append((i.start, i.stop))
        elif isinstance(i, int):
            start_stops.append((i, i + 1))
            value = value.unsqueeze(pos - len(index))
            raise ValueError("invalid index: {}".format(i))
    return start_stops, value

[docs]class BlockVector(object): """ Jit-compatible helper to build blockwise vectors. Syntax is similar to :func:`torch.zeros` :: x = BlockVector((100, 20)) x[..., 0:4] = x1 x[..., 6:10] = x2 x = x.as_tensor() assert x.shape == (100, 20) """ def __init__(self, shape): self.shape = shape = {} def __setitem__(self, index, value): (i,), value = _parse_slices(index, value)[i] = value
[docs] def as_tensor(self): # Fill gaps with zeros. prototype = next(iter( options = dict(dtype=prototype.dtype, device=prototype.device) for i in _find_intervals(, self.shape[-1]): if i not in[i] = torch.zeros(self.shape[:-1] + (i[1] - i[0],), **options) # Concatenate parts. parts = [v for k, v in sorted(] result =, dim=-1) if not torch._C._get_tracing_state(): assert result.shape == self.shape return result
[docs]class BlockMatrix(object): """ Jit-compatible helper to build blockwise matrices. Syntax is similar to :func:`torch.zeros` :: x = BlockMatrix((100, 20, 20)) x[..., 0:4, 0:4] = x11 x[..., 0:4, 6:10] = x12 x[..., 6:10, 0:4] = x12.transpose(-1, -2) x[..., 6:10, 6:10] = x22 x = x.as_tensor() assert x.shape == (100, 20, 20) """ def __init__(self, shape): self.shape = shape = defaultdict(dict) def __setitem__(self, index, value): (i, j), value = _parse_slices(index, value)[i][j] = value
[docs] def as_tensor(self): # Fill gaps with zeros. arbitrary_row = next(iter( prototype = next(iter(arbitrary_row.values())) options = dict(dtype=prototype.dtype, device=prototype.device) js = set().union(*(part.keys() for part in rows = _find_intervals(, self.shape[-2]) cols = _find_intervals(js, self.shape[-1]) for i in rows: for j in cols: if j not in[i]: shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0])[i][j] = torch.zeros(shape, **options) # Concatenate parts. # TODO This could be optimized into a single .reshape().cat().reshape() if # all inputs are contiguous, thereby saving a memcopy. columns = {i:[v for j, v in sorted(part.items())], dim=-1) for i, part in} result =[v for i, v in sorted(columns.items())], dim=-2) if not torch._C._get_tracing_state(): assert result.shape == self.shape return result
[docs]def align_gaussian(new_inputs, old): """ Align data of a Gaussian distribution to a new ``inputs`` shape. """ assert isinstance(new_inputs, OrderedDict) assert isinstance(old, Gaussian) info_vec = old.info_vec precision = old.precision # Align int inputs. # Since these are are managed as in Tensor, we can defer to align_tensor(). new_ints = OrderedDict((k, d) for k, d in new_inputs.items() if d.dtype != 'real') old_ints = OrderedDict((k, d) for k, d in old.inputs.items() if d.dtype != 'real') if new_ints != old_ints: info_vec = align_tensor(new_ints, Tensor(info_vec, old_ints)) precision = align_tensor(new_ints, Tensor(precision, old_ints)) # Align real inputs, which are all concatenated in the rightmost dims. new_offsets, new_dim = _compute_offsets(new_inputs) old_offsets, old_dim = _compute_offsets(old.inputs) assert info_vec.shape[-1:] == (old_dim,) assert precision.shape[-2:] == (old_dim, old_dim) if new_offsets != old_offsets: old_info_vec = info_vec old_precision = precision info_vec = BlockVector(old_info_vec.shape[:-1] + (new_dim,)) precision = BlockMatrix(old_info_vec.shape[:-1] + (new_dim, new_dim)) for k1, new_offset1 in new_offsets.items(): if k1 not in old_offsets: continue offset1 = old_offsets[k1] num_elements1 = old.inputs[k1].num_elements old_slice1 = slice(offset1, offset1 + num_elements1) new_slice1 = slice(new_offset1, new_offset1 + num_elements1) info_vec[..., new_slice1] = old_info_vec[..., old_slice1] for k2, new_offset2 in new_offsets.items(): if k2 not in old_offsets: continue offset2 = old_offsets[k2] num_elements2 = old.inputs[k2].num_elements old_slice2 = slice(offset2, offset2 + num_elements2) new_slice2 = slice(new_offset2, new_offset2 + num_elements2) precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2] info_vec = info_vec.as_tensor() precision = precision.as_tensor() return info_vec, precision
class GaussianMeta(FunsorMeta): """ Wrapper to convert between OrderedDict and tuple. """ def __call__(cls, info_vec, precision, inputs): if isinstance(inputs, OrderedDict): inputs = tuple(inputs.items()) assert isinstance(inputs, tuple) return super(GaussianMeta, cls).__call__(info_vec, precision, inputs)
[docs]class Gaussian(Funsor, metaclass=GaussianMeta): """ Funsor representing a batched joint Gaussian distribution as a log-density function. Mathematically, a Gaussian represents the density function:: f(x) = < x | info_vec > - 0.5 * < x | precision | x > = < x | info_vec - 0.5 * precision @ x > Note that :class:`Gaussian` s are not normalized, rather they are canonicalized to evaluate to zero log density at the origin: ``f(0) = 0``. This canonical form is useful in combination with the information filter representation because it allows :class:`Gaussian` s with incomplete information, i.e. zero eigenvalues in the precision matrix. These incomplete distributions arise when making low-dimensional observations on higher dimensional hidden state. :param torch.Tensor info_vec: An optional batched information vector, where ``info_vec = precision @ mean``. :param torch.Tensor precision: A batched positive semidefinite precision matrix. :param OrderedDict inputs: Mapping from name to :class:`` . """ def __init__(self, info_vec, precision, inputs): assert isinstance(info_vec, torch.Tensor) assert isinstance(precision, torch.Tensor) assert isinstance(inputs, tuple) inputs = OrderedDict(inputs) # Compute total dimension of all real inputs. dim = sum(d.num_elements for d in inputs.values() if d.dtype == 'real') if not torch._C._get_tracing_state(): assert dim assert precision.dim() >= 2 and precision.shape[-2:] == (dim, dim) assert info_vec.dim() >= 1 and info_vec.size(-1) == dim # Compute total shape of all bint inputs. batch_shape = tuple(d.dtype for d in inputs.values() if isinstance(d.dtype, int)) if not torch._C._get_tracing_state(): assert precision.shape == batch_shape + (dim, dim) assert info_vec.shape == batch_shape + (dim,) output = reals() fresh = frozenset(inputs.keys()) bound = frozenset() super(Gaussian, self).__init__(inputs, output, fresh, bound) self.info_vec = info_vec self.precision = precision self.batch_shape = batch_shape self.event_shape = (dim,) @lazy_property def _precision_chol(self): return cholesky(self.precision)
[docs] @lazy_property def log_normalizer(self): dim = self.precision.size(-1) log_det_term = _log_det_tri(self._precision_chol) loc_info_vec_term = 0.5 * self.info_vec.unsqueeze(-1).triangular_solve( self._precision_chol, upper=False).solution.squeeze(-1).pow(2).sum(-1) data = 0.5 * dim * math.log(2 * math.pi) - log_det_term + loc_info_vec_term inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') return Tensor(data, inputs)
def __repr__(self): return 'Gaussian(..., ({}))'.format(' '.join( '({}, {}),'.format(*kv) for kv in self.inputs.items()))
[docs] def align(self, names): assert isinstance(names, tuple) assert all(name in self.inputs for name in names) if not names or names == tuple(self.inputs): return self inputs = OrderedDict((name, self.inputs[name]) for name in names) inputs.update(self.inputs) info_vec, precision = align_gaussian(inputs, self) return Gaussian(info_vec, precision, inputs)
[docs] def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple((k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Affine funsors are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple((k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice)) and not (is_affine(v) and affine_inputs(v))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') affine_subs = tuple((k, v) for k, v in subs if is_affine(v) and affine_inputs(v) and not isinstance(v, Variable)) if var_subs: return self._eager_subs_var(var_subs, int_subs + real_subs + affine_subs + lazy_subs) if int_subs: return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs) if real_subs: return self._eager_subs_real(real_subs, affine_subs + lazy_subs) if affine_subs: return self._eager_subs_affine(affine_subs, lazy_subs) return reflect(Subs, self, lazy_subs)
def _eager_subs_var(self, subs, remaining_subs): # Perform variable substitution, i.e. renaming of inputs. rename = {k: for k, v in subs} inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, remaining_subs) if remaining_subs else var_result def _eager_subs_int(self, subs, remaining_subs): # Perform integer substitution, i.e. slicing into a batch. int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, remaining_subs) if remaining_subs else int_result def _eager_subs_real(self, subs, remaining_subs): # Broadcast all component tensors. subs = OrderedDict(subs) int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs)] tensors.extend(subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1,)) if not torch._C._get_tracing_state(): assert value.size(-1) == self.inputs[k].num_elements values[k] = value.expand(batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size,)) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == reals() return Subs(result, remaining_subs) if remaining_subs else result # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa =[[ precision[..., i1, i2] for k2, i2 in slices if k2 in a], dim=-1) for k1, i1 in slices if k1 in a], dim=-2) prec_ab =[[ precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in a], dim=-2) prec_bb =[[ precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in b], dim=-2) info_a =[info_vec[..., i] for k, i in slices if k in a], dim=-1) info_b =[info_vec[..., i] for k, i in slices if k in b], dim=-1) value_b =[values[k] for k, i in slices if k in b], dim=-1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = prec_aa.expand(info_vec.shape + (-1,)) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in subs: inputs[k] = d result = Gaussian(info_vec, precision, inputs) + Tensor(log_scale, int_inputs) return Subs(result, remaining_subs) if remaining_subs else result def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs)] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape =[:-1] # Align real dimensions. old_real_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = reals(*new_shape) old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim,)) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ torch.eye(old_size).expand(batch_shape + (-1, -1)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert == batch_shape + old_shape subs_vector[..., old_slice] = + (old_size,)) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = subs_matrix.transpose(-1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor(const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
[docs] def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. assert all(v in self.inputs for v in reduced_vars) real_vars = frozenset(k for k, d in self.inputs.items() if d.dtype == "real") reduced_reals = reduced_vars & real_vars reduced_ints = reduced_vars - real_vars if not reduced_reals: return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) if reduced_reals == real_vars: result = self.log_normalizer else: int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') offsets, _ = _compute_offsets(self.inputs) a = [] b = [] for key, domain in self.inputs.items(): if domain.dtype == 'real': block = range(offsets[key], offsets[key] + domain.num_elements) (b if key in reduced_vars else a).extend(block) a = torch.tensor(a) b = torch.tensor(b) prec_aa = self.precision[..., a.unsqueeze(-1), a] prec_ba = self.precision[..., b.unsqueeze(-1), a] prec_bb = self.precision[..., b.unsqueeze(-1), b] prec_b = cholesky(prec_bb) prec_a = prec_ba.triangular_solve(prec_b, upper=False).solution prec_at = prec_a.transpose(-1, -2) precision = prec_aa - prec_at.matmul(prec_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] b_tmp = info_b.unsqueeze(-1).triangular_solve(prec_b, upper=False).solution info_vec = info_a - prec_at.matmul(b_tmp).squeeze(-1) log_prob = Tensor(0.5 * len(b) * math.log(2 * math.pi) - _log_det_tri(prec_b) + 0.5 * b_tmp.squeeze(-1).pow(2).sum(-1), int_inputs) result = log_prob + Gaussian(info_vec, precision, inputs) return result.reduce(ops.logaddexp, reduced_ints) elif op is ops.add: for v in reduced_vars: if self.inputs[v].dtype == 'real': raise ValueError("Cannot sum along a real dimension: {}".format(repr(v))) # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian(). old_ints = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') new_ints = OrderedDict((k, v) for k, v in old_ints.items() if k not in reduced_vars) inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars) info_vec = Tensor(self.info_vec, old_ints).reduce(ops.add, reduced_vars) precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars) assert info_vec.inputs == new_ints assert precision.inputs == new_ints return Gaussian(,, inputs) return None # defer to default implementation
[docs] def unscaled_sample(self, sampled_vars, sample_inputs): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if any(self.inputs[k].dtype != 'real' for k in sampled_vars): raise ValueError('Sampling from non-normalized Gaussian mixtures is intentionally ' 'not implemented. You probably want to normalize. To work around, ' 'add a zero Tensor with given inputs.') # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): shape = sample_shape + self.info_vec.shape white_noise = torch.randn(shape + (1,)) white_vec = self.info_vec.unsqueeze(-1).triangular_solve( self._precision_chol, upper=False).solution sample = (white_noise + white_vec).triangular_solve( self._precision_chol, upper=False, transpose=True).solution.squeeze(-1) offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]: offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self.log_normalizer) return reduce(ops.add, results) raise NotImplementedError('TODO implement partial sampling of real variables')
@eager.register(Binary, AddOp, Gaussian, Gaussian) def eager_add_gaussian_gaussian(op, lhs, rhs): # Fuse two Gaussians by adding their log-densities pointwise. # This is similar to a Kalman filter update, but also keeps track of # the marginal likelihood which accumulates into a Tensor. # Align data. inputs = lhs.inputs.copy() inputs.update(rhs.inputs) lhs_info_vec, lhs_precision = align_gaussian(inputs, lhs) rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs) # Fuse aligned Gaussians. info_vec = lhs_info_vec + rhs_info_vec precision = lhs_precision + rhs_precision return Gaussian(info_vec, precision, inputs) @eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) @eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) def eager_sub(op, lhs, rhs): return lhs + -rhs @eager.register(Unary, NegOp, Gaussian) def eager_neg(op, arg): info_vec = -arg.info_vec precision = -arg.precision return Gaussian(info_vec, precision, arg.inputs) __all__ = [ 'BlockMatrix', 'BlockVector', 'Gaussian', 'align_gaussian', ]