{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Create circo-like graphs\n\nCredit: A Grigis\n\nAll the plots are perfomred with graph-tools embeded in a Singularity\nsingularity container. Please install first Singularity.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# sphinx_gallery_thumbnail_path = '_static/carousel/circo.png'\n\nimport os\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom nilearn import datasets\nfrom nilearn.input_data import NiftiMapsMasker\nfrom nilearn.connectome import ConnectivityMeasure\nfrom consciousnet.plotting import plot_circo, plot_graph"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Connectivity dataset\n\nFirst generate a connectivty matrix using the ROIs defined in the MSDL\ntemplate.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def threshold_adj(adj, threshold):\n    # Get number of connections to filter\n    n_nodes = len(adj)\n    n_conn_to_filter = int((threshold / 100.) * (n_nodes * (n_nodes - 1) / 2))\n\n    # For threshold operations, zero out lower triangle (including diagonal)\n    adj[np.tril_indices(n_nodes)] = 0\n\n    # Following code is similar to bctpy\n    indices = np.where(adj)\n    sorted_indices = np.argsort(adj[indices])[::-1]\n    adj[(indices[0][sorted_indices][n_conn_to_filter:],\n         indices[1][sorted_indices][n_conn_to_filter:])] = 0\n\n    # Just to get a symmetrical matrix\n    adj = adj + adj.T\n\n    # Diagonals need connection of 1 for graph operations\n    adj[np.diag_indices(n_nodes)] = 1.0\n\n    return adj\n\n\ndef imshow(image_file):\n    image = plt.imread(image_file)\n    fig, ax = plt.subplots()\n    im = ax.imshow(image)\n    ax.axis(\"off\")\n\n\ntmpdir = \"/tmp/circo\"\nif not os.path.isdir(tmpdir):\n    os.mkdir(tmpdir)\natlas = datasets.fetch_atlas_msdl(data_dir=tmpdir)\natlas_filename = atlas[\"maps\"]\nlabels = atlas[\"labels\"]\nnetworks = atlas[\"networks\"]\nnetworks = [elem.decode(\"utf-8\") for elem in networks]\ndata = datasets.fetch_development_fmri(n_subjects=1, data_dir=tmpdir)\nfmri_filenames = data.func[0]\nmasker = NiftiMapsMasker(maps_img=atlas_filename, standardize=True,\n                         verbose=5)\ntime_series = masker.fit_transform(fmri_filenames, confounds=data.confounds)\ncorrelation_measure = ConnectivityMeasure(kind=\"correlation\")\ncorrelation_matrix = correlation_measure.fit_transform([time_series])[0]\ncorrelation_matrix = threshold_adj(np.abs(correlation_matrix), 50)\nnp.fill_diagonal(correlation_matrix, 0)\ncorrelation_matrix[correlation_matrix > 0.55] = (\n    correlation_matrix[correlation_matrix > 0.55] * 1.1)\ncorrelation_matrix[correlation_matrix > 0.85] = (\n    correlation_matrix[correlation_matrix > 0.85] * 2.5)\ncorrelation_matrix /= correlation_matrix.max()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Connectivity display\n\nNow display the connectivity as a graph or a circular flow chart.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "adj = correlation_matrix.tolist()\nnames = labels\nhemi_map = {\"R\": 0, \"L\": 1}\nhemi_groups = [hemi_map.get(elem[0], 2) for elem in names]\ncolors = [\n    \"#FFC020\", \"#64064\", \"#146432\", \"#3CDC3C\", \"#14DC3C\", \"#A08CB4\",\n    \"#DC1414\", \"#DC3C14\", \"#DCB4DC\", \"#9696C8\", \"#B42878\", \"#234B32\",\n    \"#141E8C\", \"#E18C8C\", \"#C8234B\", \"#50148C\", \"#B4DC8C\", \"#A06432\",\n    \"#64190\", \"#196428\", \"#78643C\", \"#50A014\", \"#14B48C\", \"#4B327D\",\n    \"#DC3CDC\", \"#7D64A0\", \"#8CDCDC\", \"#3C14DC\", \"#DCB48C\", \"#DC14A\",\n    \"#14DCA0\", \"#8C148C\", \"#DC1464\", \"#464646\"]\ncolors = [tuple(int(elem.lstrip(\"#\")[i: i + 2], 16) / 255. for i in (0, 2, 4))\n          for elem in colors]\ncolor_map = dict((key, colors[cnt]) for cnt, key in enumerate(set(networks)))\ngroup_names = networks\ngroup_colors = [color_map[elem] for elem in group_names]\ncirco_file = plot_circo(\n    adj=adj, names=names, hemi_groups=hemi_groups, group_names=group_names,\n    group_colors=group_colors, outdir=tmpdir, with_labels=False)\nprint(circo_file)\ncirco_file = plot_circo(\n    adj=adj, names=names, hemi_groups=hemi_groups, group_names=group_names,\n    group_colors=group_colors, outdir=tmpdir, with_labels=True)\nprint(circo_file)\ngraph_file = plot_graph(\n    adj=adj, names=names, hemi_groups=hemi_groups, outdir=tmpdir)\nprint(graph_file)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}