Skip to content

Commit

Permalink
Plot brain on expression-plot figure
Browse files Browse the repository at this point in the history
(view is not correct yet)
  • Loading branch information
caiw committed Mar 8, 2024
1 parent c549e2e commit 8f4ad4b
Showing 1 changed file with 78 additions and 43 deletions.
121 changes: 78 additions & 43 deletions kymata/plot/plot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from itertools import cycle
from pathlib import Path
from statistics import NormalDist
from typing import Optional, Sequence, Dict, NamedTuple
from warnings import warn

import numpy as np
import os
from matplotlib import pyplot
from matplotlib.colors import to_hex, LinearSegmentedColormap
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -32,28 +36,66 @@ class _Ax:
minimap = "minimap"


def _get_mosaic_spec(paired_axes: bool, minimap: bool) -> list[list[str]]:
@dataclass
class _MosaicSpec:
mosaic: list[list[str]]
width_ratios: list[float] | None
fig_size: tuple[float, float]
subplots_adjust_kwargs: dict[str, float] = None

def __post_init__(self):
if self.subplots_adjust_kwargs is None:
self.subplots_adjust_kwargs = dict()

def to_subplots(self) -> tuple[pyplot.Figure, dict[str, pyplot.Axes]]:
return pyplot.subplot_mosaic(
self.mosaic,
width_ratios=self.width_ratios,
figsize=self.fig_size)


def _minimap_mosaic(paired_axes: bool, minimap: bool) -> _MosaicSpec:
# Set defaults:
if minimap:
width_ratios = [1, 3]
fig_size = (12, 7)
subplots_adjust = {
"hspace": 0, "wspace": 0.1,
"left": 0.04, "right": 0.8,
}
else:
width_ratios = None
fig_size = (12, 7)
subplots_adjust = {
"hspace": 0,
"left": 0.08, "right": 0.84,
}


if paired_axes:
if minimap:
return [
[_Ax.top, _Ax.minimap],
[_Ax.bottom, _Ax.minimap],
spec = [
[_Ax.minimap, _Ax.top],
[_Ax.minimap, _Ax.bottom],
]
else:
return [
spec = [
[_Ax.top],
[_Ax.bottom],
]
else:
if minimap:
return [
[_Ax.main, _Ax.minimap],
spec = [
[_Ax.minimap, _Ax.main],
]
else:
return [
spec = [
[_Ax.main],
]

return _MosaicSpec(mosaic=spec, width_ratios=width_ratios, fig_size=fig_size,
subplots_adjust_kwargs=subplots_adjust)


def _hexel_minimap_data(expression_set: HexelExpressionSet, alpha_logp: float) -> tuple[NDArray, NDArray]:
"""
Expand Down Expand Up @@ -135,13 +177,24 @@ def _plot_minimap_hexel(expression_set: HexelExpressionSet, minimap_axis: pyplot
stc = SourceEstimate(data=np.concatenate([data_left, data_right]),
vertices=[expression_set.hexels_left, expression_set.hexels_right],
tmin=0, tstep=1)
minimap_axis.imshow(stc.plot(subject='participant_01',
hemi="split",
colormap=colormap,
smoothing_steps = 2,
background="white",
spacing="ico5"
).to_image())
warn("Plotting on the fsaverage brain. Ensure that hexel numbers match those of the fsaverage brain.")
p = stc.plot(subject='fsaverage',
hemi="split",
colormap=colormap,
smoothing_steps = 2,
background="white",
spacing="ico5",
brain_kwargs={"offscreen": True},
)
minimap_axis.imshow(p.screenshot(), aspect="auto")
hide_axes(minimap_axis)


def hide_axes(axes: pyplot.Axes):
"""Hide all axes markings from a pyplot.Axes."""
axes.get_xaxis().set_visible(False)
axes.get_yaxis().set_visible(False)
axes.axis("off")


def _plot_minimap(expression_set: ExpressionSet, minimap_axis: pyplot.Axes, colors: dict[str, str], alpha_logp):
Expand Down Expand Up @@ -242,24 +295,19 @@ def expression_plot(

sidak_corrected_alpha = p_to_logp(sidak_corrected_alpha)

mosaic = _get_mosaic_spec(paired_axes=paired_axes, minimap=minimap)
mosaic = _minimap_mosaic(paired_axes=paired_axes, minimap=minimap)

fig: pyplot.Figure
axes: dict[str, pyplot.Axes]
fig, axes = mosaic.to_subplots()

expression_axes_list: list[pyplot.Axes]
if paired_axes:
fig, axes = pyplot.subplot_mosaic(mosaic,
width_ratios=[3, 1] if minimap else None,
figsize=(12, 7))
expression_axes_list = [axes[_Ax.top], axes[_Ax.bottom]] # For iterating over in a predictable order
else:
fig, axes = pyplot.subplot_mosaic(mosaic,
width_ratios=[3, 1] if minimap else None,
figsize=(12, 7))
expression_axes_list = [axes[_Ax.main]]

fig.subplots_adjust(hspace=0)
fig.subplots_adjust(right=0.84, left=0.08)
fig.subplots_adjust(**mosaic.subplots_adjust_kwargs)

custom_handles = []
custom_labels = []
Expand Down Expand Up @@ -337,7 +385,7 @@ def expression_plot(
# Plot minimap

if minimap is not None:
os.environ["SUBJECTS_DIR"] = Path(minimap.data_root_dir, minimap.mri_structurals_directory)
os.environ["SUBJECTS_DIR"] = str(Path(minimap["data_root_dir"], minimap["mri_structurals_directory"]))
_plot_minimap(expression_set=expression_set, minimap_axis=axes[_Ax.minimap], colors=color,
alpha_logp=sidak_corrected_alpha)

Expand All @@ -360,7 +408,10 @@ def expression_plot(
bottom_ax.text(s=axes_names[1],
x=bottom_ax_xmin + 20, y=ylim * 0.95,
style='italic', verticalalignment='center')
fig.supylabel('p-value (with α at 5-sigma, Šidák corrected)', x=0, y=0.5)
# TODO: revert this
# fig.text(x=0.04, y=0.5,
# s='p-value (with α at 5-sigma, Šidák corrected)',
# ha="center", va="center", rotation="vertical")
if bottom_ax_xmin <= 0 <= bottom_ax_xmax:
bottom_ax.text(s=' onset of environment ',
x=0, y=0 if paired_axes else ylim/2,
Expand Down Expand Up @@ -519,19 +570,3 @@ def plot_top_five_channels_of_gridsearch(

pyplot.clf()
pyplot.close()


# TODO: remove this bit
if __name__ == '__main__':
from kymata.datasets.sample import KymataMirror2023Q3Dataset
from kymata.entities.expression import HexelExpressionSet, SensorExpressionSet

expression_data_kymata_mirror: HexelExpressionSet = KymataMirror2023Q3Dataset().to_expressionset()
expression_plot(expression_data_kymata_mirror[
'CIECAM02 A',
'CIECAM02 a',
'CIELAB a*',
'CIELAB L'
],
minimap=True
)

0 comments on commit 8f4ad4b

Please sign in to comment.