Source code for funsor.ops.array

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import numpy as np

from .builtin import AssociativeOp, add, exp, log, log1p, max, min, reciprocal, safediv, safesub, sqrt
from .op import DISTRIBUTIVE_OPS, Op

_builtin_all = all
_builtin_any = any

# This is used only for pattern matching.
array = (np.ndarray, np.generic)

all = Op(np.all)
amax = Op(np.amax)
amin = Op(np.amin)
any = Op(np.any)
astype = Op("astype")
cat = Op("cat")
clamp = Op("clamp")
diagonal = Op("diagonal")
einsum = Op("einsum")
full_like = Op(np.full_like)
prod = Op(np.prod)
stack = Op("stack")
sum = Op(np.sum)
transpose = Op("transpose")

sqrt.register(array)(np.sqrt)
exp.register(array)(np.exp)
log1p.register(array)(np.log1p)


[docs]class LogAddExpOp(AssociativeOp): pass
[docs]class SampleOp(LogAddExpOp): pass
@log.register(array) def _log(x): if x.dtype == 'bool': return np.where(x, 0., -math.inf) with np.errstate(divide='ignore'): # skip the warning of log(0.) return np.log(x) def _logaddexp(x, y): if hasattr(x, "__logaddexp__"): return x.__logaddexp__(y) if hasattr(y, "__rlogaddexp__"): return y.__logaddexp__(x) shift = max(detach(x), detach(y)) return log(exp(x - shift) + exp(y - shift)) + shift logaddexp = LogAddExpOp(_logaddexp, name="logaddexp") sample = SampleOp(_logaddexp, name="sample") class ReshapeMeta(type): _cache = {} def __call__(cls, shape): shape = tuple(shape) try: return ReshapeMeta._cache[shape] except KeyError: instance = super().__call__(shape) ReshapeMeta._cache[shape] = instance return instance
[docs]class ReshapeOp(Op, metaclass=ReshapeMeta): def __init__(self, shape): self.shape = shape super().__init__(self._default) def __reduce__(self): return ReshapeOp, (self.shape,) def _default(self, x): return x.reshape(self.shape)
@astype.register(array, str) def _astype(x, dtype): return x.astype(dtype) @cat.register(int, [array]) def _cat(dim, *x): return np.concatenate(x, axis=dim) @clamp.register(array, object, object) def _clamp(x, min, max): return np.clip(x, a_min=min, a_max=max) @Op def cholesky(x): """ Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices. """ if x.shape[-1] == 1: return np.sqrt(x) return np.linalg.cholesky(x) @Op def cholesky_inverse(x): """ Like :func:`torch.cholesky_inverse` but supports batching and gradients. """ return cholesky_solve(new_eye(x, x.shape[:-1]), x) @Op def cholesky_solve(x, y): y_inv = np.linalg.inv(y) A = np.swapaxes(y_inv, -2, -1) @ y_inv return A @ x @Op def detach(x): return x @diagonal.register(array, int, int) def _diagonal(x, dim1, dim2): return np.diagonal(x, axis1=dim1, axis2=dim2) @einsum.register(str, [array]) def _einsum(x, *operand): return np.einsum(x, *operand) @Op def expand(x, shape): prepend_dim = len(shape) - np.ndim(x) assert prepend_dim >= 0 shape = shape[:prepend_dim] + tuple(dx if size == -1 else size for dx, size in zip(np.shape(x), shape[prepend_dim:])) return np.broadcast_to(x, shape) return np.broadcast_to(x, shape) @Op def finfo(x): return np.finfo(x.dtype) @Op def is_numeric_array(x): return True if isinstance(x, array) else False @Op def logsumexp(x, dim): amax = np.amax(x, axis=dim, keepdims=True) # treat the case x = -inf amax = np.where(np.isfinite(amax), amax, 0.) return log(np.sum(np.exp(x - amax), axis=dim)) + amax.squeeze(axis=dim) @max.register(array, array) def _max(x, y): return np.maximum(x, y) @max.register((int, float), array) def _max(x, y): return np.clip(y, a_min=x, a_max=None) @max.register(array, (int, float)) def _max(x, y): return np.clip(x, a_min=y, a_max=None) @min.register(array, array) def _min(x, y): return np.minimum(x, y) @min.register((int, float), array) def _min(x, y): return np.clip(y, a_min=None, a_max=x) @min.register(array, (int, float)) def _min(x, y): return np.clip(x, a_min=None, a_max=y) @Op def new_arange(x, stop): return np.arange(stop) @new_arange.register(array, int, int, int) def _new_arange(x, start, stop, step): return np.arange(start, stop, step) @Op def new_zeros(x, shape): return np.zeros(shape, dtype=x.dtype) @Op def new_eye(x, shape): n = shape[-1] return np.broadcast_to(np.eye(n), shape + (n,)) @Op def permute(x, dims): return np.transpose(x, axes=dims) @reciprocal.register(array) def _reciprocal(x): result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max) return result @safediv.register(object, array) def _safediv(x, y): try: finfo = np.finfo(y.dtype) except ValueError: finfo = np.iinfo(y.dtype) return x * np.clip(np.reciprocal(y), a_min=None, a_max=finfo.max) @safesub.register(object, array) def _safesub(x, y): try: finfo = np.finfo(y.dtype) except ValueError: finfo = np.iinfo(y.dtype) return x + np.clip(-y, a_min=None, a_max=finfo.max) @stack.register(int, [array]) def _stack(dim, *x): return np.stack(x, axis=dim) @transpose.register(array, int, int) def _transpose(x, dim1, dim2): return np.swapaxes(x, dim1, dim2) @Op def triangular_solve(x, y, upper=False, transpose=False): if transpose: y = np.swapaxes(y, -2, -1) return np.linalg.inv(y) @ x @Op def unsqueeze(x, dim): return np.expand_dims(x, axis=dim) DISTRIBUTIVE_OPS.add((logaddexp, add)) DISTRIBUTIVE_OPS.add((sample, add)) __all__ = [ 'LogAddExpOp', 'ReshapeOp', 'SampleOp', 'all', 'amax', 'amin', 'any', 'astype', 'cat', 'cholesky', 'cholesky_inverse', 'cholesky_solve', 'clamp', 'detach', 'diagonal', 'einsum', 'expand', 'finfo', 'full_like', 'is_numeric_array', 'logaddexp', 'logsumexp', 'new_arange', 'new_eye', 'new_zeros', 'permute', 'prod', 'sample', 'stack', 'sum', 'transpose', 'triangular_solve', 'unsqueeze', ] __doc__ = "\n".join(".. autodata:: {}\n".format(_name) for _name in __all__ if isinstance(globals()[_name], Op))