import operator
from collections import namedtuple
from functools import reduce
import torch
from pyro.distributions.util import broadcast_shape
import funsor.ops as ops
from funsor.util import lazy_property, quote
[docs]class Domain(namedtuple('Domain', ['shape', 'dtype'])):
"""
An object representing the type and shape of a :class:`Funsor` input or
output.
"""
def __new__(cls, shape, dtype):
assert isinstance(shape, tuple)
if torch._C._get_tracing_state():
shape = tuple(map(int, shape))
assert all(isinstance(size, int) for size in shape), shape
if isinstance(dtype, int):
assert not shape
elif isinstance(dtype, str):
assert dtype == 'real'
else:
raise ValueError(repr(dtype))
return super(Domain, cls).__new__(cls, shape, dtype)
def __repr__(self):
shape = tuple(self.shape)
if isinstance(self.dtype, int):
if not shape:
return 'bint({})'.format(self.dtype)
return 'bint({}, {})'.format(self.dtype, shape)
if not shape:
return 'reals()'
return 'reals{}'.format(shape)
def __iter__(self):
if isinstance(self.dtype, int) and not self.shape:
from funsor.terms import Number
return (Number(i, self.dtype) for i in range(self.dtype))
raise NotImplementedError
[docs] @lazy_property
def num_elements(self):
return reduce(operator.mul, self.shape, 1)
@property
def size(self):
assert isinstance(self.dtype, int)
return self.dtype
@quote.register(Domain)
def _(arg, indent, out):
out.append((indent, repr(arg)))
[docs]def reals(*shape):
"""
Construct a real domain of given shape.
"""
return Domain(shape, 'real')
[docs]def bint(size):
"""
Construct a bounded integer domain of scalar shape.
"""
if torch._C._get_tracing_state():
size = int(size)
assert isinstance(size, int) and size >= 0
return Domain((), size)
[docs]def find_domain(op, *domains):
r"""
Finds the :class:`Domain` resulting when applying ``op`` to ``domains``.
:param callable op: An operation.
:param Domain \*domains: One or more input domains.
"""
assert callable(op), op
assert all(isinstance(arg, Domain) for arg in domains)
if len(domains) == 1:
dtype = domains[0].dtype
shape = domains[0].shape
if op is ops.log or op is ops.exp:
dtype = 'real'
elif isinstance(op, ops.ReshapeOp):
shape = op.shape
elif isinstance(op, ops.AssociativeOp):
shape = ()
return Domain(shape, dtype)
lhs, rhs = domains
if isinstance(op, ops.GetitemOp):
dtype = lhs.dtype
shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:]
return Domain(shape, dtype)
elif op == ops.matmul:
assert lhs.shape and rhs.shape
if len(rhs.shape) == 1:
assert lhs.shape[-1] == rhs.shape[-1]
shape = lhs.shape[:-1]
elif len(lhs.shape) == 1:
assert lhs.shape[-1] == rhs.shape[-2]
shape = rhs.shape[:-2] + rhs.shape[-1:]
else:
assert lhs.shape[-1] == rhs.shape[-2]
shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] + (1,)) + rhs.shape[-1:]
return Domain(shape, 'real')
if lhs.dtype == 'real' or rhs.dtype == 'real':
dtype = 'real'
elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min):
dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1
elif op in (ops.and_, ops.or_, ops.xor):
dtype = 2
elif lhs.dtype == rhs.dtype:
dtype = lhs.dtype
else:
raise NotImplementedError('TODO')
if lhs.shape == rhs.shape:
shape = lhs.shape
else:
shape = broadcast_shape(lhs.shape, rhs.shape)
return Domain(shape, dtype)
__all__ = [
'Domain',
'find_domain',
'bint',
'reals',
]