Named tensor notation with funsors (Part 1)¶
Introduction¶
Mathematical notation with named axes introduced in Named Tensor Notation (Chiang, Rush, Barak 2021) improves the readability of mathematical formulas involving multidimensional arrays. This includes tensor operations such as elementwise operations, reductions, contractions, renaming, indexing, and broadcasting. In this tutorial we translate examples from Named Tensor Notation into funsors to demonstrate the implementation of these operations in funsor library and familiarize readers with funsor syntax. Part 1 covers examples from 2 Informal Overview, 3.4.2 Advanced Indexing, and 5 Formal Definitions.
First, let’s import some dependencies.
[ ]:
!pip install funsor[torch]@git+https://github.com/pyro-ppl/funsor
[1]:
from torch import tensor
import funsor
import funsor.ops as ops
from funsor import Number, Tensor, Variable
from funsor.domains import Bint
funsor.set_backend("torch")
Named Tensors¶
Each tensor axis is given a name:
[2]:
A = Tensor(tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]]))["height", "width"]
Access elements of \(A\) using named indices:
[3]:
# A(height=0, width=2) =
A(width=2, height=0)
[3]:
Tensor(tensor(4))
Partial indexing:
[4]:
A(height=0)
[4]:
Tensor(tensor([3, 1, 4]), {'width': Bint[3]})
[5]:
A(width=2)
[5]:
Tensor(tensor([4, 9, 5]), {'height': Bint[3]})
Named tensor operations¶
Elementwise operations and broadcasting¶
Elementwise operations:
[6]:
# A.sigmoid() =
# ops.sigmoid(A) =
# 1 / (1 + ops.exp(-A)) =
1 / (1 + (-A).exp())
[6]:
Tensor(tensor([[0.9526, 0.7311, 0.9820],
[0.7311, 0.9933, 0.9999],
[0.8808, 0.9975, 0.9933]]), {'height': Bint[3], 'width': Bint[3]})
Tensors with different shapes are automatically broadcasted against each other before an operation is applied. Let
[7]:
x = Tensor(tensor([2, 7, 1]))["height"]
y = Tensor(tensor([1, 4, 1]))["width"]
Binary addition operation:
[8]:
# ops.add(A, x) =
A + x
[8]:
Tensor(tensor([[ 5, 3, 6],
[ 8, 12, 16],
[ 3, 7, 6]]), {'height': Bint[3], 'width': Bint[3]})
[9]:
# ops.add(A, y) =
A + y
[9]:
Tensor(tensor([[ 4, 5, 5],
[ 2, 9, 10],
[ 3, 10, 6]]), {'height': Bint[3], 'width': Bint[3]})
Binary multiplication operation:
[10]:
# ops.mul(A, x) =
A * x
[10]:
Tensor(tensor([[ 6, 2, 8],
[ 7, 35, 63],
[ 2, 6, 5]]), {'height': Bint[3], 'width': Bint[3]})
Binary maximum operation:
[11]:
ops.max(A, y)
[11]:
Tensor(tensor([[3, 4, 4],
[1, 5, 9],
[2, 6, 5]]), {'height': Bint[3], 'width': Bint[3]})
Reductions¶
Named axes can be reduced over by calling the .reduce
method and specifying the reduction operator and names of reduced axes. Note that reduction is defined only for operators that are associative and commutative.
[12]:
A.reduce(ops.add, "height")
[12]:
Tensor(tensor([ 6, 12, 18]), {'width': Bint[3]})
[13]:
A.reduce(ops.add, "width")
[13]:
Tensor(tensor([ 8, 15, 13]), {'height': Bint[3]})
Reduction over multiple axes:
[14]:
A.reduce(ops.add, {"height", "width"})
[14]:
Tensor(tensor(36))
Multiplication reduction:
[15]:
A.reduce(ops.mul, "height")
[15]:
Tensor(tensor([ 6, 30, 180]), {'width': Bint[3]})
Max reduction:
[16]:
A.reduce(ops.max, "height")
[16]:
Tensor(tensor([3, 6, 9]), {'width': Bint[3]})
Contraction¶
Contraction operation can be written as elementwise multiplication followed by summation over an axis:
[17]:
(A * y).reduce(ops.add, "width")
[17]:
Tensor(tensor([11, 30, 31]), {'height': Bint[3]})
Some other operations from linear algebra:
[18]:
(x * x).reduce(ops.add, "height")
[18]:
Tensor(tensor(54))
[19]:
x * y
[19]:
Tensor(tensor([[ 2, 8, 2],
[ 7, 28, 7],
[ 1, 4, 1]]), {'height': Bint[3], 'width': Bint[3]})
[20]:
(A * y).reduce(ops.add, "width")
[20]:
Tensor(tensor([11, 30, 31]), {'height': Bint[3]})
[21]:
(x * A).reduce(ops.add, "height")
[21]:
Tensor(tensor([15, 43, 76]), {'width': Bint[3]})
[22]:
B = Tensor(
tensor([[3, 2, 5], [5, 4, 0], [8, 3, 6]]),
)["width", "width2"]
(A * B).reduce(ops.add, "width")
[22]:
Tensor(tensor([[ 46, 22, 39],
[100, 49, 59],
[ 76, 43, 40]]), {'height': Bint[3], 'width2': Bint[3]})
Contraction can be generalized to other binary and reduction operations:
[23]:
(A + y).reduce(ops.max, "width")
[23]:
Tensor(tensor([ 5, 10, 10]), {'height': Bint[3]})
Renaming and reshaping¶
Renaming funsors is simple:
[24]:
# A(height=Variable("height2", Bint[3]))
A(height="height2")
[24]:
Tensor(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]), {'height2': Bint[3], 'width': Bint[3]})
[25]:
layer = Variable("layer", Bint[9])
A_layer = A(height=layer // Number(3, 4), width=layer % Number(3, 4))
A_layer
[25]:
Tensor(tensor([3, 1, 4, 1, 5, 9, 2, 6, 5]), {'layer': Bint[9]})
[26]:
height = Variable("height", Bint[3])
width = Variable("width", Bint[3])
A_layer(layer=height * Number(3, 4) + width % Number(3, 4))
[26]:
Tensor(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]), {'height': Bint[3], 'width': Bint[3]})
Advanced indexing¶
All of advanced indexing can be achieved through name substitutions in funsors.
Partial indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,i)\):
[27]:
E = Tensor(
tensor([[2, 1, 5], [3, 4, 2], [1, 3, 7], [1, 4, 3], [5, 9, 2]]),
)["vocab", "emb"]
E(vocab=2)
[27]:
Tensor(tensor([1, 3, 7]), {'emb': Bint[3]})
Integer array indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,I)\):
[28]:
I = Tensor(tensor([3, 2, 4, 0]), dtype=5)["seq"]
E(vocab=I)
[28]:
Tensor(tensor([[1, 4, 3],
[1, 3, 7],
[5, 9, 2],
[2, 1, 5]]), {'seq': Bint[4], 'emb': Bint[3]})
Gather operation \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(P,I)\):
[29]:
P = Tensor(
tensor([[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]),
)["vocab", "seq"]
P(vocab=I)
[29]:
Tensor(tensor([1, 5, 2, 2]), {'seq': Bint[4]})
Indexing with two integer arrays:
[30]:
I1 = Tensor(tensor([1, 2, 0]), dtype=4)["subseq"]
I2 = Tensor(tensor([3, 0, 4]), dtype=5)["subseq"]
P(seq=I1, vocab=I2)
[30]:
Tensor(tensor([3, 4, 5]), {'subseq': Bint[3]})