Source code for funsor.adjoint

from collections import OrderedDict, defaultdict

import torch

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture, nullop
from funsor.domains import bint
from funsor.gaussian import Gaussian, align_gaussian
from funsor.interpreter import interpretation
from funsor.ops import AssociativeOp
from funsor.registry import KeyedRegistry
from funsor.terms import Binary, Cat, Funsor, Number, Reduce, Slice, Subs, Variable, reflect, substitute, to_funsor
from funsor.torch import Tensor, materialize


def _alpha_unmangle(expr):
    alpha_subs = {name: name.split("__BOUND")[0]
                  for name in expr.bound if "__BOUND" in name}
    if not alpha_subs:
        return tuple(expr._ast_values)

    return expr._alpha_convert(alpha_subs)


[docs]class AdjointTape(object): def __init__(self): self.tape = [] self._old_interpretation = None def __call__(self, cls, *args): if cls in adjoint_ops: # atomic op, don't trace internals with interpretation(self._old_interpretation): result = cls(*args) self.tape.append((result, cls, args)) else: result = self._old_interpretation(cls, *args) return result def __enter__(self): self.tape = [] self._old_interpretation = interpreter._INTERPRETATION interpreter.set_interpretation(self) return self def __exit__(self, *args): interpreter.set_interpretation(self._old_interpretation) self._old_interpretation = None
[docs] def adjoint(self, red_op, bin_op, root, targets): bin_unit = to_funsor(ops.UNITS[bin_op]) adjoint_values = defaultdict(lambda: bin_unit) reached_root = False while self.tape: output, fn, inputs = self.tape.pop() if not reached_root: if output is root: reached_root = True else: continue # reverse the effects of alpha-renaming with interpretation(reflect): other_subs = tuple((name, to_funsor(name.split("__BOUND")[0], domain)) for name, domain in output.inputs.items() if "__BOUND" in name) inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) in_adjs = adjoint_ops(fn, red_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs.items(): adjoint_values[v] = bin_op(adjoint_values[v], adjv) target_adjs = {} for v in targets: target_adjs[v] = adjoint_values[v] if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs
# logaddexp/add def _fail_default(*args): raise NotImplementedError("Should not be here! {}".format(args)) adjoint_ops = KeyedRegistry(default=_fail_default) if interpreter._DEBUG: adjoint_ops_register = adjoint_ops.register adjoint_ops.register = lambda *args: lambda fn: adjoint_ops_register(*args)(interpreter.debug_logged(fn))
[docs]@adjoint_ops.register(Tensor, AssociativeOp, AssociativeOp, Funsor, torch.Tensor, tuple, object) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {}
[docs]@adjoint_ops.register(Binary, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, Funsor) def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs): assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars) rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs) rhs_adj = op(out_adj, lhs).reduce(adj_redop, rhs_reduced_vars) return {lhs: lhs_adj, rhs: rhs_adj}
[docs]@adjoint_ops.register(Reduce, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, Funsor, frozenset) def adjoint_reduce(adj_redop, adj_binop, out_adj, op, arg, reduced_vars): assert adj_binop is op or (op, adj_binop) in ops.DISTRIBUTIVE_OPS if op is adj_redop: # XXX using a hack to simulate "expand" return {arg: adj_binop(out_adj, Binary(ops.PRODUCT_INVERSES[adj_binop], arg, arg))} elif op is adj_binop: # plate! out = arg.reduce(op, reduced_vars) return {arg: adj_binop(out_adj, Binary(ops.PRODUCT_INVERSES[op], out, arg))}
[docs]@adjoint_ops.register(Contraction, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, AssociativeOp, frozenset, Funsor) def adjoint_contract_unary(adj_redop, adj_binop, out_adj, sum_op, prod_op, reduced_vars, arg): return adjoint_reduce(adj_redop, adj_binop, out_adj, sum_op, arg, reduced_vars)
[docs]@adjoint_ops.register(Contraction, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, AssociativeOp, frozenset, tuple) def adjoint_contract_generic(adj_redop, adj_binop, out_adj, sum_op, prod_op, reduced_vars, terms): assert len(terms) == 1 or len(terms) == 2 return adjoint_ops(Contraction, adj_redop, adj_binop, out_adj, sum_op, prod_op, reduced_vars, *terms)
[docs]@adjoint_ops.register(Contraction, AssociativeOp, AssociativeOp, Funsor, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor) def adjoint_contract(adj_redop, adj_binop, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs): assert sum_op is nullop or (sum_op, prod_op) in ops.DISTRIBUTIVE_OPS lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs) lhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop, prod_op, lhs_reduced_vars, out_adj, rhs) rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs) rhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop, prod_op, rhs_reduced_vars, out_adj, lhs) return {lhs: lhs_adj, rhs: rhs_adj}
[docs]@adjoint_ops.register(Cat, AssociativeOp, AssociativeOp, Funsor, str, tuple, str) def adjoint_cat(adj_redop, adj_binop, out_adj, name, parts, part_name): in_adjs = {} start = 0 size = sum(part.inputs[part_name].dtype for part in parts) for i, part in enumerate(parts): if part_name in out_adj.inputs: in_adjs[part] = out_adj(**{name: Slice(name, start, start + part.inputs[part_name].dtype, 1, size)}) start += part.inputs[part_name].dtype else: in_adjs[part] = adj_binop(out_adj, Binary(ops.PRODUCT_INVERSES[adj_binop], part, part)) return in_adjs
[docs]@adjoint_ops.register(Subs, AssociativeOp, AssociativeOp, (Number, Tensor), Tensor, tuple) def adjoint_subs_tensor(adj_redop, adj_binop, out_adj, arg, subs): assert all(isinstance(v, Funsor) for k, v in subs) # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) # TODO avoid reifying these zero/one tensors by using symbolic constants # ones for things that weren't sliced away ones_like_out = Subs(Tensor(torch.full_like(arg.data, ops.UNITS[adj_binop]), arg.inputs.copy(), arg.output.dtype), slices) arg_adj = adj_binop(out_adj, ones_like_out) # ones for things that were sliced away ones_like_arg = Tensor(torch.full_like(arg.data, ops.UNITS[adj_binop]), arg.inputs.copy(), arg.output.dtype) arg_adj = _scatter(arg_adj, ones_like_arg, slices) return {arg: arg_adj}
def _scatter(src, res, subs): # inverse of advanced indexing # TODO check types of subs, in case some logic from eager_subs was accidentally left out? # use advanced indexing logic copied from Tensor.eager_subs: # materialize after checking for renaming case subs = OrderedDict((k, materialize(v)) for k, v in subs) # Compute result shapes. inputs = OrderedDict() for k, domain in res.inputs.items(): inputs[k] = domain # Construct a dict with each input's positional dim, # counting from the right so as to support broadcasting. total_size = len(inputs) + len(res.output.shape) # Assumes only scalar indices. new_dims = {} for k, domain in inputs.items(): assert not domain.shape new_dims[k] = len(new_dims) - total_size # Use advanced indexing to construct a simultaneous substitution. index = [] for k, domain in res.inputs.items(): if k in subs: v = subs.get(k) if isinstance(v, Number): index.append(int(v.data)) else: # Permute and expand v.data to end up at new_dims. assert isinstance(v, Tensor) v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs)) assert isinstance(v, Tensor) v_shape = [1] * total_size for k2, size in zip(v.inputs, v.data.shape): v_shape[new_dims[k2]] = size index.append(v.data.reshape(tuple(v_shape))) else: # Construct a [:] slice for this preserved input. offset_from_right = -1 - new_dims[k] index.append(torch.arange(domain.dtype).reshape( (-1,) + (1,) * offset_from_right)) # Construct a [:] slice for the output. for i, size in enumerate(res.output.shape): offset_from_right = len(res.output.shape) - i - 1 index.append(torch.arange(size).reshape( (-1,) + (1,) * offset_from_right)) # the only difference from Tensor.eager_subs is here: # instead of indexing the rhs (lhs = rhs[index]), we index the lhs (lhs[index] = rhs) # unsqueeze to make broadcasting work src_inputs, src_data = src.inputs.copy(), src.data for k, v in res.inputs.items(): if k not in src.inputs and isinstance(subs[k], Number): src_inputs[k] = bint(1) src_data = src_data.unsqueeze(-1 - len(src.output.shape)) src = Tensor(src_data, src_inputs, src.output.dtype).align(tuple(res.inputs.keys())) data = res.data data[tuple(index)] = src.data return Tensor(data, inputs, res.dtype)
[docs]@adjoint_ops.register(Subs, ops.LogAddExpOp, ops.AddOp, GaussianMixture, GaussianMixture, tuple) def adjoint_subs_gaussianmixture_gaussianmixture(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable") # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) assert len(slices + renames) == len(subs) in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj.terms[0], arg.terms[0], subs)[arg.terms[0]] arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real') out_adj_int_inputs = OrderedDict((k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real') arg_real_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype == 'real') align_inputs = OrderedDict((k, v) for k, v in out_adj.terms[1].inputs.items() if v.dtype != 'real') align_inputs.update(arg_real_inputs) out_adj_info_vec, out_adj_precision = align_gaussian(align_inputs, out_adj.terms[1]) in_adj_info_vec = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, Tensor(out_adj_info_vec, out_adj_int_inputs), Tensor(arg.terms[1].info_vec, arg_int_inputs), slices).values())[0] in_adj_precision = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, Tensor(out_adj_precision, out_adj_int_inputs), Tensor(arg.terms[1].precision, arg_int_inputs), slices).values())[0] assert isinstance(in_adj_info_vec, Tensor) assert isinstance(in_adj_precision, Tensor) in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy()) in_adj = in_adj_gaussian + in_adj_discrete return {arg: in_adj}
@adjoint_ops.register(Subs, ops.LogAddExpOp, ops.AddOp, Gaussian, GaussianMixture, tuple) def adjoint_subs_gaussianmixture_discrete(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable") out_adj_int_inputs = OrderedDict((k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real') out_adj_ = out_adj + Tensor(out_adj.info_vec.new_zeros(out_adj.info_vec.shape[:-1]), out_adj_int_inputs) return {arg: adjoint_ops(Subs, adj_redop, adj_binop, out_adj_, arg, subs)[arg]}
[docs]@adjoint_ops.register(Subs, ops.LogAddExpOp, ops.AddOp, (GaussianMixture, Gaussian), Gaussian, tuple) def adjoint_subs_gaussian_gaussian(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable") arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real') arg_ = arg + Tensor(arg.info_vec.new_zeros(arg.info_vec.shape[:-1]), arg_int_inputs) return {arg: adjoint_ops(Subs, adj_redop, adj_binop, out_adj, arg_, subs)[arg_]}
[docs]@adjoint_ops.register(Subs, ops.LogAddExpOp, ops.AddOp, (Number, Tensor), GaussianMixture, tuple) def adjoint_subs_gaussianmixture_discrete(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable") # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real') zeros_like_out = Subs(Tensor(arg.terms[1].info_vec.new_full(arg.terms[1].info_vec.shape[:-1], ops.UNITS[adj_binop]), arg_int_inputs), slices) out_adj = adj_binop(out_adj, zeros_like_out) in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj, arg.terms[0], subs)[arg.terms[0]] # invert the slicing for the Gaussian term even though the message does not affect the values in_adj_info_vec = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, zeros_like_out, Tensor(arg.terms[1].info_vec, arg_int_inputs), slices).values())[0] in_adj_precision = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, zeros_like_out, Tensor(arg.terms[1].precision, arg_int_inputs), slices).values())[0] assert isinstance(in_adj_info_vec, Tensor) assert isinstance(in_adj_precision, Tensor) in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy()) in_adj = in_adj_gaussian + in_adj_discrete return {arg: in_adj}