Source code for funsor.ops.op

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from multipledispatch import Dispatcher


[docs]class Op(Dispatcher): def __init__(self, fn, *, name=None): if isinstance(fn, str): fn, name = None, fn if name is None: name = fn.__name__ super(Op, self).__init__(name) if fn is not None: # register as default operation for nargs in (1, 2): default_signature = (object,) * nargs self.add(default_signature, fn) def __copy__(self): return self def __deepcopy__(self, memo): return self def __reduce__(self): return self.__name__ def __repr__(self): return "ops." + self.__name__ def __str__(self): return self.__name__
[docs]class TransformOp(Op):
[docs] def set_inv(self, fn): """ :param callable fn: A function that inputs an arg ``y`` and outputs a value ``x`` such that ``y=self(x)``. """ assert callable(fn) self.inv = fn return fn
[docs] def set_log_abs_det_jacobian(self, fn): """ :param callable fn: A function that inputs two args ``x, y``, where ``y=self(x)``, and returns ``log(abs(det(dy/dx)))``. """ assert callable(fn) self.log_abs_det_jacobian = fn return fn
[docs] @staticmethod def inv(x): raise NotImplementedError
[docs] @staticmethod def log_abs_det_jacobian(x, y): raise NotImplementedError
# Op registration tables. DISTRIBUTIVE_OPS = set() # (add, mul) pairs UNITS = {} # op -> value PRODUCT_INVERSES = {} # op -> inverse op __all__ = [ 'DISTRIBUTIVE_OPS', 'Op', 'PRODUCT_INVERSES', 'TransformOp', 'UNITS', ]