Source code for funsor.affine
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import OrderedDict
from functools import reduce, singledispatch
import opt_einsum
from funsor.domains import Bint
from funsor.interpreter import gensym
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable
from . import ops
[docs]def is_affine(fn):
"""
A sound but incomplete test to determine whether a funsor is affine with
respect to all of its real inputs.
:param Funsor fn: A funsor.
:rtype: bool
"""
return affine_inputs(fn) == _real_inputs(fn)
def _real_inputs(fn):
return frozenset(k for k, d in fn.inputs.items() if d.dtype == "real")
[docs]def affine_inputs(fn):
"""
Returns a [sound sub]set of real inputs of ``fn``
wrt which ``fn`` is known to be affine.
:param Funsor fn: A funsor.
:return: A set of input names wrt which ``fn`` is affine.
:rtype: frozenset
"""
result = getattr(fn, "_affine_inputs", None)
if result is None:
result = fn._affine_inputs = _affine_inputs(fn)
return result
@singledispatch
def _affine_inputs(fn):
assert isinstance(fn, Funsor)
return frozenset()
# Make registration public.
affine_inputs.register = _affine_inputs.register
@affine_inputs.register(Variable)
def _(fn):
return _real_inputs(fn)
@affine_inputs.register(Unary)
def _(fn):
if fn.op in (ops.neg, ops.sum) or isinstance(
fn.op, (ops.ReshapeOp, ops.GetsliceOp)
):
return affine_inputs(fn.arg)
return frozenset()
@affine_inputs.register(Binary)
def _(fn):
if fn.op in (ops.add, ops.sub):
return affine_inputs(fn.lhs) | affine_inputs(fn.rhs)
if fn.op is ops.truediv:
return affine_inputs(fn.lhs) - _real_inputs(fn.rhs)
if isinstance(fn.op, ops.GetitemOp):
return affine_inputs(fn.lhs)
if fn.op in (ops.mul, ops.matmul):
lhs_affine = affine_inputs(fn.lhs) - _real_inputs(fn.rhs)
rhs_affine = affine_inputs(fn.rhs) - _real_inputs(fn.lhs)
if not lhs_affine:
return rhs_affine
if not rhs_affine:
return lhs_affine
# This multilinear case introduces incompleteness, since some vars
# could later be reduced, making remaining vars affine.
return frozenset()
return frozenset()
@affine_inputs.register(Reduce)
def _(fn):
return affine_inputs(fn.arg) - fn.reduced_vars
@affine_inputs.register(Finitary[ops.EinsumOp, tuple])
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
for i, x in enumerate(fn.args):
others = fn.args[:i] + fn.args[i + 1 :]
other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset())
results.append(affine_inputs(x) - other_inputs)
# This multilinear case introduces incompleteness, since some vars
# could later be reduced, making remaining vars affine.
if sum(map(bool, results)) == 1:
for result in results:
if result:
return result
return frozenset()
[docs]def extract_affine(fn):
"""
Extracts an affine representation of a funsor, satisfying::
x = ...
const, coeffs = extract_affine(x)
y = sum(Einsum(eqn, coeff, Variable(var, coeff.output))
for var, (coeff, eqn) in coeffs.items())
assert_close(y, x)
assert frozenset(coeffs) == affine_inputs(x)
The ``coeffs`` will have one key per input wrt which ``fn`` is known to be
affine (via :func:`affine_inputs` ), and ``const`` and ``coeffs.values``
will all be constant wrt these inputs.
The affine approximation is computed by ev evaluating ``fn`` at
zero and each basis vector. To improve performance, users may want to run
under the :func:`~funsor.interpretations.Memoize` interpretation.
:param Funsor fn: A funsor that is affine wrt the (add,mul) semiring in
some subset of its inputs.
:return: A pair ``(const, coeffs)`` where const is a funsor with no real
inputs and ``coeffs`` is an OrderedDict mapping input name to a
``(coefficient, eqn)`` pair in einsum form.
:rtype: tuple
"""
# NB: this depends on the global default backend.
prototype = get_default_prototype()
# Determine constant part by evaluating fn at zero.
inputs = affine_inputs(fn)
inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if k in inputs)
zeros = {k: Tensor(ops.new_zeros(prototype, v.shape)) for k, v in inputs.items()}
const = fn(**zeros)
# Determine linear coefficients by evaluating fn on basis vectors.
name = gensym("probe")
coeffs = OrderedDict()
for k, v in inputs.items():
dim = v.num_elements
var = Variable(name, Bint[dim])
subs = zeros.copy()
subs[k] = Tensor(ops.new_eye(prototype, (dim,)).reshape((dim,) + v.shape))[var]
coeff = Lambda(var, fn(**subs) - const).reshape(v.shape + const.shape)
inputs1 = "".join(map(opt_einsum.get_symbol, range(len(coeff.shape))))
inputs2 = inputs1[: len(v.shape)]
output = inputs1[len(v.shape) :]
eqn = "{},{}->{}".format(inputs1, inputs2, output)
coeffs[k] = coeff, eqn
return const, coeffs
__all__ = [
"affine_inputs",
"extract_affine",
"is_affine",
]