{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Jumpy prediction with Temporal Difference Variational Auto-Encoder (TD-VAE)\n\nCredit: A Grigis\n\nTD-VAE is designed such that it have all following three features:\n\n* it learns a state representation of observations and makes predictions on\n  the state level.\n* based on observations, it learns a belief state that contains all\n  the inforamtion required to make predictions about the future.\n* it learns also to make predictions multiple steps in the future directly\n  instead of make predictions step by step by connecting states\n  that are multiple steps apart.\n\nIn this example we reproduce the experiment about moving MNIST digits.\nIn this experiment, a sequence of a MNIST digit moving to the left or the\nright direction is presented to the model. The model need to predict how the\ndigit moves in the following steps. After training the model, a sequence of\ndigits can be fed into the model to see how well it can predict the further.\n\nThe `test` variable must be set to False to run a full training.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nimport time\nimport copy\nimport numpy as np\nfrom matplotlib import gridspec\nimport matplotlib.pyplot as plt\nimport torch\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\nfrom dataify import MovingMNISTDataset\nfrom consciousnet.models import TDVAE\nfrom consciousnet.losses import TDVAELoss\n\ntest = True\ndatasetdir = \"/tmp/moving_mnist\"\nif not os.path.isdir(datasetdir):\n    os.mkdir(datasetdir)\ninput_size = 784\nprocessed_x_size = 784\nbelief_state_size = 50\nstate_size = 8\nt = 16\nd = 4\nadd_sigmoid = True\nn_samples = 3 if test else 512\nn_epochs = 3 if test else 4000\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Moving MNIST digits dataset\n\nFetch & load the moving MNIST digits dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ds_train = MovingMNISTDataset(\n    root=datasetdir, train=True, seq_size=20, shift=1, binary=True)\nds_val = MovingMNISTDataset(\n    root=datasetdir, train=False, seq_size=20, shift=1, binary=True)\nif test:\n    ds_train = torch.utils.data.random_split(\n        ds_train, [100, len(ds_train) - 100])[0]\n    ds_val = torch.utils.data.random_split(\n        ds_val, [100, len(ds_val) - 100])[0]\ndatasets = {\"train\": ds_train, \"val\": ds_val}\ndataloaders = {x: torch.utils.data.DataLoader(\n    datasets[x], batch_size=n_samples, shuffle=True, num_workers=1)\n        for x in [\"train\", \"val\"]}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Training\n\nCreate/train the model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def train_model(dataloaders, model, device, criterion, optimizer,\n                scheduler=None, n_epochs=100, checkpointdir=None,\n                save_after_epochs=1, board=None, board_updates=None,\n                load_best=False):\n    \"\"\" General function to train a model and display training metrics.\n\n    Parameters\n    ----------\n    dataloaders: dict of torch.utils.data.DataLoader\n        the train & validation data loaders.\n    model: nn.Module\n        the model to be trained.\n    device: torch.device\n        the device to work on.\n    criterion: torch.nn._Loss\n        the criterion to be optimized.\n    optimizer: torch.optim.Optimizer\n        the optimizer.\n    scheduler: torch.optim.lr_scheduler, default None\n        the scheduler.\n    n_epochs: int, default 100\n        the number of epochs.\n    checkpointdir: str, default None\n        a destination folder where intermediate models/histories will be\n        saved.\n    save_after_epochs: int, default 1\n        determines when the model is saved and represents the number of\n        epochs before saving.\n    board: brainboard.Board, default None\n        a board to display live results.\n    board_updates: list of callable, default None\n        update displayed item on the board.\n    load_best: bool, default False\n        optionally load the best model regarding the loss.\n    \"\"\"\n    since = time.time()\n    if board_updates is not None:\n        board_updates = listify(board_updates)\n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = sys.float_info.max\n    dataset_sizes = {x: len(dataloaders[x]) for x in [\"train\", \"val\"]}\n    model = model.to(device)\n    for epoch in range(n_epochs):\n        print(\"Epoch {0}/{1}\".format(epoch, n_epochs - 1))\n        print(\"-\" * 10)\n        for phase in [\"train\", \"val\"]:\n            if phase == \"train\":\n                model.train()  \n            else:\n                model.eval()   \n            running_loss = 0.0\n            for batch_data, _ in dataloaders[phase]:\n                batch_data = batch_data.to(device)\n                # Zero the parameter gradients\n                optimizer.zero_grad()\n                # Forward:\n                # track history if only in train\n                with torch.set_grad_enabled(phase == \"train\"):\n                    outputs, layer_outputs = model(batch_data)\n                    criterion.layer_outputs = layer_outputs\n                    loss, extra_loss = criterion(outputs, batch_data)\n                    # Backward + optimize only if in training phase\n                    if phase == \"train\":\n                        loss.backward()\n                        optimizer.step()\n                # Statistics\n                running_loss += loss.item() * batch_data[0].size(0)\n            if scheduler is not None and phase == \"train\":\n                scheduler.step()\n            epoch_loss = running_loss / dataset_sizes[phase]\n            print(\"{0} Loss: {1:.4f}\".format(phase, epoch_loss))\n            if board is not None:\n                board.update_plot(\"loss_{0}\".format(phase), epoch, epoch_loss)\n            # Display validation classification results\n            if board_updates is not None and phase == \"val\":\n                for update in board_updates:\n                    update(model, board, outputs, layer_outputs)\n            # Deep copy the best model\n            if phase == \"val\" and epoch_loss < best_loss:\n                best_loss = epoch_loss\n                best_model_wts = copy.deepcopy(model.state_dict())\n        # Save intermediate results\n        if checkpointdir is not None and epoch % save_after_epochs == 0:\n            outfile = os.path.join(\n                checkpointdir, \"model_{0}.pth\".format(epoch))\n            checkpoint(\n                model=model, outfile=outfile, optimizer=optimizer,\n                scheduler=scheduler, epoch=epoch, epoch_loss=epoch_loss)\n        print()\n    time_elapsed = time.time() - since\n    print(\"Training complete in {:.0f}m {:.0f}s\".format(\n        time_elapsed // 60, time_elapsed % 60))\n    print(\"Best val loss: {:4f}\".format(best_loss))\n    # Load best model weights\n    if load_best:\n        model.load_state_dict(best_model_wts)\n\n\ndef listify(data):\n    \"\"\" Ensure that the input is a list or tuple.\n\n    Parameters\n    ----------\n    arr: list or array\n        the input data.\n\n    Returns\n    -------\n    out: list\n        the liftify input data.\n    \"\"\"\n    if isinstance(data, list) or isinstance(data, tuple):\n        return data\n    else:\n        return [data]\n\n\ndef checkpoint(model, outfile, optimizer=None, scheduler=None,\n               **kwargs):\n    \"\"\" Save the weights of a given model.\n\n    Parameters\n    ----------\n    model: nn.Module\n        the model to be saved.\n    outfile: str\n        the destination file name.\n    optimizer: torch.optim.Optimizer\n        the optimizer.\n    scheduler: torch.optim.lr_scheduler, default None\n        the scheduler.\n    kwargs: dict\n        others parameters to be saved.\n    \"\"\"\n    kwargs.update(model=model.state_dict())\n    if optimizer is not None:\n        kwargs.update(optimizer=optimizer.state_dict())\n    if scheduler is not None:\n        kwargs.update(scheduler=scheduler.state_dict())\n    torch.save(kwargs, outfile)\n\n\nmodel = TDVAE(x_dim=input_size, b_dim=belief_state_size, z_dim=state_size,\n              t=t, d=d, n_layers=2, n_lstm_layers=1,\n              preproc_dim=processed_x_size, add_sigmoid=add_sigmoid)\nprint(model)\noptimizer = optim.Adam(model.parameters(), lr=0.0005)\ncriterion = TDVAELoss(obs_loss=torch.nn.functional.binary_cross_entropy)\ntrain_model(dataloaders, model, device, criterion, optimizer,\n            scheduler=None, n_epochs=n_epochs, checkpointdir=None,\n            board=None, load_best=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Jumpy predictions\n\nA sequence of digits is fed into the model to see how well it can\npredict the 4 further images with a time jump of 11 steps.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "t1, t2 = 11, 15\nbatch_display_size = 3\nmodel.eval()\nidx, (data, _) = next(enumerate(dataloaders[\"val\"]))\ndata = data.to(device)\n# calculate belief\nmodel.forward(data)\n# jumpy rollout\nrollout_data = model.rollout(data, t1, t2)\n# plot results\nimages = data.cpu().detach().numpy()\nrollout_images = rollout_data.cpu().detach().numpy()\nfig = plt.figure(0, figsize=(12, 4))\ngs = gridspec.GridSpec(batch_display_size, t2 + 2)\ngs.update(wspace = 0.05, hspace = 0.05)\nfor i in range(batch_display_size):\n    for j in range(t1):\n        ax = plt.subplot(gs[i, j])\n        ax.imshow(1 - images[i, j].reshape(28, 28),\n                  cmap=\"binary\")\n        ax.axis(\"off\")\n    for j in range(t1, t2 + 1):\n        ax = plt.subplot(gs[i, j + 1])\n        ax.imshow(1 - rollout_images[i, j - t1].reshape(28, 28),\n                  cmap=\"binary\")\n        ax.axis(\"off\")\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}