# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from multipledispatch import Dispatcher
[docs]class Op(Dispatcher):
def __init__(self, fn, *, name=None):
if isinstance(fn, str):
fn, name = None, fn
if name is None:
name = fn.__name__
super(Op, self).__init__(name)
if fn is not None:
# register as default operation
for nargs in (1, 2):
default_signature = (object,) * nargs
self.add(default_signature, fn)
def __copy__(self):
return self
def __deepcopy__(self, memo):
return self
def __reduce__(self):
return self.__name__
def __repr__(self):
return "ops." + self.__name__
def __str__(self):
return self.__name__
# Op registration tables.
DISTRIBUTIVE_OPS = set() # (add, mul) pairs
UNITS = {} # op -> value
PRODUCT_INVERSES = {} # op -> inverse op
__all__ = [
'DISTRIBUTIVE_OPS',
'Op',
'PRODUCT_INVERSES',
'TransformOp',
'UNITS',
]