Source code for funsor.einsum

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import reduce

import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.optimizer import apply_optimizer
from funsor.sum_product import sum_product
from funsor.terms import Funsor, Variable, lazy

# TODO: add numpy einsum here
BACKEND_OPS = {
    "torch": (ops.add, ops.mul),
    "pyro.ops.einsum.torch_log": (ops.logaddexp, ops.add),
    "pyro.ops.einsum.torch_marginal": (ops.logaddexp, ops.add),
    "pyro.ops.einsum.torch_map": (ops.max, ops.add),
    "pyro.ops.einsum.torch_sample": (ops.logaddexp, ops.add),
    "numpy": (ops.add, ops.mul),
    "funsor.einsum.numpy_log": (ops.logaddexp, ops.add),
    "funsor.einsum.numpy_map": (ops.max, ops.add),
    "jax.numpy": (ops.add, ops.mul),
}

BACKEND_ADJOINT_OPS = {
    "pyro.ops.einsum.torch_marginal": (ops.logaddexp, ops.add),
    "pyro.ops.einsum.torch_map": (ops.max, ops.add),
}


[docs]def naive_contract_einsum(eqn, *terms, **kwargs): """ Use for testing Contract against einsum """ assert "plates" not in kwargs backend = kwargs.pop("backend", "torch") if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split("->") inputs = inputs.split(",") assert len(inputs) == len(terms) assert len(output.split(",")) == 1 input_dims = frozenset(d for inp in inputs for d in inp) output_dims = frozenset(d for d in output) all_inputs = {k: v for term in terms for k, v in term.inputs.items()} reduced_vars = frozenset( Variable(k, all_inputs[k]) for k in input_dims - output_dims ) return Contraction(sum_op, prod_op, reduced_vars, *terms)
[docs]def naive_einsum(eqn, *terms, **kwargs): """ Implements standard variable elimination. """ backend = kwargs.pop("backend", "torch") if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split("->") assert len(output.split(",")) == 1 input_dims = frozenset(d for inp in inputs.split(",") for d in inp) output_dims = frozenset(output) reduce_dims = input_dims - output_dims return reduce(prod_op, terms).reduce(sum_op, reduce_dims)
[docs]def naive_plated_einsum(eqn, *terms, **kwargs): """ Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019]) [Obermeyer et al 2019] Obermeyer, F., Bingham, E., Jankowiak, M., Chiu, J., Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for Plated Factor Graphs, 2019 """ plates = kwargs.pop("plates", "") if not plates: return naive_einsum(eqn, *terms, **kwargs) backend = kwargs.pop("backend", "torch") if backend in BACKEND_OPS: sum_op, prod_op = BACKEND_OPS[backend] else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split("->") inputs = inputs.split(",") assert len(inputs) == len(terms) assert len(output.split(",")) == 1 input_dims = frozenset(d for inp in inputs for d in inp) output_dims = frozenset(d for d in output) plate_dims = frozenset(plates) - output_dims reduce_vars = input_dims - output_dims - frozenset(plates) output_plates = output_dims & frozenset(plates) if not all(output_plates.issubset(inp) for inp in inputs): raise NotImplementedError("TODO") eliminate = plate_dims | reduce_vars return sum_product(sum_op, prod_op, terms, eliminate, frozenset(plates))
[docs]def einsum(eqn, *terms, **kwargs): r""" Top-level interface for optimized tensor variable elimination. :param str equation: An einsum equation. :param funsor.terms.Funsor \*terms: One or more operands. :param set plates: Optional keyword argument denoting which funsor dimensions are plate dimensions. Among all input dimensions (from terms): dimensions in plates but not in outputs are product-reduced; dimensions in neither plates nor outputs are sum-reduced. """ with lazy: naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) return apply_optimizer(naive_ast)