Example: Biased Kalman Filter¶

import argparse
import itertools
import math
import os

import pyro.distributions as dist
import torch
import torch.nn as nn
from torch.optim import Adam

import funsor
import funsor.ops as ops
import funsor.torch.distributions as f_dist
from funsor.domains import Reals
from funsor.pyro.convert import dist_to_funsor, funsor_to_mvn
from funsor.tensor import Tensor, Variable

# We use a 2D continuous-time NCV dynamics model throughout.
# See http://webee.technion.ac.il/people/shimkin/Estimation09/ch8_target.pdf
TIME_STEP = 1.0
NCV_PROCESS_NOISE = torch.tensor(
[
[1 / 3, 0.0, 1 / 2, 0.0],
[0.0, 1 / 3, 0.0, 1 / 2],
[1 / 2, 0.0, 1.0, 0.0],
[0.0, 1 / 2, 0.0, 1.0],
]
)
NCV_TRANSITION_MATRIX = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
]
)

def generate_data(num_frames, num_sensors):
"""
Generate data from a damped NCV dynamics model
"""
dt = TIME_STEP
bias_scale = 4.0
obs_noise = 1.0
trans_noise = 0.3

# define dynamics
z = torch.cat([10.0 * torch.randn(2), torch.rand(2)])  # position  # velocity
damp = 0.1  # damp the velocities
f = torch.tensor(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[dt * math.exp(-damp * dt), 0, math.exp(-damp * dt), 0],
[0, dt * math.exp(-damp * dt), 0, math.exp(-damp * dt)],
]
)
trans_dist = dist.MultivariateNormal(
torch.zeros(4),
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
)

# define biased sensors
sensor_bias = bias_scale * torch.randn(2 * num_sensors)
h = torch.eye(4, 2).unsqueeze(-1).expand(-1, -1, num_sensors).reshape(4, -1)
obs_dist = dist.MultivariateNormal(
sensor_bias, scale_tril=obs_noise * torch.eye(2 * num_sensors)
)

states = []
observations = []
for t in range(num_frames):
z = z @ f + trans_dist.sample()
states.append(z)

x = z @ h + obs_dist.sample()
observations.append(x)

states = torch.stack(states)
observations = torch.stack(observations)
assert observations.shape == (num_frames, num_sensors * 2)
return observations, states, sensor_bias

class Model(nn.Module):
def __init__(self, num_sensors):
super(Model, self).__init__()
self.num_sensors = num_sensors

# learnable params
self.log_bias_scale = nn.Parameter(torch.tensor(0.0))
self.log_obs_noise = nn.Parameter(torch.tensor(0.0))
self.log_trans_noise = nn.Parameter(torch.tensor(0.0))

def forward(self, observations, add_bias=True):
obs_dim = 2 * self.num_sensors
bias_scale = self.log_bias_scale.exp()
obs_noise = self.log_obs_noise.exp()
trans_noise = self.log_trans_noise.exp()

# bias distribution
bias = Variable("bias", Reals[obs_dim])
assert not torch.isnan(bias_scale), "bias scales was nan"
bias_dist = dist_to_funsor(
dist.MultivariateNormal(
torch.zeros(obs_dim),
scale_tril=bias_scale * torch.eye(2 * self.num_sensors),
)
)(value=bias)

init_dist = dist.MultivariateNormal(
torch.zeros(4), scale_tril=100.0 * torch.eye(4)
)
self.init = dist_to_funsor(init_dist)(value="state")

# hidden states
prev = Variable("prev", Reals[4])
curr = Variable("curr", Reals[4])
self.trans_dist = f_dist.MultivariateNormal(
loc=prev @ NCV_TRANSITION_MATRIX,
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
value=curr,
)

state = Variable("state", Reals[4])
obs = Variable("obs", Reals[obs_dim])
observation_matrix = Tensor(
torch.eye(4, 2)
.unsqueeze(-1)
.expand(-1, -1, self.num_sensors)
.reshape(4, -1)
)
assert observation_matrix.output.shape == (
4,
obs_dim,
), observation_matrix.output.shape
obs_loc = state @ observation_matrix
obs_loc += bias
self.observation_dist = f_dist.MultivariateNormal(
loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs
)

logp = bias_dist
curr = "state_init"
logp += self.init(state=curr)
for t, x in enumerate(observations):
prev, curr = curr, "state_{}".format(t)
logp += self.trans_dist(prev=prev, curr=curr)
logp += self.observation_dist(state=curr, obs=x)
# marginalize out previous state
logp = logp.reduce(ops.logaddexp, prev)
# marginalize out bias variable
logp = logp.reduce(ops.logaddexp, "bias")

# save posterior over the final state
assert set(logp.inputs) == {"state_{}".format(len(observations) - 1)}
posterior = funsor_to_mvn(logp, ndims=0)

# marginalize out remaining variables
assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty()
return logp.data, posterior

def track(args):
results = {}  # keyed on (seed, bias, num_frames)
for seed in args.seed:
torch.manual_seed(seed)
observations, states, sensor_bias = generate_data(
max(args.num_frames), args.num_sensors
)
for bias, num_frames in itertools.product(args.bias, args.num_frames):
print(
"tracking with seed={}, bias={}, num_frames={}".format(
seed, bias, num_frames
)
)
model = Model(args.num_sensors)
optim = Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.8))
losses = []
for i in range(args.num_epochs):
log_prob, posterior = model(observations[:num_frames], add_bias=bias)
loss = -log_prob
loss.backward()
losses.append(loss.item())
if i % 10 == 0:
print(loss.item())
optim.step()

# Collect evaluation metrics.
final_state_true = states[num_frames - 1]
assert final_state_true.shape == (4,)
final_pos_true = final_state_true[:2]
final_vel_true = final_state_true[2:]

final_state_est = posterior.loc
assert final_state_est.shape == (4,)
final_pos_est = final_state_est[:2]
final_vel_est = final_state_est[2:]
final_pos_error = float(torch.norm(final_pos_true - final_pos_est))
final_vel_error = float(torch.norm(final_vel_true - final_vel_est))
print("final_pos_error = {}".format(final_pos_error))

results[seed, bias, num_frames] = {
"args": args,
"observations": observations[:num_frames],
"states": states[:num_frames],
"sensor_bias": sensor_bias,
"losses": losses,
"bias_scale": float(model.log_bias_scale.exp()),
"obs_noise": float(model.log_obs_noise.exp()),
"trans_noise": float(model.log_trans_noise.exp()),
"final_state_estimate": posterior,
"final_pos_error": final_pos_error,
"final_vel_error": final_vel_error,
}
if args.metrics_filename:
print("saving output to: {}".format(args.metrics_filename))
torch.save(results, args.metrics_filename)
return results

def main(args):
funsor.set_backend("torch")
if (
args.force
or not args.metrics_filename
or not os.path.exists(args.metrics_filename)
):
# Increase compression threshold for numerical stability.
with funsor.gaussian.Gaussian.set_compression_threshold(3):
results = track(args)
else:

if args.plot_filename:
import matplotlib

matplotlib.use("Agg")
import numpy as np
from matplotlib import pyplot

seeds = set(seed for seed, _, _ in results)
X = args.num_frames
pyplot.figure(figsize=(5, 1.4), dpi=300)

pos_error = np.array(
[
[results[s, 0, f]["final_pos_error"] for s in seeds]
for f in args.num_frames
]
)
mse = (pos_error ** 2).mean(axis=1)
std = (pos_error ** 2).std(axis=1) / len(seeds) ** 0.5
pyplot.plot(X, mse ** 0.5, "k--")
pyplot.fill_between(
X, (mse - std) ** 0.5, (mse + std) ** 0.5, color="black", alpha=0.15, lw=0
)

pos_error = np.array(
[
[results[s, 1, f]["final_pos_error"] for s in seeds]
for f in args.num_frames
]
)
mse = (pos_error ** 2).mean(axis=1)
std = (pos_error ** 2).std(axis=1) / len(seeds) ** 0.5
pyplot.plot(X, mse ** 0.5, "r-")
pyplot.fill_between(
X, (mse - std) ** 0.5, (mse + std) ** 0.5, color="red", alpha=0.15, lw=0
)

pyplot.ylabel("Position RMSE")
pyplot.xlabel("Track Length")
pyplot.xticks((5, 10, 15, 20, 25, 30))
pyplot.xlim(5, 30)
pyplot.tight_layout(0)
pyplot.savefig(args.plot_filename)

def int_list(args):
result = []
for arg in args.split(","):
if "-" in arg:
beg, end = map(int, arg.split("-"))
result.extend(range(beg, 1 + end))
else:
result.append(int(arg))
return result

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Biased Kalman filter")
"--seed",
default="0",
type=int_list,
help="random seed, comma delimited for multiple runs",
)
"--bias",
default="0,1",
type=int_list,
help="whether to model bias, comma deliminted for multiple runs",
)
"-f",
"--num-frames",
default="5,10,15,20,25,30",
type=int_list,
help="number of sensor frames, comma delimited for multiple runs",
)
parser.add_argument("-n", "--num-epochs", default=50, type=int)
args = parser.parse_args()
main(args)


Gallery generated by Sphinx-Gallery