Source code for funsor.einsum

from functools import reduce

import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.interpreter import interpretation
from funsor.optimizer import apply_optimizer
from funsor.sum_product import sum_product
from funsor.terms import Funsor, lazy

BACKEND_OPS = {
    "torch": (ops.add, ops.mul),
    "pyro.ops.einsum.torch_log": (ops.logaddexp, ops.add),
    "pyro.ops.einsum.torch_marginal": (ops.logaddexp, ops.add),
    "pyro.ops.einsum.torch_map": (ops.max, ops.add),
    "pyro.ops.einsum.torch_sample": (ops.logaddexp, ops.add),
}

BACKEND_ADJOINT_OPS = {
    "pyro.ops.einsum.torch_marginal": (ops.logaddexp, ops.add),
    "pyro.ops.einsum.torch_map": (ops.max, ops.add),
}


[docs]def naive_contract_einsum(eqn, *terms, **kwargs): """ Use for testing Contract against einsum """ assert "plates" not in kwargs backend = kwargs.pop('backend', 'torch') if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split('->') inputs = inputs.split(',') assert len(inputs) == len(terms) assert len(output.split(',')) == 1 input_dims = frozenset(d for inp in inputs for d in inp) output_dims = frozenset(d for d in output) reduced_vars = input_dims - output_dims return Contraction(sum_op, prod_op, reduced_vars, *terms)
[docs]def naive_einsum(eqn, *terms, **kwargs): """ Implements standard variable elimination. """ backend = kwargs.pop('backend', 'torch') if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split('->') assert len(output.split(',')) == 1 input_dims = frozenset(d for inp in inputs.split(',') for d in inp) output_dims = frozenset(output) reduce_dims = input_dims - output_dims return reduce(prod_op, terms).reduce(sum_op, reduce_dims)
[docs]def naive_plated_einsum(eqn, *terms, **kwargs): """ Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019]) [Obermeyer et al 2019] Obermeyer, F., Bingham, E., Jankowiak, M., Chiu, J., Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for Plated Factor Graphs, 2019 """ plates = kwargs.pop('plates', '') if not plates: return naive_einsum(eqn, *terms, **kwargs) backend = kwargs.pop('backend', 'torch') if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split('->') inputs = inputs.split(',') assert len(inputs) == len(terms) assert len(output.split(',')) == 1 input_dims = frozenset(d for inp in inputs for d in inp) output_dims = frozenset(d for d in output) plate_dims = frozenset(plates) - output_dims reduce_vars = input_dims - output_dims - frozenset(plates) output_plates = output_dims & frozenset(plates) if not all(output_plates.issubset(inp) for inp in inputs): raise NotImplementedError("TODO") eliminate = plate_dims | reduce_vars return sum_product(sum_op, prod_op, terms, eliminate, frozenset(plates))
[docs]def einsum(eqn, *terms, **kwargs): r""" Top-level interface for optimized tensor variable elimination. :param str equation: An einsum equation. :param funsor.terms.Funsor \*terms: One or more operands. :param set plates: Optional keyword argument denoting which funsor dimensions are plate dimensions. Among all input dimensions (from terms): dimensions in plates but not in outputs are product-reduced; dimensions in neither plates nor outputs are sum-reduced. """ with interpretation(lazy): naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) return apply_optimizer(naive_ast)