diff --git a/kymata/plot/plot.py b/kymata/plot/plot.py index 79e0336c..dba977bd 100644 --- a/kymata/plot/plot.py +++ b/kymata/plot/plot.py @@ -1,10 +1,14 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass from itertools import cycle from pathlib import Path from statistics import NormalDist from typing import Optional, Sequence, Dict, NamedTuple +from warnings import warn import numpy as np -import os from matplotlib import pyplot from matplotlib.colors import to_hex, LinearSegmentedColormap from matplotlib.lines import Line2D @@ -32,28 +36,66 @@ class _Ax: minimap = "minimap" -def _get_mosaic_spec(paired_axes: bool, minimap: bool) -> list[list[str]]: +@dataclass +class _MosaicSpec: + mosaic: list[list[str]] + width_ratios: list[float] | None + fig_size: tuple[float, float] + subplots_adjust_kwargs: dict[str, float] = None + + def __post_init__(self): + if self.subplots_adjust_kwargs is None: + self.subplots_adjust_kwargs = dict() + + def to_subplots(self) -> tuple[pyplot.Figure, dict[str, pyplot.Axes]]: + return pyplot.subplot_mosaic( + self.mosaic, + width_ratios=self.width_ratios, + figsize=self.fig_size) + + +def _minimap_mosaic(paired_axes: bool, minimap: bool) -> _MosaicSpec: + # Set defaults: + if minimap: + width_ratios = [1, 3] + fig_size = (12, 7) + subplots_adjust = { + "hspace": 0, "wspace": 0.1, + "left": 0.04, "right": 0.8, + } + else: + width_ratios = None + fig_size = (12, 7) + subplots_adjust = { + "hspace": 0, + "left": 0.08, "right": 0.84, + } + + if paired_axes: if minimap: - return [ - [_Ax.top, _Ax.minimap], - [_Ax.bottom, _Ax.minimap], + spec = [ + [_Ax.minimap, _Ax.top], + [_Ax.minimap, _Ax.bottom], ] else: - return [ + spec = [ [_Ax.top], [_Ax.bottom], ] else: if minimap: - return [ - [_Ax.main, _Ax.minimap], + spec = [ + [_Ax.minimap, _Ax.main], ] else: - return [ + spec = [ [_Ax.main], ] + return _MosaicSpec(mosaic=spec, width_ratios=width_ratios, fig_size=fig_size, + subplots_adjust_kwargs=subplots_adjust) + def _hexel_minimap_data(expression_set: HexelExpressionSet, alpha_logp: float) -> tuple[NDArray, NDArray]: """ @@ -135,13 +177,24 @@ def _plot_minimap_hexel(expression_set: HexelExpressionSet, minimap_axis: pyplot stc = SourceEstimate(data=np.concatenate([data_left, data_right]), vertices=[expression_set.hexels_left, expression_set.hexels_right], tmin=0, tstep=1) - minimap_axis.imshow(stc.plot(subject='participant_01', - hemi="split", - colormap=colormap, - smoothing_steps = 2, - background="white", - spacing="ico5" - ).to_image()) + warn("Plotting on the fsaverage brain. Ensure that hexel numbers match those of the fsaverage brain.") + p = stc.plot(subject='fsaverage', + hemi="split", + colormap=colormap, + smoothing_steps = 2, + background="white", + spacing="ico5", + brain_kwargs={"offscreen": True}, + ) + minimap_axis.imshow(p.screenshot(), aspect="auto") + hide_axes(minimap_axis) + + +def hide_axes(axes: pyplot.Axes): + """Hide all axes markings from a pyplot.Axes.""" + axes.get_xaxis().set_visible(False) + axes.get_yaxis().set_visible(False) + axes.axis("off") def _plot_minimap(expression_set: ExpressionSet, minimap_axis: pyplot.Axes, colors: dict[str, str], alpha_logp): @@ -242,24 +295,19 @@ def expression_plot( sidak_corrected_alpha = p_to_logp(sidak_corrected_alpha) - mosaic = _get_mosaic_spec(paired_axes=paired_axes, minimap=minimap) + mosaic = _minimap_mosaic(paired_axes=paired_axes, minimap=minimap) fig: pyplot.Figure axes: dict[str, pyplot.Axes] + fig, axes = mosaic.to_subplots() + expression_axes_list: list[pyplot.Axes] if paired_axes: - fig, axes = pyplot.subplot_mosaic(mosaic, - width_ratios=[3, 1] if minimap else None, - figsize=(12, 7)) expression_axes_list = [axes[_Ax.top], axes[_Ax.bottom]] # For iterating over in a predictable order else: - fig, axes = pyplot.subplot_mosaic(mosaic, - width_ratios=[3, 1] if minimap else None, - figsize=(12, 7)) expression_axes_list = [axes[_Ax.main]] - fig.subplots_adjust(hspace=0) - fig.subplots_adjust(right=0.84, left=0.08) + fig.subplots_adjust(**mosaic.subplots_adjust_kwargs) custom_handles = [] custom_labels = [] @@ -337,7 +385,7 @@ def expression_plot( # Plot minimap if minimap is not None: - os.environ["SUBJECTS_DIR"] = Path(minimap.data_root_dir, minimap.mri_structurals_directory) + os.environ["SUBJECTS_DIR"] = str(Path(minimap["data_root_dir"], minimap["mri_structurals_directory"])) _plot_minimap(expression_set=expression_set, minimap_axis=axes[_Ax.minimap], colors=color, alpha_logp=sidak_corrected_alpha) @@ -360,7 +408,10 @@ def expression_plot( bottom_ax.text(s=axes_names[1], x=bottom_ax_xmin + 20, y=ylim * 0.95, style='italic', verticalalignment='center') - fig.supylabel('p-value (with α at 5-sigma, Šidák corrected)', x=0, y=0.5) + # TODO: revert this + # fig.text(x=0.04, y=0.5, + # s='p-value (with α at 5-sigma, Šidák corrected)', + # ha="center", va="center", rotation="vertical") if bottom_ax_xmin <= 0 <= bottom_ax_xmax: bottom_ax.text(s=' onset of environment ', x=0, y=0 if paired_axes else ylim/2, @@ -519,19 +570,3 @@ def plot_top_five_channels_of_gridsearch( pyplot.clf() pyplot.close() - - -# TODO: remove this bit -if __name__ == '__main__': - from kymata.datasets.sample import KymataMirror2023Q3Dataset - from kymata.entities.expression import HexelExpressionSet, SensorExpressionSet - - expression_data_kymata_mirror: HexelExpressionSet = KymataMirror2023Q3Dataset().to_expressionset() - expression_plot(expression_data_kymata_mirror[ - 'CIECAM02 A', - 'CIECAM02 a', - 'CIELAB a*', - 'CIELAB L' - ], - minimap=True - )