Source code for funsor.factory

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import inspect
import typing
import warnings
from collections import OrderedDict
from functools import singledispatch

import makefun

from funsor.instrument import debug_logged
from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor
from funsor.util import as_callable


def _get_name(fn):
    return getattr(fn, "__name__", type(fn).__name__)


def _erase_types(fn):
    def result(*args):
        return fn(*args)

    result.__name__ = _get_name(fn)
    result.__module__ = fn.__module__
    return debug_logged(result)


class FreshMeta(type):
    def __getitem__(cls, fn):
        return Fresh(fn)


[docs]class Fresh(metaclass=FreshMeta): """ Type hint for :func:`make_funsor` decorated functions. This provides hints for fresh variables (names) and the return type. Examples:: Fresh[Real] # a constant known domain Fresh[lambda x: Array[x.dtype, x.shape[1:]] # args are Domains Fresh[lambda x, y: Bint[x.size + y.size]] :param callable fn: A lambda taking named arguments (in any order) which will be filled in with the domain of the similarly named funsor argument to the decorated function. This lambda should compute a desired resulting domain given domains of arguments. """ def __init__(self, fn): function = type(lambda: None) self.fn = fn if isinstance(fn, function) else lambda: fn self.args = inspect.getfullargspec(fn)[0] def __call__(self, **kwargs): return self.fn(*map(kwargs.__getitem__, self.args))
[docs]class Bound: """ Type hint for :func:`make_funsor` decorated functions. This provides hints for bound variables (names). """ pass
class ValueMeta(type): def __getitem__(cls, value_type): return Value(value_type) class Value(metaclass=ValueMeta): def __init__(self, value_type): if issubclass(value_type, Funsor): raise TypeError("Types cannot depend on Funsor values") self.value_type = value_type class HasMeta(type): def __getitem__(cls, bound): return Has(bound)
[docs]class Has(metaclass=HasMeta): """ Type hint for :func:`make_funsor` decorated functions. This hint asserts that a set of :class:`Bound` variables always appear in the ``.inputs`` of the annotated argument. For example, we could write a named ``matmul`` function that asserts that both arguments always contain the reduced input, and cannot be constant with respect to that input:: @make_funsor def MatMul( x: Has[{"i"}], y: Has[{"i"}], i: Bound, ) -> Fresh[lambda x: x]: return (x * y).reduce(ops.add, i) Here the string ``"i"`` in the annotations for ``x`` and ``y`` refer to the argument ``i`` of our ``MatMul`` function, which is known to be ``Bound`` (i.e it does not appear in the ``.inputs`` of evaluating ``Matmul(x, y, "i")``. .. warning :: This annotation is experimental and may be removed in the future. Note that because Funsor is inherently extensional, violating a `Has` constraint only raises a :class:`SyntaxWarning` rather than a full :class:`TypeError` and even then only under the :func:`~funsor.interpretations.reflect` interpretation. As such, :class:`Has` annotations should be used sparingly, reserved for cases where the programmer has complete control over the inputs to a function and knows that an argument will always depend on a bound variable, e.g. when writing one-off Funsor terms to describe custom layers in a neural network. :param set bound: A :class:`~builtins.set` of strings of names of :class:`Bound` arguments of a :func:`make_funsor` -decorated function. """ def __init__(self, bound): assert isinstance(bound, set) assert all(isinstance(v, str) for v in bound) self.bound = bound
def _get_dependent_args(fields, hints, args): return { name: arg if isinstance(hint, Value) else arg.output for name, arg, hint in zip(fields, args, hints) if hint in (Funsor, Bound) or isinstance(hint, (Has, Value)) }
[docs]def make_funsor(fn): """ Decorator to dynamically create a subclass of :class:`~funsor.terms.Funsor`, together with a single default eager pattern. This infers inputs, outputs, fresh, and bound variables from type hints follow the following convention: - Funsor inputs are typed :class:`~funsor.terms.Funsor`. - Bound variable inputs (names) are typed :class:`Bound`. - Fresh variable inputs (names) are typed :class:`Fresh` together with lambda to compute the dependent domain. - Ground value inputs (e.g. Python ints) are typed :class:`Value` together with their actual data type, e.g. ``Value[int]``. - The return value is typed :class:`Fresh` together with a lambda to compute the dependent return domain. For example to unflatten a single coordinate into a pair of coordinates we could define:: @make_funsor def Unflatten( x: Funsor, i: Bound, i_over_2: Fresh[lambda i: Bint[i.size // 2]], i_mod_2: Fresh[lambda: Bint[2]], ) -> Fresh[lambda x: x]: assert i.output.size % 2 == 0 return x(**{i.name: i_over_2 * Number(2, 3) + i_mod_2}) :param callable fn: A type annotated function of Funsors. :rtype: subclas of :class:`~funsor.terms.Funsor` """ input_types = typing.get_type_hints(as_callable(fn)) for name, hint in input_types.items(): if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))): raise TypeError(f"Invalid type hint {name}: {hint}") output_type = input_types.pop("return") hints = tuple(input_types.values()) class ResultMeta(FunsorMeta): def __call__(cls, *args): args = list(args) # Compute domains of bound variables. for i, (name, arg) in enumerate(zip(cls._ast_fields, args)): hint = input_types[name] if hint is Funsor or isinstance(hint, Has): # TODO support domains args[i] = to_funsor(arg) elif hint is Bound: for other in args: if isinstance(other, Funsor): domain = other.inputs.get(arg, None) if domain is not None: arg = to_funsor(arg, domain) if not isinstance(arg, Variable): raise ValueError(f"Cannot infer domain of {name}={arg}") args[i] = arg elif isinstance(hint, Value): if not isinstance(arg, hint.value_type): raise TypeError( f"invalid dependent value type: {arg}: {hint.value_type}" ) args[i] = arg # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) for i, (hint, arg) in enumerate(zip(hints, args)): if isinstance(hint, Fresh): domain = hint(**dependent_args) args[i] = to_funsor(arg, domain) return super().__call__(*args) @makefun.with_signature( "__init__({})".format(", ".join(["self"] + list(input_types))) ) def __init__(self, **kwargs): args = tuple(kwargs[k] for k in self._ast_fields) dependent_args = _get_dependent_args(self._ast_fields, hints, args) output = output_type(**dependent_args) inputs = OrderedDict() bound = {} for hint, arg, arg_name in zip(hints, args, self._ast_fields): if hint is Funsor: assert isinstance(arg, Funsor) inputs.update(arg.inputs) elif isinstance(hint, Has): assert isinstance(arg, Funsor) inputs.update(arg.inputs) for name in hint.bound: if kwargs[name] not in arg.input_vars: warnings.warn( f"Argument {arg_name} is missing bound variable {kwargs[name]} from argument {name}." f"Are you sure {name} will always appear in {arg_name}?", SyntaxWarning, ) for hint, arg in zip(hints, args): if hint is Bound: bound[arg.name] = inputs.pop(arg.name) for hint, arg in zip(hints, args): if isinstance(hint, Fresh): for k, d in arg.inputs.items(): if k not in bound: inputs[k] = d fresh = frozenset() Funsor.__init__(self, inputs, output, fresh, bound) for name, arg in zip(self._ast_fields, args): setattr(self, name, arg) def _alpha_convert(self, alpha_subs): alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} return Funsor._alpha_convert(self, alpha_subs) name = _get_name(fn) ResultMeta.__name__ = f"{name}Meta" Result = ResultMeta( name, (Funsor,), {"__init__": __init__, "_alpha_convert": _alpha_convert} ) pattern = (Result,) + tuple( _hint_to_pattern(input_types[k]) for k in Result._ast_fields ) eager.register(*pattern)(_erase_types(fn)) return Result
@singledispatch def _hint_to_pattern(t): return Funsor @_hint_to_pattern.register(Value) def _(t): return t.value_type