Source code for funsor.ops.builtin

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
import operator
from numbers import Number

from .op import DISTRIBUTIVE_OPS, PRODUCT_INVERSES, UNITS, Op, TransformOp

_builtin_abs = abs
_builtin_max = max
_builtin_min = min
_builtin_pow = pow
_builtin_sum = sum


# FIXME Most code assumes this is an AssociativeCommutativeOp.
[docs]class AssociativeOp(Op): pass
[docs]class AddOp(AssociativeOp): pass
[docs]class MulOp(AssociativeOp): pass
[docs]class MatmulOp(Op): # Associtive but not commutative. pass
[docs]class SubOp(Op): pass
[docs]class NegOp(Op): pass
[docs]class DivOp(Op): pass
[docs]class NullOp(AssociativeOp): """Placeholder associative op that unifies with any other op""" pass
[docs]@NullOp def nullop(x, y): raise ValueError("should never actually evaluate this!")
class GetitemMeta(type): _cache = {} def __call__(cls, offset): try: return GetitemMeta._cache[offset] except KeyError: instance = super(GetitemMeta, cls).__call__(offset) GetitemMeta._cache[offset] = instance return instance
[docs]class GetitemOp(Op, metaclass=GetitemMeta): """ Op encoding an index into one dimension, e.g. ``x[:,:,y]`` for offset of 2. """ def __init__(self, offset): assert isinstance(offset, int) assert offset >= 0 self.offset = offset self._prefix = (slice(None),) * offset super(GetitemOp, self).__init__(self._default) self.__name__ = 'GetitemOp({})'.format(offset) def __reduce__(self): return GetitemOp, (self.offset,) def _default(self, x, y): return x[self._prefix + (y,)] if self.offset else x[y]
getitem = GetitemOp(0) abs = Op(_builtin_abs) eq = Op(operator.eq) ge = Op(operator.ge) gt = Op(operator.gt) invert = Op(operator.invert) le = Op(operator.le) lt = Op(operator.lt) ne = Op(operator.ne) neg = NegOp(operator.neg) sub = SubOp(operator.sub) truediv = DivOp(operator.truediv) add = AddOp(operator.add) and_ = AssociativeOp(operator.and_) mul = MulOp(operator.mul) matmul = MatmulOp(operator.matmul) or_ = AssociativeOp(operator.or_) xor = AssociativeOp(operator.xor) @add.register(object) def _unary_add(x): return x.sum() @Op def sqrt(x): return math.sqrt(x)
[docs]class ExpOp(TransformOp): pass
[docs]@ExpOp def exp(x): return math.exp(x)
@exp.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): return add(x)
[docs]class LogOp(TransformOp): pass
[docs]@LogOp def log(x): return math.log(x) if x > 0 else -math.inf
@log.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): return -add(y) exp.set_inv(log) log.set_inv(exp) @Op def log1p(x): return math.log1p(x) @Op def sigmoid(x): return 1 / (1 + exp(-x)) @Op def pow(x, y): return x ** y
[docs]@AssociativeOp def min(x, y): if hasattr(x, '__min__'): return x.__min__(y) if hasattr(y, '__min__'): return y.__min__(x) return _builtin_min(x, y)
[docs]@AssociativeOp def max(x, y): if hasattr(x, '__max__'): return x.__max__(y) if hasattr(y, '__max__'): return y.__max__(x) return _builtin_max(x, y)
[docs]@SubOp def safesub(x, y): if isinstance(y, Number): return sub(x, y)
[docs]@DivOp def safediv(x, y): if isinstance(y, Number): return truediv(x, y)
[docs]class ReciprocalOp(Op): pass
[docs]@ReciprocalOp def reciprocal(x): if isinstance(x, Number): return 1. / x raise ValueError("No reciprocal for type {}".format(type(x)))
DISTRIBUTIVE_OPS.add((add, mul)) DISTRIBUTIVE_OPS.add((max, mul)) DISTRIBUTIVE_OPS.add((min, mul)) DISTRIBUTIVE_OPS.add((max, add)) DISTRIBUTIVE_OPS.add((min, add)) UNITS[mul] = 1. UNITS[add] = 0. PRODUCT_INVERSES[mul] = safediv PRODUCT_INVERSES[add] = safesub __all__ = [ 'AddOp', 'AssociativeOp', 'DivOp', 'ExpOp', 'GetitemOp', 'LogOp', 'MatmulOp', 'MulOp', 'NegOp', 'NullOp', 'ReciprocalOp', 'SubOp', 'abs', 'add', 'and_', 'eq', 'exp', 'ge', 'getitem', 'gt', 'invert', 'le', 'log', 'log1p', 'lt', 'matmul', 'max', 'min', 'mul', 'ne', 'neg', 'nullop', 'or_', 'pow', 'reciprocal', 'safediv', 'safesub', 'sigmoid', 'sqrt', 'sub', 'truediv', 'xor', ] __doc__ = "\n".join(".. autodata:: {}\n".format(_name) for _name in __all__ if isinstance(globals()[_name], Op))