Operation classes

class BinaryOp(*args, **kwargs)[source]

Bases: Op

arity = 2
class FinitaryOp(*args, **kwargs)[source]

Bases: Op

arity = 1
class LogAbsDetJacobianOp(*args, **kwargs)

Bases: BinaryOp

static default(x, y, fn)
dispatcher = <dispatched log_abs_det_jacobian>
name = 'log_abs_det_jacobian'
signature = <Signature (x, y, fn)>
class NullaryOp(*args, **kwargs)[source]

Bases: Op

arity = 0
class Op(*args, **kwargs)[source]

Bases: object

Abstract base class for all mathematical operations on ground terms.

Ops take arity-many leftmost positional args that may be funsors, followed by additional non-funsor args and kwargs. The additional args and kwargs must have default values.

When wrapping new backend ops, keep in mind these restrictions, which may require you to wrap backend functions before making them into ops:

  • Create new ops only by decoraing a default implementation with @UnaryOp.make, @BinaryOp.make, etc.

  • Register backend-specific implementations via @my_op.register(type1), @my_op.register(type1, type2) etc for arity 1, 2, etc. Patterns may include only the first arity-many types.

  • Only the first arity-many arguments may be funsors. Remaining args and kwargs must all be ground Python data.


~.arity (int) – The number of funsor arguments this op takes. Must be defined by subclasses.

  • *args

  • **kwargs – All extra arguments to this op, excluding the arguments up to .arity,

arity = NotImplemented
classmethod subclass_register(*pattern)[source]
classmethod make(fn=None, *, name=None, metaclass=None, module_name='funsor.ops')[source]

Factory to create a new Op subclass together with a new default instance of that class.


fn (callable) – A function whose signature can be inspected.


The new default instance.

Return type


class TernaryOp(*args, **kwargs)[source]

Bases: Op

arity = 3
class TransformOp(*args, **kwargs)[source]

Bases: UnaryOp


fn (callable) – A function that inputs an arg y and outputs a value x such that y=self(x).


fn (callable) – A function that inputs two args x, y, where y=self(x), and returns log(abs(det(dy/dx))).

static inv(x)[source]
static log_abs_det_jacobian(x, y)[source]
class UnaryOp(*args, **kwargs)[source]

Bases: Op

arity = 1
class WrappedTransformOp(*args, **kwargs)

Bases: TransformOp

Wrapper for a backend Transform object that provides .inv and .log_abs_det_jacobian. This additionally validates shapes on the first __call__().

static default(x, fn, *, validate_args=True)

Wrapper for a backend Transform object that provides .inv and .log_abs_det_jacobian. This additionally validates shapes on the first __call__().

dispatcher = <dispatched wrapped_transform>
property inv
property log_abs_det_jacobian
name = 'wrapped_transform'
signature = <Signature (x, fn, *, validate_args=True)>
declare_op_types(locals_, all_, name_)[source]

Builtin operations

abs = ops.abs

Return the absolute value of the argument.

add = ops.add

Same as a + b.

and_ = ops.and_

Same as a & b.

atanh = ops.atanh

Return the inverse hyperbolic tangent of x.

eq = ops.eq

Same as a == b.

exp = ops.exp

Return e raised to the power of x.

floordiv = ops.floordiv

Same as a // b.

ge = ops.ge

Same as a >= b.

getitem = ops.getitem
getslice = ops.getslice
gt = ops.gt

Same as a > b.

invert = ops.invert

Same as ~a.

le = ops.le

Same as a <= b.

lgamma = ops.lgamma

Natural logarithm of absolute value of Gamma function at x.

log = ops.log
log1p = ops.log1p

Return the natural logarithm of 1+x (base e).

The result is computed in a way which is accurate for x near zero.

lshift = ops.lshift

Same as a << b.

lt = ops.lt

Same as a < b.

matmul = ops.matmul

Same as a @ b.

max = ops.max
min = ops.min
mod = ops.mod

Same as a % b.

mul = ops.mul

Same as a * b.

ne = ops.ne

Same as a != b.

neg = ops.neg

Same as -a.

null = ops.null

Placeholder associative op that unifies with any other op

or_ = ops.or_

Same as a | b.

pos = ops.pos

Same as +a.

pow = ops.pow

Same as a ** b.

reciprocal = ops.reciprocal
rshift = ops.rshift

Same as a >> b.

safediv = ops.safediv
safesub = ops.safesub
sigmoid = ops.sigmoid
sqrt = ops.sqrt

Return the square root of x.

sub = ops.sub

Same as a - b.

tanh = ops.tanh

Return the hyperbolic tangent of x.

truediv = ops.truediv

Same as a / b.

xor = ops.xor

Same as a ^ b.

Array operations

all = ops.all
amax = ops.amax
amin = ops.amin
any = ops.any
argmax = ops.argmax
argmin = ops.argmin
astype = ops.astype
cat = ops.cat
cholesky = ops.cholesky

Like numpy.linalg.cholesky() but uses sqrt for scalar matrices.

cholesky_inverse = ops.cholesky_inverse

Like torch.cholesky_inverse() but supports batching and gradients.

cholesky_solve = ops.cholesky_solve
clamp = ops.clamp
detach = ops.detach
diagonal = ops.diagonal
einsum = ops.einsum
expand = ops.expand
finfo = ops.finfo
flip = ops.flip
full_like = ops.full_like
isnan = ops.isnan
logaddexp = ops.logaddexp
logsumexp = ops.logsumexp
mean = ops.mean
new_arange = ops.new_arange
new_eye = ops.new_eye
new_full = ops.new_full
new_zeros = ops.new_zeros
permute = ops.permute
prod = ops.prod
qr = ops.qr
randn = ops.randn
sample = ops.sample
scatter = ops.scatter
scatter_add = ops.scatter_add
stack = ops.stack
std = ops.std
sum = ops.sum
transpose = ops.transpose
triangular_inv = ops.triangular_inv
triangular_solve = ops.triangular_solve
unsqueeze = ops.unsqueeze
var = ops.var