Source code for funsor.terms

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
import itertools
import numbers
import typing
import warnings
from collections import OrderedDict, namedtuple
from functools import reduce, singledispatch
from weakref import WeakValueDictionary

from multipledispatch import dispatch

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.domains import (
    Array,
    Bint,
    BintType,
    Domain,
    Product,
    ProductDomain,
    Real,
    find_domain,
)
from funsor.interpretations import (
    Interpretation,
    die,
    eager,
    lazy,
    moment_matching,
    reflect,
    sequential,
)
from funsor.interpreter import PatternMissingError, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis, parse_slice
from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS
from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin
from funsor.util import getargspec, lazy_property, pretty, quote, register_pprint

from . import instrument, interpreter, ops

_PREFIX = {k: v for v, k, _ in PREFIX_OPERATORS}
_INFIX = {k: v for v, k, _ in INFIX_OPERATORS}


# FIXME this can lead to linear nesting of interpretations
# when used in combination with alpha_convert and optimize.
# See failing example at https://github.com/pyro-ppl/funsor/pull/414
class SubstituteInterpretation(Interpretation):
    def __init__(self, subs, base_interpretation):
        super().__init__("subs")
        self.subs = subs
        self.base_interpretation = base_interpretation
        assert isinstance(subs, tuple)
        assert all(isinstance(v, Funsor) for k, v in subs)

    @property
    def is_total(self):
        return self.base_interpretation.is_total

    def interpret(self, cls, *args):
        with self.base_interpretation:
            expr = cls(*args)
            fresh_subs = tuple((k, v) for k, v in self.subs if k in expr.fresh)
            if fresh_subs:
                expr = instrument.debug_logged(expr.eager_subs)(fresh_subs)
            if instrument.PROFILE:
                instrument.COUNTERS["interpretation"]["substitute"] += 1
            return expr


def substitute(expr, subs):
    if isinstance(subs, (dict, OrderedDict)):
        subs = tuple(subs.items())
    support = frozenset(k for k, v in subs)

    def stop(x):
        if interpreter.is_atom(x):
            return True
        if isinstance(x, Funsor) and support.isdisjoint(x.inputs):
            return True
        return False

    if stop(expr):
        return expr

    env = interpreter.anf(expr, stop)

    with SubstituteInterpretation(subs, interpreter.get_interpretation()):
        for key, value in env.items():
            args = tuple(
                c if interpreter.is_atom(c) else env.get(c, c)
                for c in interpreter.children(value)
            )
            if isinstance(value, (tuple, frozenset)):  # TODO absorb this into interpret
                env[key] = type(value)(args)
            else:
                env[key] = type(value)(*args)
    return env[expr]


def _alpha_mangle(expr):
    """
    Rename bound variables in expr to avoid conflict with any free variables.

    FIXME this does not avoid conflict with other bound variables.
    """
    alpha_subs = {
        name: interpreter.gensym(name + "__BOUND")
        for name in expr.bound
        if "__BOUND" not in name
    }
    if not alpha_subs:
        return expr

    ast_values = instrument.debug_logged(expr._alpha_convert)(alpha_subs)
    return reflect.interpret(type(expr), *ast_values)


@reflect.set_callable
def reflect(cls, *args, **kwargs):
    """
    Construct a funsor, populate ``._ast_values``, and cons hash.
    This is the only interpretation allowed to construct funsors.
    """
    if len(args) > len(cls._ast_fields):
        # handle varargs
        new_args = tuple(args[: len(cls._ast_fields) - 1]) + (
            args[len(cls._ast_fields) - 1 - len(args) :],
        )
        assert len(new_args) == len(cls._ast_fields)
        _, args = args, new_args

    cache_key = reflect.make_hash_key(cls, *args)
    if cache_key in cls._cons_cache:
        return cls._cons_cache[cache_key]

    arg_types = tuple(map(deep_type, args))
    cls_specific = get_origin(cls)[arg_types]
    result = super(FunsorMeta, cls_specific).__call__(*args)
    result._ast_values = args

    if instrument.PROFILE:
        size, depth, width = _get_ast_stats(result)
        instrument.COUNTERS["ast_size"][size] += 1
        instrument.COUNTERS["ast_depth"][depth] += 1
        classname = get_origin(cls).__name__
        instrument.COUNTERS["funsor"][classname] += 1
        instrument.COUNTERS[classname][width] += 1

    # alpha-convert eagerly upon binding any variable
    result = _alpha_mangle(result)

    cls._cons_cache[cache_key] = result
    return result


class FunsorMeta(GenericTypeMeta):
    """
    Metaclass for Funsors to perform four independent tasks:

    1.  Fill in default kwargs and convert kwargs to args before deferring to a
        nonstandard interpretation. This allows derived metaclasses to fill in
        defaults and do type conversion, thereby simplifying logic of
        interpretations.
    2.  Ensure each Funsor class has an attribute ``._ast_fields`` describing
        its input args and each Funsor instance has an attribute
        ``._ast_values`` with values corresponding to its input args. This
        allows the instance to be reflectively reconstructed under a different
        interpretation, and is used by :func:`funsor.interpreter.reinterpret`.
    3.  Cons-hash construction, so that repeatedly calling the constructor
        with identical args will produce the same object. This enables cheap
        syntactic equality testing using the ``is`` operator, which is
        important both for hashing (e.g. for memoizing funsor functions)
        and for unit testing, since ``.__eq__()`` is overloaded with
        elementwise semantics. Cons hashing differs from memoization in that
        it incurs no memory overhead beyond the cons hash dict.
    4.  Support subtyping with parameters for pattern matching, e.g. Number[int, int].
    """

    def __init__(cls, name, bases, dct):
        super().__init__(name, bases, dct)
        register_pprint(cls)
        if not cls.__args__:
            cls._ast_fields = getargspec(cls.__init__)[0][1:]
            cls._cons_cache = WeakValueDictionary()

    def __getitem__(cls, arg_types):
        if not isinstance(arg_types, tuple):
            arg_types = (arg_types,)
        assert len(arg_types) == len(
            cls._ast_fields
        ), "Must provide exactly one type per subexpression"
        return super().__getitem__(arg_types)

    def __call__(cls, *args, **kwargs):
        if cls.__args__:
            cls = cls.__origin__

        # Convert kwargs to args.
        if kwargs:
            args = list(args)
            for name in cls._ast_fields[len(args) :]:
                args.append(kwargs.pop(name))
            assert not kwargs, kwargs
            args = tuple(args)

        return interpret(cls, *args)

    @lazy_property
    def classname(cls):
        return repr(cls)


def _convert_reduced_vars(reduced_vars, inputs):
    """
    Helper to convert the reduced_vars arg of ``.reduce()`` and friends.

    :param reduced_vars:
    :type reduced_vars: str, Variable, or set or frozenset thereof.
    :returns: A frozenset of reduced variables.
    :rtype: frozenset of :class:`Variable`
    """
    # Avoid copying if arg is of correct type.
    if isinstance(reduced_vars, frozenset):
        if all(isinstance(var, Variable) for var in reduced_vars):
            return reduced_vars

    if isinstance(reduced_vars, (str, Variable)):
        reduced_vars = {reduced_vars}
    assert isinstance(reduced_vars, (frozenset, set))
    assert all(isinstance(var, (str, Variable)) for var in reduced_vars)
    return frozenset(
        Variable(var, inputs[var]) if isinstance(var, str) else var
        for var in reduced_vars
    )


[docs]class Funsor(object, metaclass=FunsorMeta): """ Abstract base class for immutable functional tensors. Concrete derived classes must implement ``__init__()`` methods taking hashable ``*args`` and no optional ``**kwargs`` so as to support cons hashing. Derived classes with ``.fresh`` variables must implement an :meth:`eager_subs` method. Derived classes with ``.bound`` variables must implement an :meth:`_alpha_convert` method. :param OrderedDict inputs: A mapping from input name to domain. This can be viewed as a typed context or a mapping from free variables to domains. :param Domain output: An output domain. """ def __init__(self, inputs, output, fresh=None, bound=None): fresh = frozenset() if fresh is None else fresh bound = {} if bound is None else bound assert isinstance(inputs, OrderedDict) for name, input_ in inputs.items(): assert isinstance(name, str) assert isinstance(input_, Domain) assert isinstance(output, Domain) assert getattr(output, "is_concrete", True) assert isinstance(fresh, frozenset) assert isinstance(bound, dict) super(Funsor, self).__init__() self.inputs = inputs self.output = output self.fresh = fresh self.bound = bound @property def dtype(self): return self.output.dtype @property def shape(self): return self.output.shape
[docs] @lazy_property def input_vars(self): return frozenset(Variable(k, v) for k, v in self.inputs.items())
def __copy__(self): return self def __reduce__(self): return type(self).__origin__, self._ast_values def __hash__(self): return id(self) @lazy_property def __annotations__(self): type_hints = dict(self.inputs) type_hints["return"] = self.output return type_hints def __repr__(self): try: ast_values = self._ast_values except AttributeError: # E.g. when printing errors during __init__, before ._ast_values is set. return f"{type(self).__name__}(...)" return "{}({})".format(type(self).__name__, ", ".join(map(repr, ast_values))) def __str__(self): return "{}({})".format( type(self).__name__, ", ".join(map(str, self._ast_values)) )
[docs] def quote(self): return quote(self)
[docs] def pretty(self, *args, **kwargs): return pretty(self, *args, **kwargs)
def __contains__(self, item): raise TypeError def _alpha_convert(self, alpha_subs): """ Rename bound variables while preserving all free variables. """ # Substitute all funsor values. # Subclasses must handle string conversion. assert set(alpha_subs).issubset(self.bound) return tuple(substitute(v, alpha_subs) for v in self._ast_values) def __call__(self, *args, **kwargs): """ Partially evaluates this funsor by substituting dimensions. """ # Eagerly restrict to this funsor's inputs. subs = OrderedDict(zip(self.inputs, args)) for k in self.inputs: if k in kwargs: subs[k] = kwargs[k] return Subs(self, tuple(subs.items())) def __bool__(self): if self.inputs or self.output.shape: raise ValueError( "bool value of Funsor with more than one value is ambiguous" ) raise NotImplementedError def __nonzero__(self): return self.__bool__() def __len__(self): if not self.output.shape: raise ValueError("Funsor with empty shape has no len()") return self.output.shape[0] def __iter__(self): for i in range(len(self)): yield self[i]
[docs] def item(self): if self.inputs or self.output.shape: raise ValueError( "only one element Funsors can be converted to Python scalars" ) raise NotImplementedError
@property def requires_grad(self): return False
[docs] def reduce(self, op, reduced_vars=None): """ Reduce along all or a subset of inputs. :param op: A reduction operation. :type op: ~funsor.ops.AssociativeOp or ~funsor.ops.ReductionOp :param reduced_vars: An optional input name or set of names to reduce. If unspecified, all inputs will be reduced. :type reduced_vars: str, Variable, or set or frozenset thereof. """ assert isinstance(op, (AssociativeOp, ops.ReductionOp)) # Eagerly convert reduced_vars to appropriate things. if reduced_vars is None: # Empty reduced_vars means "reduce over everything". reduced_vars = frozenset(Variable(k, v) for k, v in self.inputs.items()) else: reduced_vars = _convert_reduced_vars(reduced_vars, self.inputs) assert isinstance(reduced_vars, frozenset), reduced_vars # Attempt to convert ReductionOp to AssociativeOp. if isinstance(op, ops.ReductionOp): if isinstance(op, ops.MeanOp): reduced_vars &= self.input_vars if not reduced_vars: return self scale = 1 / reduce(ops.mul, [v.output.size for v in reduced_vars], 1) return self.reduce(ops.add, reduced_vars) * scale if isinstance(op, ops.VarOp): diff = self - self.reduce(ops.mean, reduced_vars) return (diff * diff).reduce(ops.mean, reduced_vars) if isinstance(op, ops.StdOp): return self.reduce(ops.var, reduced_vars).sqrt() raise NotImplementedError(f"Unsupported reduction op: {op}") assert isinstance(op, AssociativeOp) if not reduced_vars: return self return Reduce(op, self, reduced_vars)
[docs] def approximate(self, op, guide, approx_vars=None): """ Approximate wrt and all or a subset of inputs. :param AssociativeOp op: A reduction operation. :param Funsor guide: A guide funsor (e.g. a proposal distribution). :param approx_vars: An optional input name or set of names to reduce. If unspecified, all inputs will be reduced. :type approx_vars: str, Variable, or set or frozenset thereof. """ assert isinstance(op, AssociativeOp) assert self.output == Real assert guide.output == self.output # Eagerly convert approx_vars to appropriate things. inputs = self.inputs.copy() inputs.update(guide.inputs) input_vars = self.input_vars | guide.input_vars if approx_vars is None: # Empty approx_vars means "approximate everything". approx_vars = input_vars else: approx_vars = _convert_reduced_vars(approx_vars, inputs) approx_vars &= input_vars # Drop unrelated vars. if not approx_vars: return self # exact return Approximate(op, self, guide, approx_vars)
[docs] def sample(self, sampled_vars, sample_inputs=None, rng_key=None): """ Create a Monte Carlo approximation to this funsor by replacing functions of ``sampled_vars`` with :class:`~funsor.delta.Delta` s. The result is a :class:`Funsor` with the same ``.inputs`` and ``.output`` as the original funsor (plus ``sample_inputs`` if provided), so that self can be replaced by the sample in expectation computations:: y = x.sample(sampled_vars) assert y.inputs == x.inputs assert y.output == x.output exact = (x.exp() * integrand).reduce(ops.add) approx = (y.exp() * integrand).reduce(ops.add) If ``sample_inputs`` is provided, this creates a batch of samples. :param sampled_vars: A set of input variables to sample. :type sampled_vars: str, Variable, or set or frozenset thereof. :param OrderedDict sample_inputs: An optional mapping from variable name to :class:`~funsor.domains.Domain` over which samples will be batched. :param rng_key: a PRNG state to be used by JAX backend to generate random samples :type rng_key: None or JAX's random.PRNGKey """ assert self.output == Real sampled_vars = _convert_reduced_vars(sampled_vars, self.inputs) sampled_vars = frozenset(v.name for v in sampled_vars) assert isinstance(sampled_vars, frozenset) if sample_inputs is None: sample_inputs = OrderedDict() assert isinstance(sample_inputs, OrderedDict) if sampled_vars.isdisjoint(self.inputs): return self result = instrument.debug_logged(self._sample)( sampled_vars, sample_inputs, rng_key ) return result
def _sample(self, sampled_vars, sample_inputs, rng_key): """ Internal method to draw samples. This should be overridden by subclasses. """ assert self.output == Real assert isinstance(sampled_vars, frozenset) assert isinstance(sample_inputs, OrderedDict) if sampled_vars.isdisjoint(self.inputs): return self raise ValueError("Cannot sample from a {}".format(type(self).__name__))
[docs] def align(self, names): """ Align this funsor to match given ``names``. This is mainly useful in preparation for extracting ``.data`` of a :class:`funsor.tensor.Tensor`. :param tuple names: A tuple of strings representing all names but in a new order. :return: A permuted funsor equivalent to self. :rtype: Funsor """ assert isinstance(names, tuple) if not names or names == tuple(self.inputs): return self return Align(self, names)
[docs] def eager_subs(self, subs): """ Internal substitution function. This relies on the user-facing :meth:`__call__` method to coerce non-Funsors to Funsors. Once all inputs are Funsors, :meth:`eager_subs` implementations can recurse to call :class:`Subs`. """ return None # defer to default implementation
[docs] def eager_unary(self, op): return None # defer to default implementation
[docs] def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) if not reduced_vars: return self return None # defer to default implementation
[docs] def sequential_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) if not reduced_vars: return self # Try to sum out integer scalars. This is mainly useful for testing, # since reduction is more efficiently implemented by Tensor. eager_vars = [] lazy_vars = [] for k in reduced_vars: if isinstance(self.inputs[k].dtype, int) and not self.inputs[k].shape: eager_vars.append(k) else: lazy_vars.append(k) if eager_vars: result = None for values in itertools.product(*(self.inputs[k] for k in eager_vars)): subs = dict(zip(eager_vars, values)) result = self(**subs) if result is None else op(result, self(**subs)) if lazy_vars: result = Reduce(op, result, frozenset(lazy_vars)) return result return None # defer to default implementation
[docs] def moment_matching_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) if not reduced_vars: return self return None # defer to default implementation
# The following methods conform to a standard array/tensor interface. def __invert__(self): return Unary(ops.invert, self) def __pos__(self): return Unary(ops.pos, self) def __neg__(self): return Unary(ops.neg, self)
[docs] def abs(self): return Unary(ops.abs, self)
[docs] def atanh(self): return Unary(ops.atanh, self)
[docs] def sqrt(self): return Unary(ops.sqrt, self)
[docs] def exp(self): return Unary(ops.exp, self)
[docs] def log(self): return Unary(ops.log, self)
[docs] def log1p(self): return Unary(ops.log1p, self)
[docs] def sigmoid(self): return Unary(ops.sigmoid, self)
[docs] def tanh(self): return Unary(ops.tanh, self)
[docs] def reshape(self, shape): return Unary(ops.ReshapeOp(shape), self)
# The following reductions are treated as Unary ops because they # reduce over output shape while preserving all inputs. # To reduce over inputs, instead call .reduce(op, reduced_vars).
[docs] def all(self, axis=None, keepdims=False): return Unary(ops.AllOp(axis, keepdims), self)
[docs] def any(self, axis=None, keepdims=False): return Unary(ops.AnyOp(axis, keepdims), self)
[docs] def argmax(self, axis=None, keepdims=False): return Unary(ops.ArgmaxOp(axis, keepdims), self)
[docs] def argmin(self, axis=None, keepdims=False): return Unary(ops.ArgminOp(axis, keepdims), self)
[docs] def max(self, axis=None, keepdims=False): return Unary(ops.AmaxOp(axis, keepdims), self)
[docs] def min(self, axis=None, keepdims=False): return Unary(ops.AminOp(axis, keepdims), self)
[docs] def sum(self, axis=None, keepdims=False): return Unary(ops.SumOp(axis, keepdims), self)
[docs] def prod(self, axis=None, keepdims=False): return Unary(ops.ProdOp(axis, keepdims), self)
[docs] def logsumexp(self, axis=None, keepdims=False): return Unary(ops.LogsumexpOp(axis, keepdims), self)
[docs] def mean(self, axis=None, keepdims=False): return Unary(ops.MeanOp(axis, keepdims), self)
[docs] def std(self, axis=None, ddof=0, keepdims=False): return Unary(ops.StdOp(axis, ddof, keepdims), self)
[docs] def var(self, axis=None, ddof=0, keepdims=False): return Unary(ops.VarOp(axis, ddof, keepdims), self)
def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) def __radd__(self, other): return Binary(ops.add, self, to_funsor(other)) def __sub__(self, other): return Binary(ops.sub, self, to_funsor(other)) def __rsub__(self, other): return Binary(ops.sub, to_funsor(other), self) def __mul__(self, other): return Binary(ops.mul, self, to_funsor(other)) def __rmul__(self, other): return Binary(ops.mul, self, to_funsor(other)) def __truediv__(self, other): return Binary(ops.truediv, self, to_funsor(other)) def __rtruediv__(self, other): return Binary(ops.truediv, to_funsor(other), self) def __floordiv__(self, other): return Binary(ops.floordiv, self, to_funsor(other)) def __rfloordiv__(self, other): return Binary(ops.floordiv, to_funsor(other), self) def __matmul__(self, other): return Binary(ops.matmul, self, to_funsor(other)) def __rmatmul__(self, other): return Binary(ops.matmul, to_funsor(other), self) def __mod__(self, other): return Binary(ops.mod, self, to_funsor(other)) def __rmod__(self, other): return Binary(ops.mod, to_funsor(other), self) def __lshift__(self, other): return Binary(ops.lshift, self, to_funsor(other)) def __rlshift__(self, other): return Binary(ops.lshift, to_funsor(other), self) def __rshift__(self, other): return Binary(ops.rshift, self, to_funsor(other)) def __rrshift__(self, other): return Binary(ops.rshift, to_funsor(other), self) def __pow__(self, other): return Binary(ops.pow, self, to_funsor(other)) def __rpow__(self, other): return Binary(ops.pow, to_funsor(other), self) def __and__(self, other): return Binary(ops.and_, self, to_funsor(other)) def __rand__(self, other): return Binary(ops.and_, self, to_funsor(other)) def __or__(self, other): return Binary(ops.or_, self, to_funsor(other)) def __ror__(self, other): return Binary(ops.or_, self, to_funsor(other)) def __xor__(self, other): return Binary(ops.xor, self, to_funsor(other)) def __eq__(self, other): return Binary(ops.eq, self, to_funsor(other)) def __ne__(self, other): return Binary(ops.ne, self, to_funsor(other)) def __lt__(self, other): return Binary(ops.lt, self, to_funsor(other)) def __le__(self, other): return Binary(ops.le, self, to_funsor(other)) def __gt__(self, other): return Binary(ops.gt, self, to_funsor(other)) def __ge__(self, other): return Binary(ops.ge, self, to_funsor(other)) def __getitem__(self, other): """ Helper to desugar into either ops.getitem (for advanced indexing involving Funsors as indices) or ops.getslice (for simple indexing involving only integers, slices, None, and Ellipsis). """ if type(other) is not tuple: if isinstance(other, ops.getslice.supported_types): return ops.getslice(self, other) other = to_funsor(other, Bint[self.output.shape[0]]) return Binary(ops.getitem, self, other) # Handle complex slicing operations involving no funsors. if all(isinstance(part, ops.getslice.supported_types) for part in other): return ops.getslice(self, other) # Handle Ellipsis slicing. if any(part is Ellipsis for part in other): left, right = parse_ellipsis(other) missing = len(self.output.shape) - len(left) - len(right) assert missing >= 0 middle = [slice(None)] * missing other = tuple(left + middle + right) # Handle each slice separately. result = self offset = 0 for part in other: if part is None: raise NotImplementedError("TODO") if isinstance(part, slice): if part != slice(None): raise NotImplementedError("TODO support nontrivial slicing") offset += 1 else: part = to_funsor(part, Bint[result.output.shape[offset]]) result = Binary(GetitemOp(offset), result, part) return result
@quote.register(Funsor) def _(arg, indent, out): name = type(arg).__name__ if type(arg).__module__ in [ "funsor.torch.distributions", "funsor.jax.distributions", ]: name = "dist." + name out.append((indent, name + "(")) for value in arg._ast_values[:-1]: quote.inplace(value, indent + 1, out) i, line = out[-1] out[-1] = i, line + "," for value in arg._ast_values[-1:]: quote.inplace(value, indent + 1, out) i, line = out[-1] out[-1] = i, line + ")" interpreter.children.register(Funsor)(interpreter.children_funsor)
[docs]@singledispatch def to_funsor(x, output=None, dim_to_name=None, **kwargs): """ Convert to a :class:`Funsor` . Only :class:`Funsor` s and scalars are accepted. :param x: An object. :param funsor.domains.Domain output: An optional output hint. :param OrderedDict dim_to_name: An optional mapping from negative batch dimensions to name strings. :return: A Funsor equivalent to ``x``. :rtype: Funsor :raises: ValueError """ raise ValueError("Cannot convert to Funsor: {}".format(repr(x)))
@to_funsor.register(Funsor) def funsor_to_funsor(x, output=None, dim_to_name=None): if output is not None and x.output != output: raise ValueError("Output mismatch: {} vs {}".format(x.output, output)) if dim_to_name is not None: bint_names = { name for name, domain in x.inputs.items() if domain.dtype != "real" } if not bint_names.issubset(dim_to_name.values()): raise ValueError("Inputs mismatch: {} vs {}".format(x.inputs, dim_to_name)) return x
[docs]@singledispatch def to_data(x, name_to_dim=None, **kwargs): """ Extract a python object from a :class:`Funsor`. Raises a ``ValueError`` if free variables remain or if the funsor is lazy. :param x: An object, possibly a :class:`Funsor`. :param OrderedDict name_to_dim: An optional inputs hint. :return: A non-funsor equivalent to ``x``. :raises: ValueError if any free variables remain. :raises: PatternMissingError if funsor is not fully evaluated. """ return x
@to_data.register(Funsor) def _to_data_funsor(x, name_to_dim=None): if name_to_dim is None and x.inputs: raise ValueError( "cannot convert {} to data due to lazy inputs: {}".format( type(x), set(x.inputs) ) ) raise PatternMissingError("cannot convert to a non-Funsor: {}".format(repr(x)))
[docs]class Variable(Funsor): """ Funsor representing a single free variable. :param str name: A variable name. :param funsor.domains.Domain output: A domain. """ def __init__(self, name, output): inputs = OrderedDict([(name, output)]) fresh = frozenset({name}) super(Variable, self).__init__(inputs, output, fresh) self.name = name def __repr__(self): return "Variable({}, {})".format(repr(self.name), repr(self.output)) def __str__(self): return self.name
[docs] def eager_subs(self, subs): assert len(subs) == 1 and subs[0][0] == self.name return subs[0][1]
@to_funsor.register(str) def name_to_funsor(name, output=None): if output is None: raise ValueError("Missing output: {}".format(name)) return Variable(name, output) class SubsMeta(FunsorMeta): """ Wrapper to call :func:`to_funsor` and check types. """ def __call__(cls, arg, subs): subs = tuple( (k, to_funsor(v, arg.inputs[k])) for k, v in subs if k in arg.inputs ) return super().__call__(arg, subs)
[docs]class Subs(Funsor, metaclass=SubsMeta): """ Lazy substitution of the form ``x(u=y, v=z)``. :param Funsor arg: A funsor being substituted into. :param tuple subs: A tuple of ``(name, value)`` pairs, where ``name`` is a string and ``value`` can be coerced to a :class:`Funsor` via :func:`to_funsor`. """ def __init__(self, arg, subs): assert isinstance(arg, Funsor) assert isinstance(subs, tuple) for key, value in subs: assert isinstance(key, str) assert key in arg.inputs assert isinstance(value, Funsor) inputs = arg.inputs.copy() for key, value in subs: del inputs[key] for key, value in subs: inputs.update(value.inputs) fresh = frozenset() bound = {key: value.output for key, value in subs} super(Subs, self).__init__(inputs, arg.output, fresh, bound) self.arg = arg self.subs = OrderedDict(subs) def __repr__(self): return "{}({})".format( repr(self.arg), ", ".join(f"{k}={repr(v)}" for k, v in self.subs.items()) ) def __str__(self): return "{}({})".format( str(self.arg), ", ".join(f"{k}={str(v)}" for k, v in self.subs.items()) ) def _alpha_convert(self, alpha_subs): assert set(alpha_subs).issubset(self.bound) alpha_subs = { k: to_funsor(v, self.subs[k].output) for k, v in alpha_subs.items() } arg, subs = self._ast_values arg = substitute(arg, alpha_subs) subs = tuple((str(alpha_subs.get(k, k)), v) for k, v in subs) return arg, subs def _sample(self, sampled_vars, sample_inputs, rng_key=None): if any(k in sample_inputs for k, v in self.subs.items()): raise NotImplementedError("TODO alpha-convert") subs_sampled_vars = set() for name in sampled_vars: if name in self.arg.inputs: if any(name in v.inputs for k, v in self.subs.items()): raise ValueError("Cannot sample") subs_sampled_vars.add(name) else: for k, v in self.subs.items(): if name in v.inputs: subs_sampled_vars.add(k) subs_sampled_vars = frozenset(subs_sampled_vars) arg = self.arg._sample(subs_sampled_vars, sample_inputs, rng_key) return Subs(arg, tuple(self.subs.items()))
@lazy.register(Subs, Funsor, object) @eager.register(Subs, Funsor, object) def eager_subs_funsor(arg, subs): assert isinstance(subs, tuple) if not any(k in arg.inputs for k, v in subs): return arg return substitute(arg, subs) @lazy.register(Subs, Subs, object) @eager.register(Subs, Subs, object) def eager_subs_subs(arg, subs): assert isinstance(subs, tuple) subs = tuple((k, v) for k, v in subs if k in arg.inputs) if not subs: return arg # Fuse substitutions. fused_subs = tuple((k, Subs(v, subs)) for k, v in arg.subs.items()) fused_subs += subs return Subs(arg.arg, fused_subs) @die.register(Subs, Funsor, tuple) def die_subs(arg, subs): expr = reflect.interpret(Subs, arg, subs) raise NotImplementedError(f"Missing pattern for {repr(expr)}")
[docs]class Unary(Funsor): """ Lazy unary operation. :param ~funsor.ops.Op op: A unary operator. :param Funsor arg: An argument. """ def __init__(self, op, arg): assert callable(op) assert isinstance(arg, Funsor) output = find_domain(op, arg.output) super(Unary, self).__init__(arg.inputs, output) self.op = op self.arg = arg def __repr__(self): if self.op in _PREFIX: return "({}{})".format(_PREFIX[self.op], repr(self.arg)) return super().__repr__() def __str__(self): if self.op in _PREFIX: return "({}{})".format(_PREFIX[self.op], str(self.arg)) return super().__str__()
@eager.register(Unary, Op, Funsor) def eager_unary(op, arg): return instrument.debug_logged(arg.eager_unary)(op) @eager.register(Unary, AssociativeOp, Funsor) def eager_unary(op, arg): if not arg.output.shape: return arg return instrument.debug_logged(arg.eager_unary)(op) @die.register(Unary, Op, Funsor) def die_unary(op, arg): expr = reflect.interpret(Unary, op, arg) raise NotImplementedError(f"Missing pattern for {repr(expr)}")
[docs]class Binary(Funsor): """ Lazy binary operation. :param ~funsor.ops.Op op: A binary operator. :param Funsor lhs: A left hand side argument. :param Funsor rhs: A right hand side argument. """ def __init__(self, op, lhs, rhs): assert callable(op) assert isinstance(lhs, Funsor) assert isinstance(rhs, Funsor) inputs = lhs.inputs.copy() inputs.update(rhs.inputs) output = find_domain(op, lhs.output, rhs.output) super(Binary, self).__init__(inputs, output) self.op = op self.lhs = lhs self.rhs = rhs def __repr__(self): if self.op in _INFIX: return "({} {} {})".format(repr(self.lhs), _INFIX[self.op], repr(self.rhs)) return super().__repr__() def __str__(self): if self.op in _INFIX: return "({} {} {})".format(str(self.lhs), _INFIX[self.op], str(self.rhs)) return super().__str__()
@die.register(Binary, Op, Funsor, Funsor) def die_binary(op, lhs, rhs): expr = reflect.interpret(Binary, op, lhs, rhs) raise NotImplementedError(f"Missing pattern for {repr(expr)}")
[docs]class Reduce(Funsor): """ Lazy reduction over multiple variables. The user-facing interface is the :meth:`Funsor.reduce` method. :param op: An associative operator. :type op: ~funsor.ops.AssociativeOp :param funsor arg: An argument to be reduced. :param frozenset reduced_vars: A set of variables over which to reduce. """ def __init__(self, op, arg, reduced_vars): assert isinstance(op, AssociativeOp) assert isinstance(arg, Funsor) assert isinstance(reduced_vars, frozenset) assert all(isinstance(v, Variable) for v in reduced_vars) reduced_names = frozenset(v.name for v in reduced_vars) inputs = OrderedDict( (k, v) for k, v in arg.inputs.items() if k not in reduced_names ) output = arg.output fresh = frozenset() bound = {var.name: var.output for var in reduced_vars} super(Reduce, self).__init__(inputs, output, fresh, bound) self.op = op self.arg = arg self.reduced_vars = reduced_vars def __repr__(self): assert self.reduced_vars if self.reduced_vars == self.arg.input_vars: return f"{repr(self.arg)}.reduce({self.op.__name__})" rvars = [ f'"{v.name}"' if v in self.arg.input_vars else repr(v) for v in self.reduced_vars ] return "{}.reduce({}, {{{}}})".format( repr(self.arg), self.op.__name__, ", ".join(rvars) ) def __str__(self): assert self.reduced_vars if self.reduced_vars == self.arg.input_vars: return f"{str(self.arg)}.reduce({self.op.__name__})" rvars = [ f'"{v.name}"' if v in self.arg.input_vars else repr(v) for v in self.reduced_vars ] return "{}.reduce({}, {{{}}})".format( str(self.arg), self.op.__name__, ", ".join(rvars) ) def _alpha_convert(self, alpha_subs): alpha_subs = { k: to_funsor(v, self.arg.inputs[k]) for k, v in alpha_subs.items() } op, arg, reduced_vars = super()._alpha_convert(alpha_subs) reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) return op, arg, reduced_vars
def _reduce_unrelated_vars(op, arg, reduced_vars): factor_vars = reduced_vars - arg.input_vars if factor_vars: reduced_vars = reduced_vars & arg.input_vars multiplicity = reduce( ops.mul, [ v.output.size**v.output.num_elements for v in factor_vars if v.dtype != "real" ], ) for add_op, mul_op in ops.DISTRIBUTIVE_OPS: if add_op is op: arg = mul_op(arg, multiplicity).reduce(op, reduced_vars) return arg, None raise NotImplementedError(f"Cannot reduce {op}") return arg, frozenset(v.name for v in reduced_vars) @lazy.register(Reduce, AssociativeOp, Funsor, frozenset) def lazy_reduce(op, arg, reduced_vars): new_arg, new_reduced_vars = _reduce_unrelated_vars(op, arg, reduced_vars) if new_reduced_vars is None: return new_arg if new_arg is arg: return None return new_arg.reduce(op, new_reduced_vars) @eager.register(Reduce, AssociativeOp, Funsor, frozenset) def eager_reduce(op, arg, reduced_vars): arg, reduced_vars = _reduce_unrelated_vars(op, arg, reduced_vars) if reduced_vars is None: return arg return instrument.debug_logged(arg.eager_reduce)(op, reduced_vars) @sequential.register(Reduce, AssociativeOp, Funsor, frozenset) def sequential_reduce(op, arg, reduced_vars): arg, reduced_vars = _reduce_unrelated_vars(op, arg, reduced_vars) if reduced_vars is None: return arg return instrument.debug_logged(arg.sequential_reduce)(op, reduced_vars) @moment_matching.register(Reduce, AssociativeOp, Funsor, frozenset) def moment_matching_reduce(op, arg, reduced_vars): arg, reduced_vars = _reduce_unrelated_vars(op, arg, reduced_vars) if reduced_vars is None: return arg return instrument.debug_logged(arg.moment_matching_reduce)(op, reduced_vars) @die.register(Reduce, Op, Funsor, frozenset) def die_reduce(op, arg, reduced_vars): expr = reflect.interpret(Reduce, op, arg, reduced_vars) raise NotImplementedError(f"Missing pattern for {repr(expr)}")
[docs]class Scatter(Funsor): """ Transpose of structurally linear :class:`Subs`, followed by :class:`Reduce`. For injective scatter operations this should satisfy the equation:: if destin = Scatter(op, subs, source, frozenset()) then source = Subs(destin, subs) The ``reduced_vars`` is merely for computational efficiency, and could always be split out into a separate ``.reduce()``. For example in the following equation, the left hand side uses much less memory than the right hand side:: Scatter(op, subs, source, reduced_vars) == Scatter(op, subs, source, frozenset()).reduce(op, reduced_vars) .. warning:: This is currently implemented only for injective scatter operations. In particular, this does not allow accumulation behavior like scatter-add. .. note:: ``Scatter(ops.add, ...)`` is the funsor analog of ``numpy.add.at()`` or :func:`torch.index_put` or :func:`jax.lax.scatter_add`. For injective substitutions, ``Scatter(ops.add, ...)`` is roughly equivalent to the tensor operation:: result = zeros(...) # since zero is the additive unit result[subs] = source :param AssociativeOp op: An op. The unit of this op will be used as default value. :param tuple subs: A substitution. :param Funsor source: A source for data to be scattered from. :param frozenset reduced_vars: A set of variables over which to reduce. """ def __init__(self, op, subs, source, reduced_vars): assert isinstance(op, AssociativeOp) assert isinstance(subs, tuple) assert len(subs) == len(set(key for key, value in subs)) assert isinstance(source, Funsor) assert isinstance(reduced_vars, frozenset) assert all(isinstance(v, Variable) for v in reduced_vars) reduced_names = frozenset(v.name for v in reduced_vars) # First compute inputs of the pure-scatter op with no reduction. inputs = OrderedDict() for key, value in subs: assert isinstance(key, str) assert isinstance(value, Funsor) assert key not in source.inputs assert key not in reduced_names for k, d in value.inputs.items(): # These are "batch" inputs and should be left of subs keys. d2 = inputs.setdefault(k, d) assert d2 == d for k, d in source.inputs.items(): # These are "batch" inputs and should be left of subs keys. d2 = inputs.setdefault(k, d) assert d2 == d for key, value in subs: assert key not in inputs # These are "event" inputs and should be right of "batch" inputs. inputs[key] = value.output # Then narrow these down to the fused scatter-reduce op. inputs = OrderedDict( (k, d) for k, d in inputs.items() if k not in reduced_names ) fresh = frozenset(key for key, value in subs) bound = {v.name: v.output for v in reduced_vars} super().__init__(inputs, source.output, fresh, bound) self.op = op self.subs = subs self.source = source self.reduced_vars = reduced_vars def _alpha_convert(self, alpha_subs): alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} op, subs, source, reduced_vars = super()._alpha_convert(alpha_subs) reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) return op, subs, source, reduced_vars
[docs] def eager_subs(self, subs): subs = OrderedDict(subs) new_subs = [] for name, sub in self.subs: if name in subs and isinstance(subs[name], Variable): new_subs.append((subs[name].name, sub)) else: new_subs.append((name, sub)) return Scatter(self.op, tuple(new_subs), self.source, self.reduced_vars)
[docs]class Approximate(Funsor): """ Interpretation-specific approximation wrt a set of variables. The default eager interpretation should be exact. The user-facing interface is the :meth:`Funsor.approximate` method. :param op: An associative operator. :type op: ~funsor.ops.AssociativeOp :param Funsor model: An exact funsor depending on ``approx_vars``. :param Funsor guide: A proposal funsor guiding optional approximation. :param frozenset approx_vars: A set of variables over which to approximate. """ def __init__(self, op, model, guide, approx_vars): assert isinstance(op, AssociativeOp) assert isinstance(model, Funsor) assert isinstance(guide, Funsor) assert model.output is guide.output assert isinstance(approx_vars, frozenset), approx_vars inputs = model.inputs.copy() inputs.update(guide.inputs) output = model.output fresh = frozenset(v.name for v in approx_vars) bound = {v.name: v.output for v in approx_vars} super().__init__(inputs, output, fresh, bound) self.op = op self.model = model self.guide = guide self.approx_vars = approx_vars def _alpha_convert(self, alpha_subs): alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} op, model, guide, approx_vars = super()._alpha_convert(alpha_subs) approx_vars = frozenset(alpha_subs.get(var.name, var) for var in approx_vars) return op, model, guide, approx_vars
@eager.register(Approximate, AssociativeOp, Funsor, Funsor, frozenset) def eager_approximate(op, model, guide, approx_vars): return model # exact class NumberMeta(FunsorMeta): """ Wrapper to fill in default ``dtype``. """ def __call__(cls, data, dtype=None): if dtype is None: dtype = "real" return super(NumberMeta, cls).__call__(data, dtype)
[docs]class Number(Funsor, metaclass=NumberMeta): """ Funsor backed by a Python number. :param numbers.Number data: A python number. :param dtype: A nonnegative integer or the string "real". """ def __init__(self, data, dtype=None): assert isinstance(data, numbers.Number) if isinstance(dtype, int): data = type(dtype)(data) if dtype != 2: # booleans have bitwise interpretation assert 0 <= data and data < dtype else: assert isinstance(dtype, str) and dtype == "real" data = float(data) inputs = OrderedDict() output = Array[dtype, ()] super(Number, self).__init__(inputs, output) self.data = data def __repr__(self): if self.dtype == "real": return f"Number({str(self.data)})" else: return f"Number({str(self.data)}, {self.dtype})" def __str__(self): return str(self.data) def __int__(self): return int(self.data) def __float__(self): return float(self.data) def __bool__(self): return bool(self.data)
[docs] def item(self): return self.data
[docs] def eager_unary(self, op): dtype = find_domain(op, self.output).dtype return Number(op(self.data), dtype)
@to_funsor.register(numbers.Number) def number_to_funsor(x, output=None, dim_to_name=None): if output is None: return Number(x) if output.shape: raise ValueError("Cannot create Number with shape {}".format(output.shape)) return Number(x, output.dtype) @to_data.register(Number) def _to_data_number(x, name_to_dim=None): return x.data @eager.register(Binary, Op, Number, Number) def eager_binary_number_number(op, lhs, rhs): data = op(lhs.data, rhs.data) output = find_domain(op, lhs.output, rhs.output) dtype = output.dtype return Number(data, dtype) class SliceMeta(FunsorMeta): """ Wrapper to fill in ``start``, ``stop``, ``step``, ``dtype`` following Python conventions. """ def __call__(cls, name, *args, **kwargs): start = 0 step = 1 dtype = None if len(args) == 1: stop = args[0] dtype = kwargs.pop("dtype", stop) elif len(args) == 2: start, stop = args dtype = kwargs.pop("dtype", stop) elif len(args) == 3: start, stop, step = args dtype = kwargs.pop("dtype", stop) elif len(args) == 4: start, stop, step, dtype = args else: raise ValueError if step <= 0: raise ValueError stop = min(dtype, max(start, stop)) return super().__call__(name, start, stop, step, dtype)
[docs]class Slice(Funsor, metaclass=SliceMeta): """ Symbolic representation of a Python :py:class:`slice` object. :param str name: A name for the new slice dimension. :param int start: :param int stop: :param int step: Three args following :py:class:`slice` semantics. :param int dtype: An optional bounded integer type of this slice. """ def __init__(self, name, start, stop, step, dtype): assert isinstance(name, str) assert isinstance(start, int) and start >= 0 assert isinstance(stop, int) and stop >= start assert isinstance(step, int) and step > 0 assert isinstance(dtype, int) size = max(0, (stop + step - 1 - start) // step) inputs = OrderedDict([(name, Bint[size])]) output = Bint[dtype] fresh = frozenset({name}) super().__init__(inputs, output, fresh) self.name = name self.slice = slice(start, stop, step)
[docs] def eager_subs(self, subs): assert len(subs) == 1 and subs[0][0] == self.name index = subs[0][1] if isinstance(index, Variable): name = index.name return Slice( name, self.slice.start, self.slice.stop, self.slice.step, self.dtype ) elif isinstance(index, Number): data = self.slice.start + self.slice.step * index.data return Number(data, self.output.dtype) elif type(index).__name__ == "Tensor": # avoid importing funsor.tensor.Tensor data = self.slice.start + self.slice.step * index.data return type(index)(data, index.inputs, self.output.dtype) elif isinstance(index, Slice): name = index.name start = self.slice.start + self.slice.step * index.slice.start step = self.slice.step * index.slice.step return Slice(name, start, self.slice.stop, step, self.dtype) else: raise NotImplementedError( "TODO support substitution of {} into Slice".format(type(index)) )
@to_funsor.register(slice) def slice_to_funsor(s, output=None, dim_to_name=None): if not isinstance(output, BintType): raise ValueError("Incompatible slice output: {output}") start, stop, step = parse_slice(s, output.size) i = Variable("slice", output) return Lambda(i, Slice("slice", start, stop, step, output.size)) class Align(Funsor): """ Lazy call to ``.align(...)``. :param Funsor arg: A funsor to align. :param tuple names: A tuple of input names whose order to follow. """ def __init__(self, arg, names): assert isinstance(arg, Funsor) assert isinstance(names, tuple) assert all(isinstance(name, str) for name in names) assert all(name in arg.inputs for name in names) inputs = OrderedDict((name, arg.inputs[name]) for name in names) inputs.update(arg.inputs) output = arg.output fresh = frozenset() # TODO get this right bound = {} super(Align, self).__init__(inputs, output, fresh, bound) self.arg = arg def align(self, names): return self.arg.align(names) def eager_unary(self, op): return Unary(op, self.arg) def eager_reduce(self, op, reduced_vars): return self.arg.reduce(op, reduced_vars) @eager.register(Align, Funsor, tuple) def eager_align(arg, names): if not frozenset(names) == frozenset(arg.inputs.keys()): # assume there's been a substitution and this align is no longer valid return arg return None @eager.register(Binary, Op, Align, Funsor) def eager_binary_align_funsor(op, lhs, rhs): return Binary(op, lhs.arg, rhs) @eager.register(Binary, Op, Funsor, Align) def eager_binary_funsor_align(op, lhs, rhs): return Binary(op, lhs, rhs.arg) @eager.register(Binary, Op, Align, Align) def eager_binary_align_align(op, lhs, rhs): return Binary(op, lhs.arg, rhs.arg) class Finitary(Funsor): def __init__(self, op, args): assert isinstance(op, ops.Op) assert isinstance(args, tuple) assert all(isinstance(v, Funsor) for v in args) inputs = OrderedDict() for arg in args: inputs.update(arg.inputs) output = find_domain(op, tuple(arg.output for arg in args)) super().__init__(inputs, output) self.op = op self.args = args
[docs]class Stack(Funsor): """ Stack of funsors along a new input dimension. :param str name: The name of the new input variable along which to stack. :param tuple parts: A tuple of Funsors of homogenous output domain. """ def __init__(self, name, parts): assert isinstance(name, str) assert isinstance(parts, tuple) assert parts assert not any(name in x.inputs for x in parts) assert len(set(x.output for x in parts)) == 1 output = parts[0].output domain = Bint[len(parts)] inputs = OrderedDict([(name, domain)]) for x in parts: inputs.update(x.inputs) fresh = frozenset({name}) super().__init__(inputs, output, fresh) self.name = name self.parts = parts
[docs] def eager_subs(self, subs): assert isinstance(subs, tuple) and len(subs) == 1 and subs[0][0] == self.name index = subs[0][1] # Try to eagerly select an index. if index.output == Bint[len(self.parts)]: if isinstance(index, Number): # Select a single part. return self.parts[index.data] elif isinstance(index, Variable): # Rename the stacking dimension. parts = self.parts return Stack(index.name, parts) elif isinstance(index, Slice): parts = self.parts[index.slice] return Stack(index.name, parts) else: raise NotImplementedError("TODO support advanced indexing in Stack") else: raise NotImplementedError("TODO support slicing in Stack")
[docs] def eager_reduce(self, op, reduced_vars): parts = self.parts if self.name in reduced_vars: reduced_vars -= frozenset([self.name]) if reduced_vars: parts = tuple(x.reduce(op, reduced_vars) for x in parts) return reduce(op, parts) parts = tuple(x.reduce(op, reduced_vars) for x in parts) return Stack(self.name, parts)
@eager.register(Stack, str, tuple) def eager_stack(name, parts): return eager_stack_homogeneous(name, *parts) @dispatch(str, Variadic[Funsor]) def eager_stack_homogeneous(name, *parts): return None # defer to default implementation class CatMeta(FunsorMeta): """ Wrapper to fill in default value for ``part_name``. """ def __call__(cls, name, parts, part_name=None): if part_name is None: part_name = name return super().__call__(name, parts, part_name)
[docs]class Cat(Funsor, metaclass=CatMeta): """ Concatenate funsors along an existing input dimension. :param str name: The name of the input variable along which to concatenate. :param tuple parts: A tuple of Funsors of homogenous output domain. """ def __init__(self, name, parts, part_name=None): assert isinstance(name, str) assert isinstance(parts, tuple) assert isinstance(part_name, str) assert parts for part in parts: assert part_name in part.inputs, (part_name, part.inputs) if part_name != name: assert not any(name in x.inputs for x in parts) assert len(set(x.output for x in parts)) == 1 output = parts[0].output inputs = OrderedDict() for x in parts: inputs.update(x.inputs) del inputs[part_name] inputs[name] = Bint[sum(x.inputs[part_name].size for x in parts)] fresh = frozenset({name}) bound = {part_name: x.inputs[part_name]} super().__init__(inputs, output, fresh, bound) self.name = name self.parts = parts self.part_name = part_name def _alpha_convert(self, alpha_subs): assert len(alpha_subs) == 1 part_name = alpha_subs[self.part_name] parts = tuple( substitute( p, {self.part_name: to_funsor(part_name, p.inputs[self.part_name])} ) for p in self.parts ) return self.name, parts, part_name
[docs] def eager_subs(self, subs): assert len(subs) == 1 and subs[0][0] == self.name value = subs[0][1] if isinstance(value, Variable): return Cat(value.name, self.parts, self.part_name) elif isinstance(value, Number): n = value.data for part in self.parts: size = part.inputs[self.part_name].size if n < size: return part(**{self.part_name: n}) n -= size assert False elif isinstance(value, Slice): start, stop, step = value.slice.start, value.slice.stop, value.slice.step new_parts = [] pos = 0 for part in self.parts: psize = part.inputs[self.part_name].size if step > 1: pstart = ((pos - start) // step) * step - (pos - start) pstart = pstart + step if pstart < 0 else pstart else: pstart = max(start - pos, 0) pstop = min(pos + psize, stop) - pos if not (pstart >= pstop or pos >= stop or pos + psize <= start): pslice = Slice(self.part_name, pstart, pstop, step, psize) part = part(**{self.part_name: pslice}) new_parts.append(part) pos += psize return Cat(self.name, tuple(new_parts), self.part_name) else: raise NotImplementedError( "TODO implement Cat.eager_subs for {}".format(type(value)) )
@eager.register(Cat, str, tuple, str) def eager_cat(name, parts, part_name): if len(parts) == 1: return parts[0](**{part_name: name}) return eager_cat_homogeneous(name, part_name, *parts) @dispatch(str, str, Variadic[Funsor]) def eager_cat_homogeneous(name, part_name, *parts): return None # defer to default implementation
[docs]class Lambda(Funsor): """ Lazy inverse to ``ops.getitem``. This is useful to simulate higher-order functions of integers by representing those functions as arrays. :param Variable var: A variable to bind. :param funsor expr: A funsor. """ def __init__(self, var, expr): assert isinstance(var, Variable) assert isinstance(var.dtype, int) assert isinstance(expr, Funsor) inputs = expr.inputs.copy() inputs.pop(var.name, None) shape = (var.dtype,) + expr.output.shape output = Array[expr.dtype, shape] fresh = frozenset() bound = {var.name: var.output} super(Lambda, self).__init__(inputs, output, fresh, bound) self.var = var self.expr = expr def _alpha_convert(self, alpha_subs): alpha_subs = { k: to_funsor(v, self.var.inputs[k]) for k, v in alpha_subs.items() } return super()._alpha_convert(alpha_subs)
@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) def eager_getitem_lambda(op, lhs, rhs): offset = op.defaults["offset"] if offset == 0: return Subs(lhs.expr, ((lhs.var.name, rhs),)) expr = GetitemOp(offset - 1)(lhs.expr, rhs) return Lambda(lhs.var, expr) @eager.register(Unary, ops.GetsliceOp, Lambda) def eager_getslice_lambda(op, x): index = normalize_ellipsis(op.defaults["index"], len(x.shape)) head, tail = index[0], index[1:] expr = x.expr if head != slice(None): expr = expr(**{x.var.name: head}) if tail: expr = ops.getslice(expr, tail) if x.var.name in expr.inputs: # dim is preserved, e.g. x[1:] return Lambda(x.var, expr) else: # dim is eliminated, e.g. x[0] return expr
[docs]class Independent(Funsor): """ Creates an independent diagonal distribution. This is equivalent to substitution followed by reduction:: f = ... # a batched distribution assert f.inputs['x_i'] == Reals[4, 5] assert f.inputs['i'] == Bint[3] g = Independent(f, 'x', 'i', 'x_i') assert g.inputs['x'] == Reals[3, 4, 5] assert 'x_i' not in g.inputs assert 'i' not in g.inputs x = Variable('x', Reals[3, 4, 5]) g == f(x_i=x['i']).reduce(ops.add, 'i') :param Funsor fn: A funsor. :param str reals_var: The name of a real-tensor input. :param str bint_var: The name of a new batch input of ``fn``. :param diag_var: The name of a smaller-shape real input of ``fn``. """ def __init__(self, fn, reals_var, bint_var, diag_var): assert isinstance(fn, Funsor) assert isinstance(reals_var, str) assert isinstance(bint_var, str) assert bint_var in fn.inputs, (bint_var, fn.inputs) assert isinstance(fn.inputs[bint_var].dtype, int) assert isinstance(diag_var, str) assert diag_var in fn.inputs inputs = fn.inputs.copy() diag_input = inputs.pop(diag_var) shape = (inputs.pop(bint_var).dtype,) + diag_input.shape assert reals_var not in inputs inputs[reals_var] = Array[diag_input.dtype, shape] fresh = frozenset({reals_var}) bound = {bint_var: fn.inputs[bint_var], diag_var: fn.inputs[diag_var]} super(Independent, self).__init__(inputs, fn.output, fresh, bound) self.fn = fn self.reals_var = reals_var self.bint_var = bint_var self.diag_var = diag_var def _alpha_convert(self, alpha_subs): alpha_subs = {k: to_funsor(v, self.fn.inputs[k]) for k, v in alpha_subs.items()} fn, reals_var, bint_var, diag_var = super()._alpha_convert(alpha_subs) bint_var = str(alpha_subs.get(bint_var, bint_var)) diag_var = str(alpha_subs.get(diag_var, diag_var)) return fn, reals_var, bint_var, diag_var def _sample(self, sampled_vars, sample_inputs, rng_key=None): if self.bint_var in sampled_vars or self.bint_var in sample_inputs: raise NotImplementedError("TODO alpha-convert") sampled_vars = frozenset( self.diag_var if v == self.reals_var else v for v in sampled_vars ) fn = self.fn._sample(sampled_vars, sample_inputs, rng_key) return Independent(fn, self.reals_var, self.bint_var, self.diag_var)
[docs] def eager_subs(self, subs): assert len(subs) == 1 and subs[0][0] == self.reals_var value = subs[0][1] # Handle simple renaming to preserve Independent. if isinstance(value, Variable): return Independent(self.fn, value.name, self.bint_var, self.diag_var) # Otherwise convert to a Reduce. result = Subs(self.fn, ((self.diag_var, value[self.bint_var]),)) result = result.reduce(ops.add, self.bint_var) return result
[docs] def mean(self): raise NotImplementedError("mean() not yet implemented for Independent")
[docs] def variance(self): raise NotImplementedError("variance() not yet implemented for Independent")
[docs] def entropy(self): raise NotImplementedError("entropy() not yet implemented for Independent")
@eager.register(Independent, Funsor, str, str, str) def eager_independent_trivial(fn, reals_var, bint_var, diag_var): # compare to Independent.eager_subs if diag_var not in fn.inputs: return fn.reduce(ops.add, bint_var) return None class Tuple(Funsor): """ Funsor term representing tuples of other terms of possibly heterogeneous type. """ def __init__(self, args): assert isinstance(args, tuple) assert all(isinstance(arg, Funsor) for arg in args) inputs = OrderedDict() for arg in args: inputs.update(arg.inputs) output = Product[tuple(arg.output for arg in args)] super().__init__(inputs, output) self.args = args def __iter__(self): for i in range(len(self.args)): yield self[i] @to_funsor.register(tuple) def tuple_to_funsor(args, output=None, dim_to_name=None): if not isinstance(output, ProductDomain): raise NotImplementedError("TODO") outputs = get_args(output) assert len(outputs) == len(args) funsor_args = tuple( to_funsor(arg, output=arg_output, dim_to_name=dim_to_name) for arg, arg_output in zip(args, outputs) ) return Tuple(funsor_args) @lazy.register(Binary, GetitemOp, Tuple, Number) @eager.register(Binary, GetitemOp, Tuple, Number) def eager_getitem_tuple(op, lhs, rhs): return op(lhs.args, rhs.data) @lazy.register(Unary, ops.GetsliceOp, Tuple) @eager.register(Unary, ops.GetsliceOp, Tuple) def eager_getslice_tuple(op, x): index = op.defaults["index"] if isinstance(index, tuple): assert len(index) == 1 index = index[0] if isinstance(index, int): return op(x.args) elif isinstance(index, slice): return Tuple(op(x.args)) else: raise ValueError(index) def _symbolic(inputs, output, fn): args, vargs, kwargs, defaults = getargspec(fn) assert not vargs assert not kwargs names = tuple(args) if isinstance(inputs, dict): args = tuple(Variable(name, inputs[name]) for name in names if name in inputs) else: args = tuple(Variable(name, domain) for (name, domain) in zip(names, inputs)) assert len(args) == len(inputs) return to_funsor(fn(*args), output).align(names) def symbolic(*signature): r""" Decorator to construct a symbolic :class:`Funsor` with one free :class:`Variable` per function arg. This can be used either with explicit types or with type hints:: # Using type hints: @symbolic def xpyi(x: Real, y: Reals[3], i: Bint[3]): return x + y[i] # Using explicit type annotations: @symbolic(Real, Reals[3], Bint[3]) def xpyi(x: Real, y: Reals[3], i: Bint[3]): return x + y[i] :param \*signature: A sequence if input domains. """ if len(signature) == 1: fn = signature[0] if callable(fn) and not isinstance(fn, Domain): # Usage: @symbolic inputs = typing.get_type_hints(fn) output = inputs.pop("return", None) return _symbolic(inputs, output, fn) # Usage: @symbolic(Real, Reals[3], Bint[3]) output = None # FIXME: what is inputs? return functools.partial(_symbolic, inputs, output) # DEPRECATED
[docs]def of_shape(*shape): warnings.warn("@of_shape is deprecated, use @symbolic instead", DeprecationWarning) return symbolic(*shape)
AstStats = namedtuple("AstStats", ("size", "depth", "width")) # Profiling helpers @singledispatch def _count_funsors(x): return 0 @_count_funsors.register(Funsor) def _(x): return 1 @_count_funsors.register(tuple) def _(x): return sum(map(_count_funsors, x)) @singledispatch def _get_ast_stats(x): return AstStats(1, 1, 0) @_get_ast_stats.register(Funsor) def _(x): result = getattr(x, "_ast_stats", None) if result is None: size, depth, _ = _get_ast_stats(x._ast_values) width = _count_funsors(x._ast_values) result = x._ast_stats = AstStats(size + 1, depth + 1, width) return result @_get_ast_stats.register(tuple) def _(x): parts = list(map(_get_ast_stats, x)) size = sum(p.size for p in parts) depth = max([0] + [p.depth for p in parts]) return AstStats(size, depth, 0) ################################################################################ # Register Ops ################################################################################ @quote.register(Variable) @quote.register(Number) @quote.register(Slice) def quote_inplace_oneline(arg, indent, out): out.append((indent, repr(arg))) @quote.register(Unary) @quote.register(Binary) @quote.register(Reduce) @quote.register(Stack) @quote.register(Cat) @quote.register(Lambda) def quote_inplace_first_arg_on_first_line(arg, indent, out): line = "{}({},".format(type(arg).__name__, repr(arg._ast_values[0])) out.append((indent, line)) for value in arg._ast_values[1:-1]: quote.inplace(value, indent + 1, out) i, line = out[-1] out[-1] = i, line + "," for value in arg._ast_values[-1:]: quote.inplace(value, indent + 1, out) i, line = out[-1] out[-1] = i, line + ")" @ops.UnaryOp.subclass_register(Funsor) def unary_funsor(cls, arg, *args, **kwargs): op = cls(*args, **kwargs) return Unary(op, arg) @ops.BinaryOp.subclass_register(Funsor, Funsor) def binary_funsor_funsor(cls, lhs, rhs, *args, **kwargs): op = cls(*args, **kwargs) return Binary(op, lhs, rhs) @ops.BinaryOp.subclass_register(object, Funsor) def binary_object_funsor(cls, lhs, rhs, *args, **kwargs): op = cls(*args, **kwargs) lhs = to_funsor(lhs) return Binary(op, lhs, rhs) @ops.BinaryOp.subclass_register(Funsor, object) def binary_funsor_object(cls, lhs, rhs, *args, **kwargs): op = cls(*args, **kwargs) rhs = to_funsor(rhs) return Binary(op, lhs, rhs) @ops.TernaryOp.subclass_register(Funsor, Funsor, Funsor) @ops.TernaryOp.subclass_register(Funsor, Funsor, object) @ops.TernaryOp.subclass_register(Funsor, object, object) @ops.TernaryOp.subclass_register(object, Funsor, object) @ops.TernaryOp.subclass_register(object, object, Funsor) def ternary_funsor_object(cls, x, y, z, *args, **kwargs): op = cls(*args, **kwargs) x = to_funsor(x) y = to_funsor(y) z = to_funsor(z) return Finitary(op, (x, y, z)) # FIXME allow some non-funsors @ops.FinitaryOp.subclass_register(typing.Tuple[Funsor, ...]) def finitary_funsor(cls, arg, *args, **kwargs): op = cls(*args, **kwargs) return Finitary(op, arg) __all__ = [ "Approximate", "Binary", "Cat", "Funsor", "Independent", "Lambda", "Number", "Reduce", "Scatter", "Stack", "Slice", "Subs", "Unary", "Variable", "of_shape", "to_data", "to_funsor", ]