{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Switching Linear Dynamical System\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\n\nimport torch\n\nimport funsor\nimport funsor.ops as ops\nimport funsor.torch.distributions as dist\n\n\ndef main(args):\n    funsor.set_backend(\"torch\")\n\n    # Declare parameters.\n    trans_probs = funsor.Tensor(\n        torch.tensor([[0.9, 0.1], [0.1, 0.9]], requires_grad=True)\n    )\n    trans_noise = funsor.Tensor(\n        torch.tensor(\n            [0.1, 1.0],  # low noise component  # high noisy component\n            requires_grad=True,\n        )\n    )\n    emit_noise = funsor.Tensor(torch.tensor(0.5, requires_grad=True))\n    params = [trans_probs.data, trans_noise.data, emit_noise.data]\n\n    # A Gaussian HMM model.\n    @funsor.interpretations.moment_matching\n    def model(data):\n        log_prob = funsor.Number(0.0)\n\n        # s is the discrete latent state,\n        # x is the continuous latent state,\n        # y is the observed state.\n        s_curr = funsor.Tensor(torch.tensor(0), dtype=2)\n        x_curr = funsor.Tensor(torch.tensor(0.0))\n        for t, y in enumerate(data):\n            s_prev = s_curr\n            x_prev = x_curr\n\n            # A delayed sample statement.\n            s_curr = funsor.Variable(f\"s_{t}\", funsor.Bint[2])\n            log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)\n\n            # A delayed sample statement.\n            x_curr = funsor.Variable(f\"x_{t}\", funsor.Real)\n            log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)\n\n            # Marginalize out previous delayed sample statements.\n            if t > 0:\n                log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name})\n\n            # An observe statement.\n            log_prob += dist.Normal(x_curr, emit_noise, value=y)\n\n        log_prob = log_prob.reduce(ops.logaddexp)\n        return log_prob\n\n    # Train model parameters.\n    torch.manual_seed(0)\n    data = torch.randn(args.time_steps)\n    optim = torch.optim.Adam(params, lr=args.learning_rate)\n    for step in range(args.train_steps):\n        optim.zero_grad()\n        log_prob = model(data)\n        assert not log_prob.inputs, \"free variables remain\"\n        loss = -log_prob.data\n        loss.backward()\n        optim.step()\n        if args.verbose and step % 10 == 0:\n            print(f\"step {step} loss = {loss.item()}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Switching linear dynamical system\")\n    parser.add_argument(\"-t\", \"--time-steps\", default=10, type=int)\n    parser.add_argument(\"-n\", \"--train-steps\", default=101, type=int)\n    parser.add_argument(\"-lr\", \"--learning-rate\", default=0.01, type=float)\n    parser.add_argument(\"--filter\", action=\"store_true\")\n    parser.add_argument(\"-v\", \"--verbose\", action=\"store_true\")\n    args = parser.parse_args()\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}