Source code for funsor.interpreter

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
import inspect
import os
import re
import types
from collections import OrderedDict, namedtuple
from contextlib import contextmanager
from functools import singledispatch

import numpy as np

from funsor.domains import ArrayType
from funsor.ops import Op, is_numeric_array
from funsor.registry import KeyedRegistry
from funsor.util import is_nn_module

_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_DEBUG = int(os.environ.get("FUNSOR_DEBUG", 0))
_STACK_SIZE = 0

_INTERPRETATION = None  # To be set later in funsor.terms
_USE_TCO = int(os.environ.get("FUNSOR_USE_TCO", 0))

_GENSYM_COUNTER = 0


def _indent():
    result = u'    \u2502' * (_STACK_SIZE // 4 + 3)
    return result[:_STACK_SIZE]


if _DEBUG:
    class DebugLogged(object):
        def __init__(self, fn):
            self.fn = fn
            while isinstance(fn, functools.partial):
                fn = fn.func
            path = inspect.getabsfile(fn)
            lineno = inspect.getsourcelines(fn)[1]
            self._message = "{} file://{} {}".format(fn.__name__, path, lineno)

        def __call__(self, *args, **kwargs):
            global _STACK_SIZE
            print(_indent() + self._message)
            _STACK_SIZE += 1
            try:
                return self.fn(*args, **kwargs)
            finally:
                _STACK_SIZE -= 1

        @property
        def register(self):
            return self.fn.register

    def debug_logged(fn):
        if isinstance(fn, DebugLogged):
            return fn
        return DebugLogged(fn)
else:
    def debug_logged(fn):
        return fn


def _classname(cls):
    return getattr(cls, "classname", cls.__name__)


class Interpreter:
    @property
    def __call__(self):
        return _INTERPRETATION


def debug_interpret(cls, *args):
    global _STACK_SIZE
    indent = _indent()
    if _DEBUG > 1:
        typenames = [_classname(cls)] + [_classname(type(arg)) for arg in args]
    else:
        typenames = [cls.__name__] + [type(arg).__name__ for arg in args]
    print(indent + ' '.join(typenames))

    _STACK_SIZE += 1
    try:
        result = _INTERPRETATION(cls, *args)
    finally:
        _STACK_SIZE -= 1

    if _DEBUG > 1:
        result_str = re.sub('\n', '\n          ' + indent, str(result))
    else:
        result_str = type(result).__name__
    print(indent + '-> ' + result_str)
    return result


interpret = debug_interpret if _DEBUG else Interpreter()


[docs]def set_interpretation(new): assert callable(new) global _INTERPRETATION _INTERPRETATION = new
[docs]@contextmanager def interpretation(new): assert callable(new) global _INTERPRETATION old = _INTERPRETATION new = InterpreterStack(new, old) try: _INTERPRETATION = new yield finally: _INTERPRETATION = old
@singledispatch def recursion_reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. This interpreter uses the Python stack and is subject to the recursion limit. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError """ raise ValueError(type(x)) # We need to register this later in terms.py after declaring Funsor. # reinterpret.register(Funsor) @debug_logged def reinterpret_funsor(x): return _INTERPRETATION(type(x), *map(recursion_reinterpret, x._ast_values)) _ground_types = ( str, int, float, type, functools.partial, types.FunctionType, types.BuiltinFunctionType, ArrayType, Op, np.generic, np.ndarray, np.ufunc, ) for t in _ground_types: @recursion_reinterpret.register(t) def recursion_reinterpret_ground(x): return x @recursion_reinterpret.register(tuple) @debug_logged def recursion_reinterpret_tuple(x): return tuple(map(recursion_reinterpret, x)) @recursion_reinterpret.register(frozenset) @debug_logged def recursion_reinterpret_frozenset(x): return frozenset(map(recursion_reinterpret, x)) @recursion_reinterpret.register(dict) @debug_logged def recursion_reinterpret_dict(x): return {key: recursion_reinterpret(value) for key, value in x.items()} @recursion_reinterpret.register(OrderedDict) @debug_logged def recursion_reinterpret_ordereddict(x): return OrderedDict((key, recursion_reinterpret(value)) for key, value in x.items()) @singledispatch def children(x): raise ValueError(type(x)) # has to be registered in terms.py def children_funsor(x): return x._ast_values @children.register(tuple) @children.register(frozenset) def _children_tuple(x): return x @children.register(dict) @children.register(OrderedDict) def _children_tuple(x): return x.values() for t in _ground_types: @children.register(t) def _children_ground(x): return () def is_atom(x): if isinstance(x, (tuple, frozenset)): return len(x) == 0 or all(is_atom(c) for c in x) return isinstance(x, _ground_types) or is_numeric_array(x) or is_nn_module(x) def gensym(x=None): global _GENSYM_COUNTER _GENSYM_COUNTER += 1 sym = _GENSYM_COUNTER if x is not None: if isinstance(x, str): return x + "_" + str(sym) return id(x) return "V" + str(sym) def stack_reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. This interpreter uses an explicit stack and no recursion but is much slower. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError """ x_name = gensym(x) node_vars = {x_name: x} node_names = {x: x_name} env = {} stack = [(x_name, x)] parent_to_children = OrderedDict() child_to_parents = OrderedDict() while stack: h_name, h = stack.pop(0) parent_to_children[h_name] = [] for c in children(h): if c in node_names: c_name = node_names[c] else: c_name = gensym(c) node_names[c] = c_name node_vars[c_name] = c stack.append((c_name, c)) parent_to_children.setdefault(h_name, []).append(c_name) child_to_parents.setdefault(c_name, []).append(h_name) children_counts = OrderedDict((k, len(v)) for k, v in parent_to_children.items()) leaves = [name for name, count in children_counts.items() if count == 0] while leaves: h_name = leaves.pop(0) if h_name in child_to_parents: for parent in child_to_parents[h_name]: children_counts[parent] -= 1 if children_counts[parent] == 0: leaves.append(parent) h = node_vars[h_name] if is_atom(h): env[h_name] = h elif isinstance(h, (tuple, frozenset)): env[h_name] = type(h)( env[c_name] for c_name in parent_to_children[h_name]) else: env[h_name] = _INTERPRETATION( type(h), *(env[c_name] for c_name in parent_to_children[h_name])) return env[x_name]
[docs]def reinterpret(x): r""" Overloaded reinterpretation of a deferred expression. This handles a limited class of expressions, raising ``ValueError`` in unhandled cases. :param x: An input, typically involving deferred :class:`~funsor.terms.Funsor` s. :type x: A funsor or data structure holding funsors. :return: A reinterpreted version of the input. :raises: ValueError """ if _USE_TCO: return stack_reinterpret(x) else: return recursion_reinterpret(x)
class InterpreterStack(namedtuple("InterpreterStack", ["default", "fallback"])): def __call__(self, cls, *args): for interpreter in self: result = interpreter(cls, *args) if result is not None: return result
[docs]def dispatched_interpretation(fn): """ Decorator to create a dispatched interpretation function. """ registry = KeyedRegistry(default=lambda *args: None) if _DEBUG: fn.register = lambda *args: lambda fn: registry.register(*args)(debug_logged(fn)) else: fn.register = registry.register fn.dispatch = registry.dispatch return fn
class StatefulInterpretationMeta(type): def __init__(cls, name, bases, dct): super().__init__(name, bases, dct) cls.registry = KeyedRegistry(default=lambda *args: None) cls.dispatch = cls.registry.dispatch
[docs]class StatefulInterpretation(metaclass=StatefulInterpretationMeta): """ Base class for interpreters with instance-dependent state or parameters. Example usage:: class MyInterpretation(StatefulInterpretation): def __init__(self, my_param): self.my_param = my_param @MyInterpretation.register(...) def my_impl(interpreter_state, cls, *args): my_param = interpreter_state.my_param ... with interpretation(MyInterpretation(my_param=0.1)): ... """ def __call__(self, cls, *args): return self.dispatch(cls, *args)(self, *args) if _DEBUG: @classmethod def register(cls, *args): return lambda fn: cls.registry.register(*args)(debug_logged(fn)) else:
[docs] @classmethod def register(cls, *args): return cls.registry.register(*args)
[docs]class PatternMissingError(NotImplementedError): def __str__(self): return "{}\nThis is most likely due to a missing pattern.".format(super().__str__())
__all__ = [ 'PatternMissingError', 'StatefulInterpretation', 'dispatched_interpretation', 'interpret', 'interpretation', 'reinterpret', 'set_interpretation', ]