# Example: Talbot’s method for numerical inversion of the Laplace transform¶

```import argparse
import math

import torch

import funsor
import funsor.ops as ops
from funsor.domains import Real
from funsor.factory import Bound, Fresh, Has, make_funsor
from funsor.interpretations import StatefulInterpretation
from funsor.tensor import Tensor
from funsor.terms import Funsor, Variable
from funsor.util import get_backend

@make_funsor
def InverseLaplace(
F: Has[{"s"}], t: Funsor, s: Bound  # noqa: F821
) -> Fresh[lambda F: F]:
"""
Inverse Laplace transform of function F(s).

There is no closed-form solution for arbitrary F(s). However, we can
resort to numerical approximations which we store in new interpretations.
For example, see Talbot's method below.

:param F: function of s.
:param t: times at which to evaluate the inverse Laplace transformation of F.
:param s: s Variable.
"""
return None

class Talbot(StatefulInterpretation):
"""
Talbot's method for numerical inversion of the Laplace transform.

Reference
Abate, Joseph, and Ward Whitt. "A Unified Framework for Numerically
Inverting Laplace Transforms." INFORMS Journal of Computing, vol. 18.4
(2006): 408-421. Print. (http://www.columbia.edu/~ww2040/allpapers.html)

Implementation here is adapted from the MATLAB implementation of the algorithm by
Tucker McClure (2021). Numerical Inverse Laplace Transform
(https://www.mathworks.com/matlabcentral/fileexchange/39035-numerical-inverse-laplace-transform),
MATLAB Central File Exchange. Retrieved April 4, 2021.

:param num_steps: number of terms to sum for each t.
"""

def __init__(self, num_steps):
super().__init__("talbot")
self.num_steps = num_steps

@Talbot.register(InverseLaplace, Funsor, Funsor, Variable)
def talbot(self, F, t, s):
if get_backend() == "torch":
import torch

k = torch.arange(1, self.num_steps)
delta = torch.zeros(self.num_steps, dtype=torch.complex64)
delta[0] = 2 * self.num_steps / 5
delta[1:] = (
2 * math.pi / 5 * k * (1 / (math.pi / self.num_steps * k).tan() + 1j)
)

gamma = torch.zeros(self.num_steps, dtype=torch.complex64)
gamma[0] = 0.5 * delta[0].exp()
gamma[1:] = (
1
+ 1j
* math.pi
/ self.num_steps
* k
* (1 + 1 / (math.pi / self.num_steps * k).tan() ** 2)
- 1j / (math.pi / self.num_steps * k).tan()
) * delta[1:].exp()

delta = Tensor(delta)["num_steps"]
gamma = Tensor(gamma)["num_steps"]
ilt = 0.4 / t * (gamma * F(**{s.name: delta / t})).reduce(ops.add, "num_steps")

return Tensor(ilt.data.real, ilt.inputs)
else:
raise NotImplementedError(f"Unsupported backend {get_backend()}")

def main(args):
"""
Reference for the n-step sequential model used here:

Aaron L. Lucius et al (2003).
"General Methods for Analysis of Sequential ‘‘n-step’’ Kinetic Mechanisms:
Application to Single Turnover Kinetics of Helicase-Catalyzed DNA Unwinding"
https://www.sciencedirect.com/science/article/pii/S0006349503746487
"""

funsor.set_backend("torch")

# Problem definition.
true_rate = 20
true_nsteps = 4
rate = Variable("rate", Real)
nsteps = Variable("nsteps", Real)
s = Variable("s", Real)
time = Tensor(torch.arange(0.04, 1.04, 0.04))["timepoint"]
noise = Tensor(torch.randn(time.inputs["timepoint"].size))["timepoint"] / 500
data = (
Tensor(1 - torch.igammac(torch.tensor(true_nsteps), true_rate * time.data))[
"timepoint"
]
+ noise
)
F = rate**nsteps / (s * (rate + s) ** nsteps)
# Inverse Laplace.
pred = InverseLaplace(F, time, "s")

# Loss function.
loss = (pred - data).abs().reduce(ops.add, "timepoint")
init_params = {
}

with Talbot(num_steps=args.talbot_num_steps):
# Fit the data
args.num_steps,
lr=args.learning_rate,
log_every=args.log_every,
params=init_params,
) as optim:
loss.reduce(ops.min, {"rate", "nsteps"})
# Fitted curve.
fitted_curve = pred(rate=optim.param("rate"), nsteps=optim.param("nsteps"))

print(f"Data\n{data}")
print(f"Fit curve\n{fitted_curve}")
print(f"True rate\n{true_rate}")
print("Learned rate\n{}".format(optim.param("rate").item()))
print(f"True number of steps\n{true_nsteps}")
print("Learned number of steps\n{}".format(optim.param("nsteps").item()))

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Numerical inversion of the Laplace transform using Talbot's method"
)