Source code for funsor.compiler

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools

import funsor

from .cnf import Contraction
from .ops.program import OpProgram, make_tuple
from .tensor import Tensor
from .terms import Binary, Funsor, Number, Tuple, Unary, Variable


def compile_funsor(expr: Funsor) -> OpProgram:
    """
    Compiles a symbolic :class:`~funsor.terms.Funsor` to an
    :class:`~funsor.ops.program.OpProgram` that runs on backend values.

    Example::

        # Create a lazy expression.
        a = Variable("a", Reals[3, 3])
        b = Variable("b", Reals[3])
        x = Variable("x", Reals[3])
        expr = a @ x + b

        # Evaluate via Funsor substitution.
        data = dict(a=randn(3, 3), b=randn(3), x=randn(3))
        expected = expr(**data).data

        # Alternatively evaluate via a program.
        program = compile_funsor(expr)
        actual = program(**data)
        assert (acutal == expected).all()

    :param Funsor expr: A funsor expression to evaluate.
    :returns: An op program.
    :rtype: ~funsor.ops.program.OpProgram
    """
    assert isinstance(expr, Funsor)

    # Lower and convert to A-normal form.
    lowered_expr = lower(expr)
    anf = list(funsor.interpreter.anf(lowered_expr))
    ids = {}

    # Collect constants (leaves).
    constants = []
    for f in anf:
        if isinstance(f, (Number, Tensor)):
            ids[f] = len(ids)
            constants.append(f.data)

    # Collect input variables (leaves).
    inputs = []
    for k, d in expr.inputs.items():
        f = Variable(k, d)
        ids[f] = len(ids)
        inputs.append(k)

    # Collect operations to be computed (internal nodes).
    operations = []
    for f in anf:
        if f in ids:
            continue  # constant or free variable
        ids[f] = len(ids)
        if isinstance(f, Unary):
            arg_ids = (ids[f.arg],)
            operations.append((f.op, arg_ids))
        elif isinstance(f, Binary):
            arg_ids = (ids[f.lhs], ids[f.rhs])
            operations.append((f.op, arg_ids))
        elif isinstance(f, Tuple):
            arg_ids = tuple(ids[arg] for arg in f.args)
            operations.append((make_tuple, arg_ids))
        elif isinstance(f, tuple):
            continue  # Skip from Tuple directly to its elements.
        else:
            raise NotImplementedError(type(f).__name__)

    return OpProgram(constants, inputs, operations)


[docs]def lower(expr: Funsor) -> Funsor: """ Lower a funsor expression: - eliminate bound variables - convert Contraction to Binary :param Funsor expr: An arbitrary funsor expression. :returns: A lowered funsor expression. :rtype: Funsor """ # FIXME should this be lazy? What about Lambda? with funsor.interpretations.reflect: return _lower(expr)
@functools.singledispatch def _lower(x): raise NotImplementedError(type(x).__name__) @_lower.register(Number) @_lower.register(Tensor) @_lower.register(Variable) def _lower_atom(x): return x @_lower.register(Tuple) def _lower_tuple(x): args = tuple(_lower(arg) for arg in x.args) return Tuple(args) @_lower.register(Unary) def _lower_unary(x): arg = _lower(x.arg) return Unary(x.op, arg) @_lower.register(Binary) def _lower_binary(x): lhs = _lower(x.lhs) rhs = _lower(x.rhs) return Binary(x.op, lhs, rhs) @_lower.register(Contraction) def _lower_contraction(x): if x.reduced_vars: raise NotImplementedError("TODO") terms = [_lower(term) for term in x.terms] bin_op = functools.partial(Binary, x.bin_op) return functools.reduce(bin_op, terms) __all__ = [ "lower", ]