# Example: Forward-Backward algorithm¶

```import argparse
from collections import OrderedDict
from typing import Dict, List, Tuple

import funsor.ops as ops
from funsor import Funsor, Tensor
from funsor.domains import Bint
from funsor.testing import assert_close, random_tensor

def forward_algorithm(
factors: List[Funsor],
step: Dict[str, str],
) -> Tuple[Funsor, List[Funsor]]:
"""
Calculate log marginal probability using the forward algorithm:
Z = p(y[0:T])

Transition and emission probabilities are given by init and trans factors:
init = p(y[0], x[0])
trans[t] = p(y[t], x[t] | x[t-1])

Forward probabilities are computed inductively:
alpha[t] = p(y[0:t], x[t])
alpha[0] = init
alpha[t+1] = alpha[t] @ trans[t+1]
"""
step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
reduce_vars = frozenset(drop)

# base case
alpha = factors[0]
alphas = [alpha]
# inductive steps
for trans in factors[1:]:
alpha = (alpha(**curr_to_drop) + trans(**prev_to_drop)).reduce(
)
alphas.append(alpha)
else:
return Z

def forward_backward_algorithm(
factors: List[Funsor],
step: Dict[str, str],
) -> List[Tensor]:
"""
Calculate marginal probabilities:
p(x[t], x[t-1] | Y)
"""
step = OrderedDict(sorted(step.items()))
drop = tuple("_drop_{}".format(i) for i in range(len(step)))
prev_to_drop = dict(zip(step.keys(), drop))
curr_to_drop = dict(zip(step.values(), drop))
reduce_vars = frozenset(drop)

# Base cases
alpha = factors[0]  # alpha[0] = p(y[0], x[0])
beta = Tensor(
ops.full_like(alpha.data, 0.0), alpha(x_curr="x_prev").inputs
)  # beta[T] = 1

# Backward algorithm
# beta[t] = p(y[t+1:T] | x[t])
# beta[t] = trans[t+1] @ beta[t+1]
betas = [beta]
for trans in factors[:0:-1]:
beta = (trans(**curr_to_drop) + beta(**prev_to_drop)).reduce(
)
betas.append(beta)
else:
init = factors[0]
Z = (init(**curr_to_drop) + beta(**prev_to_drop)).reduce(
)
betas.reverse()

# Forward algorithm & Marginal computations
marginal = alpha + betas[0](**{"x_prev": "x_curr"}) - Z  # p(x[0] | Y)
marginals = [marginal]
# inductive steps
for trans, beta in zip(factors[1:], betas[1:]):
# alpha[t-1] * trans[t] = p(y[0:t], x[t-1], x[t])
alpha_trans = alpha(**{"x_curr": "x_prev"}) + trans
# alpha[t] = p(y[0:t], x[t])
# alpha[t-1] * trans[t] * beta[t] / Z = p(x[t-1], x[t] | Y)
marginal = alpha_trans + beta(**{"x_prev": "x_curr"}) - Z
marginals.append(marginal)

return marginals

def main(args):
"""
Compute marginal probabilities p(x[t], x[t-1] | Y) for an HMM:

x[0] -> ... -> x[t-1] -> x[t] -> ... -> x[T]
|              |         |             |
v              v         v             v
y[0]           y[t-1]    y[t]           y[T]

Z = p(y[0:T])
alpha[t] = p(y[0:t], x[t])
beta[t] = p(y[t+1:T] | x[t])
trans[t] = p(y[t], x[t] | x[t-1])

p(x[t], x[t-1] | Y) = alpha[t-1] * beta[t] * trans[t] / Z

d Z / d trans[t] = alpha[t-1] * beta[t]

**References:**

[1] Jason Eisner (2016)
"Inside-Outside and Forward-Backward Algorithms Are Just Backprop
(Tutorial Paper)"
https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf
[2] Zhifei Li and Jason Eisner (2009)
"First- and Second-Order Expectation Semirings
with Applications to Minimum-Risk Training on Translation Forests"
http://www.cs.jhu.edu/~zfli/pubs/semiring_translation_zhifei_emnlp09.pdf
"""

# transition and emission probabilities
init = random_tensor(OrderedDict([("x_curr", Bint[args.hidden_dim])]))
factors = [init]
for t in range(args.time_steps - 1):
factors.append(
random_tensor(
OrderedDict(x_prev=Bint[args.hidden_dim], x_curr=Bint[args.hidden_dim])
)
)

# Compute marginal probabilities using the forward-backward algorithm
marginals = forward_backward_algorithm(factors, {"x_prev": "x_curr"})
# Compute marginal probabilities using backpropagation
Z = forward_algorithm(factors, {"x_prev": "x_curr"})

t = 0
# adjoint returns dZ / dtrans = alpha[t-1] * beta[t]
# marginal = alpha[t-1] * beta[t] * trans / Z
actual = adj + trans - Z
assert_close(expected, actual.align(tuple(expected.inputs)), rtol=1e-4)
print("")
print(f"Marginal term: p(x[{t}], x[{t-1}] | Y)")
print("Forward-backward algorithm:\n", expected.data)
print("Differentiating forward algorithm:\n", actual.data)
t += 1

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Forward-Backward Algorithm Is Just Backprop"
)