Source code for funsor.gaussian

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from collections import OrderedDict, defaultdict
from functools import reduce

import numpy as np

import funsor
import funsor.ops as ops
from funsor.affine import affine_inputs, extract_affine, is_affine
from funsor.delta import Delta
from funsor.domains import Real, Reals
from funsor.ops import AddOp, NegOp, SubOp
from funsor.tensor import Tensor, align_tensor, align_tensors
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Slice, Subs, Unary, Variable, eager, reflect
from funsor.util import broadcast_shape, get_backend, get_tracing_state, lazy_property


def _log_det_tri(x):
    return ops.log(ops.diagonal(x, -1, -2)).sum(-1)


def _vv(vec1, vec2):
    """
    Computes the inner product ``< vec1 | vec 2 >``.
    """
    return ops.matmul(ops.unsqueeze(vec1, -2), ops.unsqueeze(vec2, -1)).squeeze(-1).squeeze(-1)


def _mv(mat, vec):
    return ops.matmul(mat, ops.unsqueeze(vec, -1)).squeeze(-1)


def _trace_mm(x, y):
    """
    Computes ``trace(x.T @ y)``.
    """
    assert len(x.shape) >= 2
    assert len(y.shape) >= 2
    return (x * y).sum((-1, -2))


def _compute_offsets(inputs):
    """
    Compute offsets of real inputs into the concatenated Gaussian dims.
    This ignores all int inputs.

    :param OrderedDict inputs: A schema mapping variable name to domain.
    :return: a pair ``(offsets, total)``, where ``offsets`` is an OrderedDict
        mapping input name to integer offset, and ``total`` is the total event
        size.
    :rtype: tuple
    """
    assert isinstance(inputs, OrderedDict)
    offsets = OrderedDict()
    total = 0
    for key, domain in inputs.items():
        if domain.dtype == 'real':
            offsets[key] = total
            total += domain.num_elements
    return offsets, total


def _find_intervals(intervals, end):
    """
    Finds a complete set of intervals partitioning [0, end), given a partial
    set of non-overlapping intervals.
    """
    cuts = list(sorted({0, end}.union(*intervals)))
    return list(zip(cuts[:-1], cuts[1:]))


def _parse_slices(index, value):
    if not isinstance(index, tuple):
        index = (index,)
    if index[0] is Ellipsis:
        index = index[1:]
    start_stops = []
    for pos, i in reversed(list(enumerate(index))):
        if isinstance(i, slice):
            start_stops.append((i.start, i.stop))
        elif isinstance(i, int):
            start_stops.append((i, i + 1))
            value = ops.unsqueeze(value, pos - len(index))
        else:
            raise ValueError("invalid index: {}".format(i))
    start_stops.reverse()
    return start_stops, value


[docs]class BlockVector(object): """ Jit-compatible helper to build blockwise vectors. Syntax is similar to :func:`torch.zeros` :: x = BlockVector((100, 20)) x[..., 0:4] = x1 x[..., 6:10] = x2 x = x.as_tensor() assert x.shape == (100, 20) """ def __init__(self, shape): self.shape = shape self.parts = {} def __setitem__(self, index, value): (i,), value = _parse_slices(index, value) self.parts[i] = value
[docs] def as_tensor(self): # Fill gaps with zeros. prototype = next(iter(self.parts.values())) for i in _find_intervals(self.parts.keys(), self.shape[-1]): if i not in self.parts: self.parts[i] = ops.new_zeros(prototype, self.shape[:-1] + (i[1] - i[0],)) # Concatenate parts. parts = [v for k, v in sorted(self.parts.items())] result = ops.cat(-1, *parts) if not get_tracing_state(): assert result.shape == self.shape return result
[docs]class BlockMatrix(object): """ Jit-compatible helper to build blockwise matrices. Syntax is similar to :func:`torch.zeros` :: x = BlockMatrix((100, 20, 20)) x[..., 0:4, 0:4] = x11 x[..., 0:4, 6:10] = x12 x[..., 6:10, 0:4] = x12.transpose(-1, -2) x[..., 6:10, 6:10] = x22 x = x.as_tensor() assert x.shape == (100, 20, 20) """ def __init__(self, shape): self.shape = shape self.parts = defaultdict(dict) def __setitem__(self, index, value): (i, j), value = _parse_slices(index, value) self.parts[i][j] = value
[docs] def as_tensor(self): # Fill gaps with zeros. arbitrary_row = next(iter(self.parts.values())) prototype = next(iter(arbitrary_row.values())) js = set().union(*(part.keys() for part in self.parts.values())) rows = _find_intervals(self.parts.keys(), self.shape[-2]) cols = _find_intervals(js, self.shape[-1]) for i in rows: for j in cols: if j not in self.parts[i]: shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0]) self.parts[i][j] = ops.new_zeros(prototype, shape) # Concatenate parts. # TODO This could be optimized into a single .reshape().cat().reshape() if # all inputs are contiguous, thereby saving a memcopy. columns = {i: ops.cat(-1, *[v for j, v in sorted(part.items())]) for i, part in self.parts.items()} result = ops.cat(-2, *[v for i, v in sorted(columns.items())]) if not get_tracing_state(): assert result.shape == self.shape return result
[docs]def align_gaussian(new_inputs, old): """ Align data of a Gaussian distribution to a new ``inputs`` shape. """ assert isinstance(new_inputs, OrderedDict) assert isinstance(old, Gaussian) info_vec = old.info_vec precision = old.precision # Align int inputs. # Since these are are managed as in Tensor, we can defer to align_tensor(). new_ints = OrderedDict((k, d) for k, d in new_inputs.items() if d.dtype != 'real') old_ints = OrderedDict((k, d) for k, d in old.inputs.items() if d.dtype != 'real') if new_ints != old_ints: info_vec = align_tensor(new_ints, Tensor(info_vec, old_ints)) precision = align_tensor(new_ints, Tensor(precision, old_ints)) # Align real inputs, which are all concatenated in the rightmost dims. new_offsets, new_dim = _compute_offsets(new_inputs) old_offsets, old_dim = _compute_offsets(old.inputs) assert info_vec.shape[-1:] == (old_dim,) assert precision.shape[-2:] == (old_dim, old_dim) if new_offsets != old_offsets: old_info_vec = info_vec old_precision = precision info_vec = BlockVector(old_info_vec.shape[:-1] + (new_dim,)) precision = BlockMatrix(old_info_vec.shape[:-1] + (new_dim, new_dim)) for k1, new_offset1 in new_offsets.items(): if k1 not in old_offsets: continue offset1 = old_offsets[k1] num_elements1 = old.inputs[k1].num_elements old_slice1 = slice(offset1, offset1 + num_elements1) new_slice1 = slice(new_offset1, new_offset1 + num_elements1) info_vec[..., new_slice1] = old_info_vec[..., old_slice1] for k2, new_offset2 in new_offsets.items(): if k2 not in old_offsets: continue offset2 = old_offsets[k2] num_elements2 = old.inputs[k2].num_elements old_slice2 = slice(offset2, offset2 + num_elements2) new_slice2 = slice(new_offset2, new_offset2 + num_elements2) precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2] info_vec = info_vec.as_tensor() precision = precision.as_tensor() return info_vec, precision
class GaussianMeta(FunsorMeta): """ Wrapper to convert between OrderedDict and tuple. """ def __call__(cls, info_vec, precision, inputs): if isinstance(inputs, OrderedDict): inputs = tuple(inputs.items()) assert isinstance(inputs, tuple) return super(GaussianMeta, cls).__call__(info_vec, precision, inputs)
[docs]class Gaussian(Funsor, metaclass=GaussianMeta): """ Funsor representing a batched joint Gaussian distribution as a log-density function. Mathematically, a Gaussian represents the density function:: f(x) = < x | info_vec > - 0.5 * < x | precision | x > = < x | info_vec - 0.5 * precision @ x > Note that :class:`Gaussian` s are not normalized, rather they are canonicalized to evaluate to zero log density at the origin: ``f(0) = 0``. This canonical form is useful in combination with the information filter representation because it allows :class:`Gaussian` s with incomplete information, i.e. zero eigenvalues in the precision matrix. These incomplete distributions arise when making low-dimensional observations on higher dimensional hidden state. :param torch.Tensor info_vec: An optional batched information vector, where ``info_vec = precision @ mean``. :param torch.Tensor precision: A batched positive semidefinite precision matrix. :param OrderedDict inputs: Mapping from name to :class:`~funsor.domains.Domain` . """ def __init__(self, info_vec, precision, inputs): assert ops.is_numeric_array(info_vec) and ops.is_numeric_array(precision) assert isinstance(inputs, tuple) inputs = OrderedDict(inputs) # Compute total dimension of all real inputs. dim = sum(d.num_elements for d in inputs.values() if d.dtype == 'real') if not get_tracing_state(): assert dim assert len(precision.shape) >= 2 and precision.shape[-2:] == (dim, dim) assert len(info_vec.shape) >= 1 and info_vec.shape[-1] == dim # Compute total shape of all Bint inputs. batch_shape = tuple(d.dtype for d in inputs.values() if isinstance(d.dtype, int)) if not get_tracing_state(): assert precision.shape == batch_shape + (dim, dim) assert info_vec.shape == batch_shape + (dim,) output = Real fresh = frozenset(inputs.keys()) bound = frozenset() super(Gaussian, self).__init__(inputs, output, fresh, bound) self.info_vec = info_vec self.precision = precision self.batch_shape = batch_shape self.event_shape = (dim,) @lazy_property def _precision_chol(self): return ops.cholesky(self.precision)
[docs] @lazy_property def log_normalizer(self): dim = self.precision.shape[-1] log_det_term = _log_det_tri(self._precision_chol) loc_info_vec_term = 0.5 * (ops.triangular_solve( self.info_vec[..., None], self._precision_chol)[..., 0] ** 2).sum(-1) data = 0.5 * dim * math.log(2 * math.pi) - log_det_term + loc_info_vec_term inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') return Tensor(data, inputs)
def __repr__(self): return 'Gaussian(..., ({}))'.format(' '.join( '({}, {}),'.format(*kv) for kv in self.inputs.items()))
[docs] def align(self, names): assert isinstance(names, tuple) assert all(name in self.inputs for name in names) if not names or names == tuple(self.inputs): return self inputs = OrderedDict((name, self.inputs[name]) for name in names) inputs.update(self.inputs) info_vec, precision = align_gaussian(inputs, self) return Gaussian(info_vec, precision, inputs)
[docs] def eager_subs(self, subs): assert isinstance(subs, tuple) prototype = Tensor(self.info_vec) subs = tuple((k, v if isinstance(v, (Variable, Slice)) else prototype.materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Affine funsors are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple((k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice)) and not (is_affine(v) and affine_inputs(v))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') affine_subs = tuple((k, v) for k, v in subs if is_affine(v) and affine_inputs(v) and not isinstance(v, Variable)) if var_subs: return self._eager_subs_var(var_subs, int_subs + real_subs + affine_subs + lazy_subs) if int_subs: return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs) if real_subs: return self._eager_subs_real(real_subs, affine_subs + lazy_subs) if affine_subs: return self._eager_subs_affine(affine_subs, lazy_subs) return reflect(Subs, self, lazy_subs)
def _eager_subs_var(self, subs, remaining_subs): # Perform variable substitution, i.e. renaming of inputs. rename = {k: v.name for k, v in subs} inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, remaining_subs) if remaining_subs else var_result def _eager_subs_int(self, subs, remaining_subs): # Perform integer substitution, i.e. slicing into a batch. int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, remaining_subs) if remaining_subs else int_result def _eager_subs_real(self, subs, remaining_subs): # Broadcast all component tensors. subs = OrderedDict(subs) int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs)] tensors.extend(subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = len(tensors[0].shape) - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1,)) if not get_tracing_state(): assert value.shape[-1] == self.inputs[k].num_elements values[k] = ops.expand(value, batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size,)) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == Real return Subs(result, remaining_subs) if remaining_subs else result # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = ops.cat(-2, *[ops.cat(-1, *[ precision[..., i1, i2] for k2, i2 in slices if k2 in a]) for k1, i1 in slices if k1 in a]) prec_ab = ops.cat(-2, *[ops.cat(-1, *[ precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in a]) prec_bb = ops.cat(-2, *[ops.cat(-1, *[ precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in b]) info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a]) info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b]) value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b]) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:]) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in subs: inputs[k] = d result = Gaussian(info_vec, precision, inputs) + Tensor(log_scale, int_inputs) return Subs(result, remaining_subs) if remaining_subs else result def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs)] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape = old_info_vec.shape[:-1] # Align real dimensions. old_real_inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = Reals[new_shape] old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim,)) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ ops.new_eye(self.info_vec, batch_shape + (old_size,)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert const.data.shape == batch_shape + old_shape subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size,)) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ coeff.data.reshape(batch_shape + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = ops.transpose(subs_matrix, -1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor(const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
[docs] def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. assert all(v in self.inputs for v in reduced_vars) real_vars = frozenset(k for k, d in self.inputs.items() if d.dtype == "real") reduced_reals = reduced_vars & real_vars reduced_ints = reduced_vars - real_vars if not reduced_reals: return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) if reduced_reals == real_vars: result = self.log_normalizer else: int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') offsets, _ = _compute_offsets(self.inputs) a = [] b = [] for key, domain in self.inputs.items(): if domain.dtype == 'real': block = ops.new_arange(self.info_vec, offsets[key], offsets[key] + domain.num_elements, 1) (b if key in reduced_vars else a).append(block) a = ops.cat(-1, *a) b = ops.cat(-1, *b) prec_aa = self.precision[..., a[..., None], a] prec_ba = self.precision[..., b[..., None], a] prec_bb = self.precision[..., b[..., None], b] prec_b = ops.cholesky(prec_bb) prec_a = ops.triangular_solve(prec_ba, prec_b) prec_at = ops.transpose(prec_a, -1, -2) precision = prec_aa - ops.matmul(prec_at, prec_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] b_tmp = ops.triangular_solve(info_b[..., None], prec_b) info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0] log_prob = Tensor(0.5 * len(b) * math.log(2 * math.pi) - _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0] ** 2).sum(-1), int_inputs) result = log_prob + Gaussian(info_vec, precision, inputs) return result.reduce(ops.logaddexp, reduced_ints) elif op is ops.add: for v in reduced_vars: if self.inputs[v].dtype == 'real': raise ValueError("Cannot sum along a real dimension: {}".format(repr(v))) # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian(). old_ints = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') new_ints = OrderedDict((k, v) for k, v in old_ints.items() if k not in reduced_vars) inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars) info_vec = Tensor(self.info_vec, old_ints).reduce(ops.add, reduced_vars) precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars) assert info_vec.inputs == new_ints assert precision.inputs == new_ints return Gaussian(info_vec.data, precision.data, inputs) return None # defer to default implementation
[docs] def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if any(self.inputs[k].dtype != 'real' for k in sampled_vars): raise ValueError('Sampling from non-normalized Gaussian mixtures is intentionally ' 'not implemented. You probably want to normalize. To work around, ' 'add a zero Tensor/Array with given inputs.') # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): shape = sample_shape + self.info_vec.shape backend = get_backend() if backend != "numpy": from importlib import import_module dist = import_module(funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) sample_args = (shape,) if rng_key is None else (rng_key, shape) white_noise = dist.Normal.dist_class(0, 1).sample(*sample_args) else: white_noise = np.random.randn(*shape) white_noise = ops.unsqueeze(white_noise, -1) white_vec = ops.triangular_solve(self.info_vec[..., None], self._precision_chol) sample = ops.triangular_solve(white_noise + white_vec, self._precision_chol, transpose=True)[..., 0] offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]: offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self.log_normalizer) return reduce(ops.add, results) raise NotImplementedError('TODO implement partial sampling of real variables')
@eager.register(Binary, AddOp, Gaussian, Gaussian) def eager_add_gaussian_gaussian(op, lhs, rhs): # Fuse two Gaussians by adding their log-densities pointwise. # This is similar to a Kalman filter update, but also keeps track of # the marginal likelihood which accumulates into a Tensor. # Align data. inputs = lhs.inputs.copy() inputs.update(rhs.inputs) lhs_info_vec, lhs_precision = align_gaussian(inputs, lhs) rhs_info_vec, rhs_precision = align_gaussian(inputs, rhs) # Fuse aligned Gaussians. info_vec = lhs_info_vec + rhs_info_vec precision = lhs_precision + rhs_precision return Gaussian(info_vec, precision, inputs) @eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) @eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) def eager_sub(op, lhs, rhs): return lhs + -rhs @eager.register(Unary, NegOp, Gaussian) def eager_neg(op, arg): info_vec = -arg.info_vec precision = -arg.precision return Gaussian(info_vec, precision, arg.inputs) __all__ = [ 'BlockMatrix', 'BlockVector', 'Gaussian', 'align_gaussian', ]