{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Barlow Twins: Self-Supervised Learning for clustering\n\nCredit: A Grigis\n\nA simple example on how to use the Barlow Twins to learn data representation\nin an unsupervised way via redundancy reduction that in turns is used for\nclustering using a simple linear layer to associate a learned representation\nwith a label.\n\nThe `test` variable must be set to False to run a full training.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nimport time\nimport json\nimport math\nimport types\nimport argparse\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torch\nfrom torch import nn, optim\nimport torch.nn.functional as func\nimport torchvision\nimport torchvision.transforms as transforms\nfrom consciousnet.augmentation import ContrastiveImageTransform\nfrom consciousnet.optim import LARS\nfrom consciousnet.models import BarlowTwins\n\ntest = True\nn_epochs = 3 if test else 1000\nbatch_size = 10\nlearning_rate_weights = 0.2\nlearning_rate_biases = 0.0048\nweight_decay = 1e-6\nlambd = 0.0051\nprojector = \"64-64\"\nprint_freq = 10\ndatasetdir = \"/tmp/minst\"\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## CIFAR10 dataset\n\nFetch/load the CIFAR10 dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def imshow(img):\n    \"\"\" Unnormalize image and display it.\n    \"\"\"\n    img = img / 2 + 0.5\n    npimg = img.numpy()\n    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n    plt.show()\n\n\ndataset = torchvision.datasets.CIFAR10(\n    root=datasetdir, train=True, download=True,\n    transform=ContrastiveImageTransform(50))\nif test:\n    dataset = torch.utils.data.random_split(\n        dataset, [100, len(dataset) - 100])[0]\ndataloader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=True, num_workers=1,\n    pin_memory=True)\n\ndataiter = iter(dataloader)\nimages, labels = dataiter.next()\nimshow(torchvision.utils.make_grid(images[0][:4]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Training\n\nCreate/train a simple conv network replacing the fully connected layer\nby a projection head.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class ConvNet(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.pool = nn.MaxPool2d(2, 2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc = nn.Linear(1296, 10)\n\n    def forward(self, x):\n        x = self.pool(func.relu(self.conv1(x)))\n        x = self.pool(func.relu(self.conv2(x)))\n        x = torch.flatten(x, 1)\n        x = self.fc(x)\n        return x\n\n\ndef adjust_learning_rate(n_epochs, batch_size, learning_rate_weights,\n                         learning_rate_biases, optimizer, loader, step):\n    max_steps = n_epochs * len(loader)\n    warmup_steps = 10 * len(loader)\n    base_lr = batch_size / 256\n    if step < warmup_steps:\n        lr = base_lr * step / warmup_steps\n    else:\n        step -= warmup_steps\n        max_steps -= warmup_steps\n        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))\n        end_lr = base_lr * 0.001\n        lr = base_lr * q + end_lr * (1 - q)\n    optimizer.param_groups[0][\"lr\"] = lr * learning_rate_weights\n    optimizer.param_groups[1][\"lr\"] = lr * learning_rate_biases\n\n\ndef exclude_bias_and_norm(p):\n    return p.ndim == 1\n\n\nmodel = BarlowTwins(\n    model=ConvNet(), fc_layer_name=\"fc\", fc_in_features=1296,\n    projector=projector, batch_size=batch_size, lambd=lambd).to(device)\nprint(model)\nparam_weights = []\nparam_biases = []\nfor param in model.parameters():\n    if param.ndim == 1:\n        param_biases.append(param)\n    else:\n        param_weights.append(param)\nparameters = [{\"params\": param_weights}, {\"params\": param_biases}]\noptimizer = LARS(parameters, lr=0, weight_decay=weight_decay,\n                 weight_decay_filter=exclude_bias_and_norm,\n                 lars_adaptation_filter=exclude_bias_and_norm)\nstart_time = time.time()\nif device.type == \"cpu\":\n    scaler = None\nelse:\n    scaler = torch.cuda.amp.GradScaler()\nfor epoch in range(n_epochs):\n    for step, ((y1, y2), _) in enumerate(\n            dataloader, start=epoch * len(dataloader)):\n        y1 = y1.to(device, non_blocking=True)\n        y2 = y2.to(device, non_blocking=True)\n        adjust_learning_rate(n_epochs, batch_size, learning_rate_weights,\n                             learning_rate_biases, optimizer, dataloader,\n                             step)\n        optimizer.zero_grad()\n        if device.type == \"cpu\":\n            loss = model.forward(y1, y2)\n        else:\n            with torch.cuda.amp.autocast():\n                loss = model.forward(y1, y2)\n        if scaler is not None:\n            scaler.scale(loss).backward()\n            scaler.unscale_(optimizer)\n            scaler.step(optimizer)\n            scaler.update()\n        else:\n            loss.backward()\n            optimizer.step()\n        if step % print_freq == 0:\n            stats = dict(epoch=epoch, step=step,\n                         lr_weights=optimizer.param_groups[0][\"lr\"],\n                         lr_biases=optimizer.param_groups[1][\"lr\"],\n                         loss=loss.item(),\n                         time=int(time.time() - start_time))\n            print(json.dumps(stats))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Evaluation: linear classification\n\nTrain a linear probe on the representations learned by Barlow Twins. Freeze\nthe weights of the resnet.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class AverageMeter(object):\n    \"\"\" Computes and stores the average and current value.\n    \"\"\"\n    def __init__(self, name, fmt=\":f\"):\n        self.name = name\n        self.fmt = fmt\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n    def __str__(self):\n        fmtstr = \"{name} {val\" + self.fmt + \"} ({avg\" + self.fmt + \"})\"\n        return fmtstr.format(**self.__dict__)\n\n\ndef accuracy(output, target, topk=(1,)):\n    \"\"\" Computes the accuracy over the k top predictions for the specified\n    values of k.\n    \"\"\"\n    with torch.no_grad():\n        maxk = max(topk)\n        batch_size = target.size(0)\n\n        _, pred = output.topk(maxk, 1, True, True)\n        pred = pred.t()\n        correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n        res = []\n        for k in topk:\n            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n            res.append(correct_k.mul_(100.0 / batch_size))\n        return res\n\n\nn_epochs = 3 if test else 100\nlr_classifier = 0.3\n\nnormalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n                                 std=[0.229, 0.224, 0.225])\ntransform_train = transforms.Compose([\n    transforms.RandomResizedCrop(50),\n    transforms.RandomHorizontalFlip(),\n    transforms.ToTensor(),\n    normalize\n])\ntransform_val = transforms.Compose([\n    transforms.CenterCrop(50),\n    transforms.ToTensor(),\n    normalize\n])\nds_train = torchvision.datasets.CIFAR10(\n    root=datasetdir, train=True, download=True, transform=transform_train)\nds_val = torchvision.datasets.CIFAR10(\n    root=datasetdir, train=False, download=True, transform=transform_val)\nif test:\n    ds_train = torch.utils.data.random_split(\n        ds_train, [100, len(ds_train) - 100])[0]\n    ds_val = torch.utils.data.random_split(\n        ds_val, [100, len(ds_val) - 100])[0]\ndatasets = {\"train\": ds_train, \"val\": ds_val}\ndataloaders = {x: torch.utils.data.DataLoader(\n    datasets[x], batch_size=batch_size, shuffle=True, num_workers=1)\n        for x in [\"train\", \"val\"]}\n\nreference_state_dict = model.backbone.state_dict()\nmodel = ConvNet().to(device)\nprint(model)\nmissing_keys, unexpected_keys = model.load_state_dict(\n    reference_state_dict, strict=False)\nassert (\n    missing_keys == [\"fc.weight\", \"fc.bias\"] and\n    unexpected_keys == [])\nmodel.fc.weight.data.normal_(mean=0.0, std=0.01)\nmodel.fc.bias.data.zero_()\nmodel.requires_grad_(False)\nmodel.fc.requires_grad_(True)\nclassifier_parameters, model_parameters = [], []\nfor name, param in model.named_parameters():\n    if name in {\"fc.weight\", \"fc.bias\"}:\n        classifier_parameters.append(param)\n    else:\n        model_parameters.append(param)\ncriterion = nn.CrossEntropyLoss().to(device)\nparam_groups = [dict(params=classifier_parameters, lr=lr_classifier)]\noptimizer = optim.SGD(\n    param_groups, 0, momentum=0.9, weight_decay=weight_decay)\nscheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)\nbest_acc = argparse.Namespace(top1=0, top5=0)\nstart_time = time.time()\nfor epoch in range(n_epochs):\n    model.eval()\n    for step, (y, labels) in enumerate(\n            dataloaders[\"train\"], start=epoch * len(dataloaders[\"train\"])):\n        output = model(y.to(device, non_blocking=True))\n        loss = criterion(output, labels.to(device, non_blocking=True))\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        if step % print_freq == 0:\n            pg = optimizer.param_groups\n            lr_classifier = pg[0][\"lr\"]\n            lr_backbone = pg[1][\"lr\"] if len(pg) == 2 else 0\n            stats = dict(epoch=epoch, step=step, lr_backbone=lr_backbone,\n                         lr_classifier=lr_classifier, loss=loss.item(),\n                         time=int(time.time() - start_time))\n            print(json.dumps(stats))\n\n    # Evaluate\n    model.eval()\n    top1 = AverageMeter(\"Acc@1\")\n    top5 = AverageMeter(\"Acc@5\")\n    with torch.no_grad():\n        for y, labels in dataloaders[\"val\"]:\n            output = model(y.to(device, non_blocking=True))\n            acc1, acc5 = accuracy(\n                output, labels.to(device, non_blocking=True), topk=(1, 5))\n            top1.update(acc1[0].item(), y.size(0))\n            top5.update(acc5[0].item(), y.size(0))\n    best_acc.top1 = max(best_acc.top1, top1.avg)\n    best_acc.top5 = max(best_acc.top5, top5.avg)\n    stats = dict(\n        epoch=epoch, acc1=top1.avg, acc5=top5.avg,\n        best_acc1=best_acc.top1, best_acc5=best_acc.top5)\n    print(json.dumps(stats))\n\n    # Sanity check\n    model_state_dict = model.state_dict()\n    for k in reference_state_dict:\n        assert torch.equal(\n            model_state_dict[k].cpu(), reference_state_dict[k].cpu()), k\n\n    scheduler.step()\n    state = dict(\n        epoch=epoch + 1, best_acc=best_acc, model=model.state_dict(),\n        optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict())"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}