{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Unsupervised clustering with GMVAE\n\nCredit: A Grigis\n\nUnsupervised Gaussian Mixture Variational Auto-encoder (GMVAE) on a synthetic\ndataset. In this example we attempt to replicate the work described in this\n[blog](http://ruishu.io/2016/12/25/gmvae) inspired from this\n[paper](https://arxiv.org/abs/1611.02648).\n\nThe `test` variable must be set to False to run a full training.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# sphinx_gallery_thumbnail_path = '_static/carousel/latent-space.jpg'\n\nimport os\nimport sys\nimport time\nimport copy\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom matplotlib.ticker import NullFormatter\nfrom sklearn import manifold\nfrom sklearn.cluster import KMeans\nfrom sklearn.preprocessing import StandardScaler\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import Dataset\nfrom consciousnet.models import GMVAE\nfrom consciousnet.losses import GMVAELoss\n\ntest = True\nn_samples = 100\nn_classes = 3\nn_feats = 4\ntrue_lat_dims = 2\nfit_lat_dims = 5\nsnr = 10\nbatch_size = 10\nadam_lr = 2e-3\nn_epochs = 3 if test else 100\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Synthetic dataset\n\nA Gaussian Linear multi-class synthetic dataset is generated as\nfollows. The number of the latent dimensions used to generate the data can be\ncontrolled.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class GeneratorUniform(nn.Module):\n    \"\"\" Generate multiple sources (channels) of data through a linear\n    generative model:\n\n    z ~ N(mu,sigma)\n    for c_idx in n_channels:\n        x_ch = W_ch(c_idx)\n    where 'W_ch' is an arbitrary linear mapping z -> x_ch\n    \"\"\"\n    def __init__(self, lat_dim=2, n_channels=2, n_feats=5, seed=100):\n        super(GeneratorUniform, self).__init__()\n        self.lat_dim = lat_dim\n        self.n_channels = n_channels\n        self.n_feats = n_feats\n        self.seed = seed\n        np.random.seed(self.seed)\n        W = []\n        for c_idx in range(n_channels):\n            w_ = np.random.uniform(-1, 1, (self.n_feats, lat_dim))\n            u, s, vt = np.linalg.svd(w_, full_matrices=False)\n            w = (u if self.n_feats >= lat_dim else vt)\n            W.append(torch.nn.Linear(lat_dim, self.n_feats, bias=False))\n            W[c_idx].weight.data = torch.FloatTensor(w)\n        self.W = torch.nn.ModuleList(W)\n\n    def forward(self, z):\n        if isinstance(z, list):\n            return [self.forward(_) for _ in z]\n        if type(z) == np.ndarray:\n            z = torch.FloatTensor(z)\n        assert z.size(dim=1) == self.lat_dim\n        obs = []\n        for c_idx in range(self.n_channels):\n            x = self.W[c_idx](z)\n            obs.append(x.detach())\n        return obs\n\n\nclass SyntheticDataset(Dataset):\n    def __init__(self, n_samples=500, lat_dim=2, n_feats=5, n_classes=2,\n                 generatorclass=GeneratorUniform, snr=1, train=True):\n        super(SyntheticDataset, self).__init__()\n        self.n_samples = n_samples\n        self.lat_dim = lat_dim\n        self.n_feats = n_feats\n        self.n_classes = n_classes\n        self.snr = snr\n        self.train = train\n        self.labels = []\n        self.z = []\n        self.x = []\n        seed = 7 if self.train else 14\n        np.random.seed(seed)\n        locs = np.random.uniform(-5, 5, (self.n_classes, ))\n        np.random.seed(seed)\n        scales = np.random.uniform(0, 2, (self.n_classes, ))\n        np.random.seed(seed)\n        for k_idx in range(self.n_classes):\n            self.z.append(\n                np.random.normal(loc=locs[k_idx], scale=scales[k_idx],\n                                 size=(self.n_samples, self.lat_dim)))\n            self.generator = generatorclass(\n                lat_dim=self.lat_dim, n_channels=1, n_feats=self.n_feats)\n            self.x.append(self.generator(self.z[-1])[0])\n            self.labels += [k_idx] * self.n_samples\n        self.data = np.concatenate(self.x, axis=0)\n        self.labels = np.asarray(self.labels)\n        _, self.data = preprocess_and_add_noise(self.data, snr=snr)\n        self.data = self.data.astype(np.float32)\n\n    def __len__(self):\n        return self.n_samples\n\n    def __getitem__(self, item):\n        return self.data[item], self.labels[item]\n\n\ndef preprocess_and_add_noise(x, snr, seed=0):\n    scalers = StandardScaler().fit(x)\n    x_std = scalers.transform(x)\n    np.random.seed(seed)\n    sigma_noise = np.sqrt(1. / snr)\n    x_std_noisy = x_std + sigma_noise * np.random.randn(*x_std.shape)\n    return x_std, x_std_noisy\n\n\nds_train = SyntheticDataset(\n    n_samples=n_samples, lat_dim=true_lat_dims, n_feats=n_feats,\n    n_classes=n_classes, train=True, snr=snr)\nds_val = SyntheticDataset(\n    n_samples=n_samples, lat_dim=true_lat_dims, n_feats=n_feats,\n    n_classes=n_classes, train=False, snr=snr)\ndatasets = {\"train\": ds_train, \"val\": ds_val}\ndataloaders = {x: torch.utils.data.DataLoader(\n    datasets[x], batch_size=batch_size, shuffle=True, num_workers=1)\n        for x in [\"train\", \"val\"]}\n\nmethod = manifold.TSNE(n_components=2, init=\"pca\", random_state=0)\ny_train = method.fit_transform(ds_train.data)\ny_val = method.fit_transform(ds_val.data)\nfig, axs = plt.subplots(nrows=3, ncols=2)\nfor cnt, (name, y, labels) in enumerate((\n        (\"train\", y_train, ds_train.labels),\n        (\"val\", y_val, ds_val.labels))):\n    colors = labels.astype(float)\n    colors /= colors.max()\n    axs[0, cnt].scatter(y[:, 0], y[:, 1], c=colors, cmap=plt.cm.Spectral)\n    axs[0, cnt].xaxis.set_major_formatter(NullFormatter())\n    axs[0, cnt].yaxis.set_major_formatter(NullFormatter())\n    axs[0, cnt].set_title(\"GT clustering ({0})\".format(name))\n    axs[0, cnt].axis(\"tight\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## ML clustering\n\nAs a ground truth we performed a K-means clustering of the data.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "kmeans = KMeans(n_clusters=n_classes, random_state=0).fit(ds_train.data)\ntrain_labels = kmeans.labels_\ntrain_acc = GMVAELoss.cluster_acc(train_labels, ds_train.labels)\nprint(\"-- K-Means ACC train\", train_acc)\nval_labels = kmeans.predict(ds_val.data)\nval_acc = GMVAELoss.cluster_acc(val_labels, ds_val.labels)\nprint(\"-- K-Means ACC val\",val_acc)\n\nfor cnt, (name, y, labels, acc) in enumerate((\n        (\"train\", y_train, train_labels, train_acc),\n        (\"val\", y_val, val_labels, val_acc))):\n    colors = labels.astype(float)\n    colors /= colors.max()\n    axs[1, cnt].scatter(y[:, 0], y[:, 1], c=colors, cmap=plt.cm.Spectral)\n    axs[1, cnt].xaxis.set_major_formatter(NullFormatter())\n    axs[1, cnt].yaxis.set_major_formatter(NullFormatter())\n    axs[1, cnt].set_title(\n        \"K-means clustering ({0}-ACC:{1:.3f})\".format(name, acc))\n    axs[1, cnt].axis(\"tight\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Training\n\nCreate/train the model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def train_model(dataloaders, model, device, criterion, optimizer,\n                scheduler=None, n_epochs=100, checkpointdir=None,\n                save_after_epochs=1, board=None, board_updates=None,\n                load_best=False):\n    \"\"\" General function to train a model and display training metrics.\n\n    Parameters\n    ----------\n    dataloaders: dict of torch.utils.data.DataLoader\n        the train & validation data loaders.\n    model: nn.Module\n        the model to be trained.\n    device: torch.device\n        the device to work on.\n    criterion: torch.nn._Loss\n        the criterion to be optimized.\n    optimizer: torch.optim.Optimizer\n        the optimizer.\n    scheduler: torch.optim.lr_scheduler, default None\n        the scheduler.\n    n_epochs: int, default 100\n        the number of epochs.\n    checkpointdir: str, default None\n        a destination folder where intermediate models/histories will be\n        saved.\n    save_after_epochs: int, default 1\n        determines when the model is saved and represents the number of\n        epochs before saving.\n    board: brainboard.Board, default None\n        a board to display live results.\n    board_updates: list of callable, default None\n        update displayed item on the board.\n    load_best: bool, default False\n        optionally load the best model regarding the loss.\n    \"\"\"\n    since = time.time()\n    if board_updates is not None:\n        board_updates = listify(board_updates)\n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = sys.float_info.max\n    dataset_sizes = {x: len(dataloaders[x]) for x in [\"train\", \"val\"]}\n    model = model.to(device)\n    for epoch in range(n_epochs):\n        print(\"Epoch {0}/{1}\".format(epoch, n_epochs - 1))\n        print(\"-\" * 10)\n        for phase in [\"train\", \"val\"]:\n            if phase == \"train\":\n                model.train()  \n            else:\n                model.eval()   \n            running_loss = 0.0\n            for batch_data, batch_labels in dataloaders[phase]:\n                batch_data = batch_data.to(device)\n                batch_labels = batch_labels.to(device)\n                # Zero the parameter gradients\n                optimizer.zero_grad()\n                # Forward:\n                # track history if only in train\n                with torch.set_grad_enabled(phase == \"train\"):\n                    outputs, layer_outputs = model(batch_data)\n                    criterion.layer_outputs = layer_outputs\n                    loss, extra_loss = criterion(\n                        outputs, batch_data, labels=batch_labels)\n                    # Backward + optimize only if in training phase\n                    if phase == \"train\":\n                        loss.backward()\n                        optimizer.step()\n                # Statistics\n                running_loss += loss.item() * batch_data[0].size(0)\n            if scheduler is not None and phase == \"train\":\n                scheduler.step()\n            epoch_loss = running_loss / dataset_sizes[phase]\n            print(\"{0} Loss: {1:.4f}\".format(phase, epoch_loss))\n            if board is not None:\n                board.update_plot(\"loss_{0}\".format(phase), epoch, epoch_loss)\n            # Display validation classification results\n            if board_updates is not None and phase == \"val\":\n                for update in board_updates:\n                    update(model, board, outputs, layer_outputs)\n            # Deep copy the best model\n            if phase == \"val\" and epoch_loss < best_loss:\n                best_loss = epoch_loss\n                best_model_wts = copy.deepcopy(model.state_dict())\n        # Save intermediate results\n        if checkpointdir is not None and epoch % save_after_epochs == 0:\n            outfile = os.path.join(\n                checkpointdir, \"model_{0}.pth\".format(epoch))\n            checkpoint(\n                model=model, outfile=outfile, optimizer=optimizer,\n                scheduler=scheduler, epoch=epoch, epoch_loss=epoch_loss)\n        print()\n    time_elapsed = time.time() - since\n    print(\"Training complete in {:.0f}m {:.0f}s\".format(\n        time_elapsed // 60, time_elapsed % 60))\n    print(\"Best val loss: {:4f}\".format(best_loss))\n    # Load best model weights\n    if load_best:\n        model.load_state_dict(best_model_wts)\n\n\ndef listify(data):\n    \"\"\" Ensure that the input is a list or tuple.\n\n    Parameters\n    ----------\n    arr: list or array\n        the input data.\n\n    Returns\n    -------\n    out: list\n        the liftify input data.\n    \"\"\"\n    if isinstance(data, list) or isinstance(data, tuple):\n        return data\n    else:\n        return [data]\n\n\ndef checkpoint(model, outfile, optimizer=None, scheduler=None,\n               **kwargs):\n    \"\"\" Save the weights of a given model.\n\n    Parameters\n    ----------\n    model: nn.Module\n        the model to be saved.\n    outfile: str\n        the destination file name.\n    optimizer: torch.optim.Optimizer\n        the optimizer.\n    scheduler: torch.optim.lr_scheduler, default None\n        the scheduler.\n    kwargs: dict\n        others parameters to be saved.\n    \"\"\"\n    kwargs.update(model=model.state_dict())\n    if optimizer is not None:\n        kwargs.update(optimizer=optimizer.state_dict())\n    if scheduler is not None:\n        kwargs.update(scheduler=scheduler.state_dict())\n    torch.save(kwargs, outfile)\n\nmodel = GMVAE(\n    input_dim=n_feats, latent_dim=fit_lat_dims, n_mix_components=n_classes,\n    sigma_min=0.001, raw_sigma_bias=0.25, dropout=0, temperature=1,\n    gen_bias_init=0.)\nprint(model)\noptimizer = optim.Adam(model.parameters(), lr=adam_lr)\ncriterion = GMVAELoss()\ntrain_model(dataloaders, model, device, criterion, optimizer,\n            scheduler=None, n_epochs=n_epochs, checkpointdir=None,\n            board=None, load_best=False)\n\nmodel.eval()\nwith torch.no_grad():\n    p_x_given_z, dists = model(\n        torch.from_numpy(ds_train.data.astype(np.float32)).to(device))\nq_y_given_x = dists[\"q_y_given_x\"]\ntrain_labels = np.argmax(q_y_given_x.logits.detach().cpu().numpy(), axis=1)\ntrain_acc = GMVAELoss.cluster_acc(\n    q_y_given_x.logits, ds_train.labels, is_logits=True)\nprint(\"-- GMVAE ACC train\", train_acc)\nwith torch.no_grad():\n    p_x_given_z, dists = model(\n            torch.from_numpy(ds_val.data.astype(np.float32)).to(device))\nq_y_given_x = dists[\"q_y_given_x\"]\nval_labels = np.argmax(q_y_given_x.logits.detach().cpu().numpy(), axis=1)\nval_acc = GMVAELoss.cluster_acc(\n    q_y_given_x.logits, ds_val.labels, is_logits=True)\nprint(\"-- GMVAE ACC val\", val_acc)\n\nfor cnt, (name, y, labels, acc) in enumerate((\n        (\"train\", y_train, train_labels, train_acc),\n        (\"val\", y_val, val_labels, val_acc))):\n    colors = labels.astype(float)\n    colors /= colors.max()\n    axs[2, cnt].scatter(y[:, 0], y[:, 1], c=colors, cmap=plt.cm.Spectral)\n    axs[2, cnt].xaxis.set_major_formatter(NullFormatter())\n    axs[2, cnt].yaxis.set_major_formatter(NullFormatter())\n    axs[2, cnt].set_title(\n        \"GMVAE clustering ({0}-ACC:{1:.3f})\".format(name, acc))\n    axs[2, cnt].axis(\"tight\")\nplt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95,\n                    wspace=0.1, hspace=0.5)\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}