Source code for funsor.interpreter

import functools
import inspect
import os
import re
import types
from collections import OrderedDict
from contextlib import contextmanager
from functools import singledispatch

import numpy
import torch

from funsor.domains import Domain
from funsor.ops import Op
from funsor.registry import KeyedRegistry

_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_DEBUG = int(os.environ.get("FUNSOR_DEBUG", 0))
_STACK_SIZE = 0

_INTERPRETATION = None  # To be set later in funsor.terms
_USE_TCO = int(os.environ.get("FUNSOR_USE_TCO", 0))

_GENSYM_COUNTER = 0


def _indent():
    result = u'    \u2502' * (_STACK_SIZE // 4 + 3)
    return result[:_STACK_SIZE]


if _DEBUG:
    class DebugLogged(object):
        def __init__(self, fn):
            self.fn = fn
            while isinstance(fn, functools.partial):
                fn = fn.func
            path = inspect.getabsfile(fn)
            lineno = inspect.getsourcelines(fn)[1]
            self._message = "{} file://{} {}".format(fn.__name__, path, lineno)

        def __call__(self, *args, **kwargs):
            global _STACK_SIZE
            print(_indent() + self._message)
            _STACK_SIZE += 1
            try:
                return self.fn(*args, **kwargs)
            finally:
                _STACK_SIZE -= 1

        @property
        def register(self):
            return self.fn.register

    def debug_logged(fn):
        if isinstance(fn, DebugLogged):
            return fn
        return DebugLogged(fn)
else:
    def debug_logged(fn):
        return fn


def _classname(cls):
    return getattr(cls, "classname", cls.__name__)


class Interpreter:
    @property
    def __call__(self):
        return _INTERPRETATION


def debug_interpret(cls, *args):
    global _STACK_SIZE
    indent = _indent()
    if _DEBUG > 1:
        typenames = [_classname(cls)] + [_classname(type(arg)) for arg in args]
    else:
        typenames = [cls.__name__] + [type(arg).__name__ for arg in args]
    print(indent + ' '.join(typenames))

    _STACK_SIZE += 1
    try:
        result = _INTERPRETATION(cls, *args)
    finally:
        _STACK_SIZE -= 1

    if _DEBUG > 1:
        result_str = re.sub('\n', '\n          ' + indent, str(result))
    else:
        result_str = type(result).__name__
    print(indent + '-> ' + result_str)
    return result


interpret = debug_interpret if _DEBUG else Interpreter()


[docs]def set_interpretation(new): assert callable(new) global _INTERPRETATION _INTERPRETATION = new
[docs]@contextmanager def interpretation(new): assert callable(new) global _INTERPRETATION old = _INTERPRETATION try: _INTERPRETATION = new yield finally: _INTERPRETATION = old
@singledispatch def recursion_reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. This interpreter uses the Python stack and is subject to the recursion limit. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError """ raise ValueError(type(x)) # We need to register this later in terms.py after declaring Funsor. # reinterpret.register(Funsor) @debug_logged def reinterpret_funsor(x): return _INTERPRETATION(type(x), *map(recursion_reinterpret, x._ast_values)) @recursion_reinterpret.register(str) @recursion_reinterpret.register(int) @recursion_reinterpret.register(float) @recursion_reinterpret.register(type) @recursion_reinterpret.register(functools.partial) @recursion_reinterpret.register(types.FunctionType) @recursion_reinterpret.register(types.BuiltinFunctionType) @recursion_reinterpret.register(numpy.ndarray) @recursion_reinterpret.register(torch.Tensor) @recursion_reinterpret.register(torch.nn.Module) @recursion_reinterpret.register(Domain) @recursion_reinterpret.register(Op) def recursion_reinterpret_ground(x): return x @recursion_reinterpret.register(tuple) @debug_logged def recursion_reinterpret_tuple(x): return tuple(map(recursion_reinterpret, x)) @recursion_reinterpret.register(frozenset) @debug_logged def recursion_reinterpret_frozenset(x): return frozenset(map(recursion_reinterpret, x)) @recursion_reinterpret.register(dict) @debug_logged def recursion_reinterpret_dict(x): return {key: recursion_reinterpret(value) for key, value in x.items()} @recursion_reinterpret.register(OrderedDict) @debug_logged def recursion_reinterpret_ordereddict(x): return OrderedDict((key, recursion_reinterpret(value)) for key, value in x.items()) @singledispatch def children(x): raise ValueError(type(x)) # has to be registered in terms.py def children_funsor(x): return x._ast_values @children.register(tuple) @children.register(frozenset) def _children_tuple(x): return x @children.register(dict) @children.register(OrderedDict) def _children_tuple(x): return x.values() @children.register(str) @children.register(int) @children.register(float) @children.register(type) @children.register(functools.partial) @children.register(types.FunctionType) @children.register(types.BuiltinFunctionType) @children.register(numpy.ndarray) @children.register(torch.Tensor) @children.register(torch.nn.Module) @children.register(Domain) @children.register(Op) def _children_ground(x): return () def is_atom(x): if isinstance(x, (tuple, frozenset)) and not isinstance(x, Domain): return len(x) == 0 or all(is_atom(c) for c in x) return isinstance(x, ( int, str, float, type, functools.partial, types.FunctionType, types.BuiltinFunctionType, torch.Tensor, torch.nn.Module, numpy.ndarray, Domain, Op )) def gensym(x=None): global _GENSYM_COUNTER _GENSYM_COUNTER += 1 sym = _GENSYM_COUNTER if x is not None: if isinstance(x, str): return x + "_" + str(sym) return id(x) return "V" + str(sym) def stack_reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. This interpreter uses an explicit stack and no recursion but is much slower. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError """ x_name = gensym(x) node_vars = {x_name: x} node_names = {x: x_name} env = {} stack = [(x_name, x)] parent_to_children = OrderedDict() child_to_parents = OrderedDict() while stack: h_name, h = stack.pop(0) parent_to_children[h_name] = [] for c in children(h): if c in node_names: c_name = node_names[c] else: c_name = gensym(c) node_names[c] = c_name node_vars[c_name] = c stack.append((c_name, c)) parent_to_children.setdefault(h_name, []).append(c_name) child_to_parents.setdefault(c_name, []).append(h_name) children_counts = OrderedDict((k, len(v)) for k, v in parent_to_children.items()) leaves = [name for name, count in children_counts.items() if count == 0] while leaves: h_name = leaves.pop(0) if h_name in child_to_parents: for parent in child_to_parents[h_name]: children_counts[parent] -= 1 if children_counts[parent] == 0: leaves.append(parent) h = node_vars[h_name] if is_atom(h): env[h_name] = h elif isinstance(h, (tuple, frozenset)): env[h_name] = type(h)( env[c_name] for c_name in parent_to_children[h_name]) else: env[h_name] = _INTERPRETATION( type(h), *(env[c_name] for c_name in parent_to_children[h_name])) return env[x_name]
[docs]def reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError """ if _USE_TCO: return stack_reinterpret(x) else: return recursion_reinterpret(x)
[docs]def dispatched_interpretation(fn): """ Decorator to create a dispatched interpretation function. """ registry = KeyedRegistry(default=lambda *args: None) if _DEBUG: fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn)) else: fn.register = registry.register fn.dispatch = registry.dispatch return fn
[docs]class PatternMissingError(NotImplementedError): def __str__(self): return f"{super().__str__()}\nThis is most likely due to a missing pattern."
__all__ = [ 'PatternMissingError', 'dispatched_interpretation', 'interpret', 'interpretation', 'reinterpret', 'set_interpretation', ]