# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from collections.abc import Hashable
from funsor.cnf import Contraction
from funsor.interpretations import Interpretation, reflect
from funsor.interpreter import stack_reinterpret
from funsor.ops import AssociativeOp
from funsor.registry import KeyedRegistry
from funsor.terms import (
Approximate,
Binary,
Cat,
Funsor,
Reduce,
Scatter,
Slice,
Subs,
substitute,
to_funsor,
)
from . import instrument, interpreter, ops
def _alpha_unmangle(expr):
alpha_subs = {
name: name.split("__BOUND")[0] for name in expr.bound if "__BOUND" in name
}
if not alpha_subs:
return tuple(expr._ast_values)
return expr._alpha_convert(alpha_subs)
[docs]class AdjointTape(Interpretation):
def __init__(self):
super().__init__("adjoint")
self.tape = []
self._old_interpretation = None
self._eager_to_lazy = {}
[docs] def interpret(self, cls, *args):
if cls in adjoint_ops: # atomic op, don't trace internals
with self._old_interpretation:
result = cls(*args)
self.tape.append((result, cls, args))
else:
result = self._old_interpretation.interpret(cls, *args)
lazy_args = [
self._eager_to_lazy.get(
id(arg)
if ops.is_numeric_array(arg) or not isinstance(arg, Hashable)
else arg,
arg,
)
for arg in args
]
with self._old_interpretation:
self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args)
return result
def __enter__(self):
self.tape = []
self._old_interpretation = interpreter.get_interpretation()
return super().__enter__()
[docs] def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()):
zero = to_funsor(ops.UNITS[sum_op])
one = to_funsor(ops.UNITS[bin_op])
adjoint_values = defaultdict(lambda: zero)
adjoint_values[root] = one
reached_root = False
while self.tape:
output, fn, inputs = self.tape.pop()
if not reached_root:
if output is root:
reached_root = True
else:
continue
# reverse the effects of alpha-renaming
with reflect:
lazy_output = self._eager_to_lazy[output]
lazy_fn = type(lazy_output)
lazy_inputs = lazy_output._ast_values
# TODO abstract this into a helper function
# FIXME make lazy_output linear instead of quadratic in the size of the tape
lazy_other_subs = tuple(
(name, to_funsor(name.split("__BOUND")[0], domain))
for name, domain in lazy_output.inputs.items()
if "__BOUND" in name
)
lazy_inputs = _alpha_unmangle(
substitute(lazy_fn(*lazy_inputs), lazy_other_subs)
)
lazy_output = type(lazy_output)(
*_alpha_unmangle(substitute(lazy_output, lazy_other_subs))
)
other_subs = tuple(
(name, to_funsor(name.split("__BOUND")[0], domain))
for name, domain in output.inputs.items()
if "__BOUND" in name
)
inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs))
output = type(output)(*_alpha_unmangle(substitute(output, other_subs)))
self._eager_to_lazy[output] = lazy_output
in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
for v, adjv in in_adjs:
# Marginalize out message variables that don't appear in recipients.
agg_vars = adjv.input_vars - v.input_vars - root.input_vars - batch_vars
assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME
old_value = adjoint_values[v]
adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars))
result = defaultdict(lambda: zero)
for key, value in adjoint_values.items():
lazy_key = self._eager_to_lazy.get(key, key)
result[lazy_key] = value
if targets is None:
return result
return {target: result[target] for target in targets}
[docs]def forward_backward(sum_op, bin_op, expr, *, batch_vars=frozenset()):
with AdjointTape() as tape:
# TODO fix traversal order in AdjointTape instead of using stack_reinterpret
forward = stack_reinterpret(expr)
backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars)
return forward, backward
[docs]def adjoint(sum_op, bin_op, expr):
forward, backward = forward_backward(sum_op, bin_op, expr)
return backward
# logaddexp/add
def _fail_default(*args):
raise NotImplementedError("Should not be here! {}".format(args))
adjoint_ops = KeyedRegistry(default=_fail_default)
if instrument.DEBUG:
adjoint_ops_register = adjoint_ops.register
adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(*args)(
instrument.debug_logged(fn)
)
[docs]@adjoint_ops.register(
Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor
)
def adjoint_binary(adj_sum_op, adj_prod_op, out_adj, op, lhs, rhs):
if op is adj_prod_op:
lhs_adj = adj_prod_op(out_adj, rhs)
rhs_adj = adj_prod_op(out_adj, lhs)
return ((lhs, lhs_adj), (rhs, rhs_adj))
elif op is adj_sum_op:
return ((lhs, out_adj), (rhs, out_adj))
raise ValueError("should not be here!")
[docs]@adjoint_ops.register(
Reduce, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, frozenset
)
def adjoint_reduce(adj_sum_op, adj_prod_op, out_adj, op, arg, reduced_vars):
if op is adj_sum_op:
out_adj = Approximate(
adj_sum_op, out_adj, adj_prod_op(out_adj, arg), reduced_vars
)
return ((arg, out_adj),)
elif op is adj_prod_op: # plate!
out = arg.reduce(adj_prod_op, reduced_vars)
div_op = ops.SAFE_BINARY_INVERSES[adj_prod_op]
return ((arg, div_op(adj_prod_op(out_adj, out), arg)),)
raise ValueError("should not be here!")
[docs]@adjoint_ops.register(
Contraction,
AssociativeOp,
AssociativeOp,
Funsor,
AssociativeOp,
AssociativeOp,
frozenset,
Funsor,
)
def adjoint_contract_unary(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, arg
):
return adjoint_reduce(adj_sum_op, adj_prod_op, out_adj, sum_op, arg, reduced_vars)
[docs]@adjoint_ops.register(
Contraction,
AssociativeOp,
AssociativeOp,
Funsor,
AssociativeOp,
AssociativeOp,
frozenset,
tuple,
)
def adjoint_contract_generic(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms
):
assert len(terms) == 1 or len(terms) == 2
return adjoint_ops(
Contraction,
adj_sum_op,
adj_prod_op,
out_adj,
sum_op,
prod_op,
reduced_vars,
*terms
)
[docs]@adjoint_ops.register(
Contraction,
AssociativeOp,
AssociativeOp,
Funsor,
AssociativeOp,
AssociativeOp,
frozenset,
Funsor,
Funsor,
)
def adjoint_contract(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs
):
if prod_op is adj_prod_op and sum_op in (ops.null, adj_sum_op):
# the only change is here:
out_adj = Approximate(
adj_sum_op,
out_adj,
adj_prod_op(out_adj, adj_prod_op(lhs, rhs)),
reduced_vars,
)
lhs_adj = adj_prod_op(out_adj, rhs)
rhs_adj = adj_prod_op(lhs, out_adj)
return ((lhs, lhs_adj), (rhs, rhs_adj))
elif prod_op is adj_sum_op:
if reduced_vars:
raise NotImplementedError("TODO implement sum Contraction")
return ((lhs, out_adj), (rhs, out_adj))
raise ValueError("should not be here!")
[docs]@adjoint_ops.register(Cat, AssociativeOp, AssociativeOp, Funsor, str, tuple, str)
def adjoint_cat(adj_sum_op, adj_prod_op, out_adj, name, parts, part_name):
if part_name not in out_adj.inputs:
return tuple((part, out_adj) for part in parts)
in_adjs = []
start = 0
size = sum(part.inputs[part_name].dtype for part in parts)
for i, part in enumerate(parts):
part_slice = Slice(name, start, start + part.inputs[part_name].dtype, 1, size)
part_adj = out_adj(**{name: part_slice})
in_adjs.append((part, part_adj))
start += part.inputs[part_name].dtype
return tuple(in_adjs)
[docs]@adjoint_ops.register(Subs, AssociativeOp, AssociativeOp, Funsor, Funsor, tuple)
def adjoint_subs(adj_sum_op, adj_prod_op, out_adj, arg, subs):
# detect fresh variable collisions that should be relabeled and reduced
relabel = {k: interpreter.gensym(k) for k, v in subs}
relabeled_subs = tuple((relabel[k], v) for k, v in subs)
relabeled_arg = arg(**relabel)
reduced_vars = out_adj.input_vars - relabeled_arg.input_vars
for k, v in subs:
reduced_vars |= v.input_vars - relabeled_arg.input_vars
relabeled_arg_adj = Scatter(adj_sum_op, relabeled_subs, out_adj, reduced_vars)
arg_adj = relabeled_arg_adj(**{v: k for k, v in relabel.items()})
return ((arg, arg_adj),)
[docs]@adjoint_ops.register(
Scatter,
AssociativeOp,
AssociativeOp,
Funsor,
AssociativeOp,
tuple,
Funsor,
frozenset,
)
def adjoint_scatter(adj_sum_op, adj_prod_op, out_adj, op, subs, source, reduced_vars):
return ((source, out_adj(**dict(subs)).reduce(adj_sum_op, reduced_vars)),)