Menu

Module providing Consciousness Exploration Tools for PyTorch.

Source code for consciousnet.plotting.patterns

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Spatio temporal patterns plots.
"""

# Imports
import numpy as np
import matplotlib.pyplot as plt


[docs]def plot_spatiotemporal_patterns(patterns, sigma, channel_id, fig=None, outfile=None): """ Display the spatiotemporal patterns. Parameters ---------- patterns: array (l, s, c, d) the patterns to be displayed. sigma: float the traversal range [+sigma, -sigma]. channel_id: int the channel to be displayed [0, c[. fig: Figure, default None a matplotlib figure. outfile: str, default None optionally specify a file to save the plot. """ if fig is None: fig = plt.figure() fig = mosaic( patterns[:, :, channel_id], fig=fig, title=(r"Temporal patterns using [$-{0}\sigma$, $+{0}\sigma$] " "traversal").format(sigma), y_labels=[r"$-{0}\sigma$".format(sigma), r"0", r"$+{0}\sigma$".format( sigma)]) if outfile is not None: fig.savefig(outfile)
[docs]def mosaic(data, title=None, y_labels=None, ncol=4, fig=None): """ Display a mosaic of images. """ n_plots = len(data) nrow = n_plots // ncol vmin = data.min() vmax = np.percentile(data, 99) if n_plots % ncol != 0: nrow += 1 if fig is None: fig = plt.figure() for idx, img in enumerate(data): ax = fig.add_subplot(nrow, ncol, idx + 1) ax.imshow(img, vmin=vmin, vmax=vmax, aspect="auto", cmap="jet") ax.set_xlim(0, img.shape[1]) ax.set_xticks(np.arange(0, img.shape[1] + 1, img.shape[1] // 2)) ax.set_ylim(0, img.shape[0]) ax.set_yticks(np.arange(0, img.shape[0] + 1, img.shape[0] // 2)) if idx % ncol != 0: ax.get_yaxis().set_visible(False) if idx < (n_plots - ncol) != 0: ax.get_xaxis().set_visible(False) if y_labels is not None: assert len(ax.get_xticklabels()) == len(y_labels) ax.set_yticklabels(y_labels, fontsize=8) for _idx in range(idx + 1, nrow * ncol): ax.axis("off") plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.87, wspace=0.1, hspace=0.25) if title is not None: plt.suptitle(title)

Follow us

© 2021, consciousnet developers