{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Spatiotemporal Trajectories in Resting-state FMRI\n\nCredit: A Grigis\n\nIn this example we illustrate how we can extract meaningful spatiotemporal\ninformation from a Variational Auto-Encoder (VAE) using rfMRI data.\n\n\nThe `test` variable must be set to False to run a full training.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nimport time\nimport copy\nfrom pprint import pprint\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torch\nfrom dataify import SinOscillatorDataset\nfrom brainite.models import VAE\nfrom brainite.losses import BetaHLoss\nfrom brainite.utils import traversals\nfrom consciousnet.plotting import plot_reconstruction_error\nfrom consciousnet.plotting import plot_spatiotemporal_patterns\n\n\ntest = True\nn_samples = 20\nadam_lr = 0.01\nbatch_size = 10\nn_epochs = 10\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Sinus oscillator dataset\n\nFetch/load the SinOscillator dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dataset = SinOscillatorDataset(\n    n_samples=n_samples, duration=4, fs=10, freq=(0.6, 0.7),\n    amp=1, phase=np.pi, target_snr=20, seed=42)\ndataloader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=True, num_workers=1)\nitem = next(iter(dataloader))\nprint(item.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Training\n\nTrain a VAE with 1-D temporal convolutions.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def train_model(dataloader, model, device, criterion, optimizer,\n                scheduler=None, n_epochs=100, checkpointdir=None,\n                save_after_epochs=1, board=None, board_updates=None,\n                load_best=False):\n    \"\"\" General function to train a model and display training metrics.\n\n    Parameters\n    ----------\n    dataloader: torch.utils.data.DataLoader\n        the data loader.\n    model: nn.Module\n        the model to be trained.\n    device: torch.device\n        the device to work on.\n    criterion: torch.nn._Loss\n        the criterion to be optimized.\n    optimizer: torch.optim.Optimizer\n        the optimizer.\n    scheduler: torch.optim.lr_scheduler, default None\n        the scheduler.\n    n_epochs: int, default 100\n        the number of epochs.\n    checkpointdir: str, default None\n        a destination folder where intermediate models/histories will be\n        saved.\n    save_after_epochs: int, default 1\n        determines when the model is saved and represents the number of\n        epochs before saving.\n    board: brainboard.Board, default None\n        a board to display live results.\n    board_updates: list of callable, default None\n        update displayed item on the board.\n    load_best: bool, default False\n        optionally load the best model regarding the loss.\n    \"\"\"\n    since = time.time()\n    if board_updates is not None:\n        board_updates = listify(board_updates)\n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = sys.float_info.max\n    dataset_size = len(dataloader)\n    model = model.to(device)\n    for epoch in range(n_epochs):\n        print(\"Epoch {0}/{1}\".format(epoch, n_epochs - 1))\n        print(\"-\" * 10)\n        model.train()\n        running_loss = 0.0\n        for batch_data in dataloader:\n            batch_data = torch.transpose(batch_data, 1, 2)\n            batch_data = batch_data.to(device)\n            # Zero the parameter gradients\n            optimizer.zero_grad()\n            # Forward:\n            # track history if only in train\n            with torch.set_grad_enabled(True):\n                outputs, layer_outputs = model(batch_data)\n                criterion.layer_outputs = layer_outputs\n                loss, extra_loss = criterion(outputs, batch_data)\n                # Backward + optimize only if in training phase\n                loss.backward()\n                optimizer.step()\n            # Statistics\n            running_loss += loss.item() * batch_data[0].size(0)\n        if scheduler is not None:\n            scheduler.step()\n        epoch_loss = running_loss / dataset_size\n        print(\"Loss: {:.4f}\".format(epoch_loss))\n        if board is not None:\n            board.update_plot(\"loss\", epoch, epoch_loss)\n        # Display validation classification results\n        if board_updates is not None:\n            for update in board_updates:\n                update(model, board, outputs, layer_outputs)\n        # Deep copy the best model\n        if epoch_loss < best_loss:\n            best_loss = epoch_loss\n            best_model_wts = copy.deepcopy(model.state_dict())\n        # Save intermediate results\n        if checkpointdir is not None and epoch % save_after_epochs == 0:\n            outfile = os.path.join(\n                checkpointdir, \"model_{0}.pth\".format(epoch))\n            checkpoint(\n                model=model, outfile=outfile, optimizer=optimizer,\n                scheduler=scheduler, epoch=epoch, epoch_loss=epoch_loss)\n        print()\n    time_elapsed = time.time() - since\n    print(\"Training complete in {:.0f}m {:.0f}s\".format(\n        time_elapsed // 60, time_elapsed % 60))\n    print(\"Best val loss: {:4f}\".format(best_loss))\n    # Load best model weights\n    if load_best:\n        model.load_state_dict(best_model_wts)\n\n\ndef listify(data):\n    \"\"\" Ensure that the input is a list or tuple.\n\n    Parameters\n    ----------\n    arr: list or array\n        the input data.\n\n    Returns\n    -------\n    out: list\n        the liftify input data.\n    \"\"\"\n    if isinstance(data, list) or isinstance(data, tuple):\n        return data\n    else:\n        return [data]\n\n\ndef checkpoint(model, outfile, optimizer=None, scheduler=None,\n               **kwargs):\n    \"\"\" Save the weights of a given model.\n\n    Parameters\n    ----------\n    model: nn.Module\n        the model to be saved.\n    outfile: str\n        the destination file name.\n    optimizer: torch.optim.Optimizer\n        the optimizer.\n    scheduler: torch.optim.lr_scheduler, default None\n        the scheduler.\n    kwargs: dict\n        others parameters to be saved.\n    \"\"\"\n    kwargs.update(model=model.state_dict())\n    if optimizer is not None:\n        kwargs.update(optimizer=optimizer.state_dict())\n    if scheduler is not None:\n        kwargs.update(scheduler=scheduler.state_dict())\n    torch.save(kwargs, outfile)\n\nmodel = VAE(\n    input_channels=1, input_dim=40, conv_flts=[16], dense_hidden_dims=None,\n    latent_dim=8, noise_fixed=True, act_func=None, dropout=0, sparse=False)\nprint(model)\noptimizer = torch.optim.Adam(model.parameters(), lr=adam_lr)\ncriterion = BetaHLoss(beta=1, use_mse=True)\ntrain_model(dataloader, model, device, criterion, optimizer,\n            n_epochs=n_epochs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Exporing VAE\n\nTrain a VAE with 1-D temporal convolutions.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def test_model(dataloader, model, device):\n    \"\"\" General function to test a model.\n\n    Parameters\n    ----------\n    dataloaders: dict of torch.utils.data.DataLoader\n        the train & validation data loaders.\n    model: nn.Module\n        the trained model.\n    device: torch.device\n        the device to work on.\n    \"\"\"\n    was_training = model.training\n    model.eval()\n    data, rec_data = [], []\n    with torch.no_grad():\n        for idx, batch_data, in enumerate(dataloader):\n            batch_data = torch.transpose(batch_data, 1, 2)\n            data.append(batch_data.numpy())\n            batch_data = batch_data.to(device)\n            outputs, layer_outputs = model(batch_data)\n            rec_data.append(VAE.p_to_prediction(outputs))\n    model.train(mode=was_training)\n    data = np.concatenate(data, axis=0).squeeze()\n    rec_data = np.concatenate(rec_data, axis=0).squeeze()\n    return data, rec_data\n\nn_samples = 30\nsigma = 3\nst_patterns = traversals(\n    model, device, n_per_latent=n_samples, max_traversal=sigma)\nplot_spatiotemporal_patterns(st_patterns, sigma, channel_id=0)\n\ndata, rec_data = test_model(dataloader, model, device)\nsimilarity = plot_reconstruction_error(data, rec_data)\npprint(similarity)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}