From 606545957b9c424fd7a90dafdc9c60c7d5194fd8 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 1 Jul 2021 10:44:52 -0700 Subject: [PATCH 1/2] intracranial gui getting closer to a working draft [skip ci] [skip circle] wip fixed some little things wip wip wip finally fixed the transform issue wip wip wip wip wip fix units wip working versions with separate plots fix coordinate frames, refactor surface.py change to this PR wip add 3d plot automatic detection wip almost feature complete a few fixes attempted fix but doesn't work on ubuntu for some reason [skip ci][skip circle] fix zoom fix draw bug allow no head lots of cleaning up fix, change to calling dipy directly feature complete version wip wip add snap feature wip fix tests fix tutorial fix flake fix thresh fix tests didn't save fix tests fix tests fix spelling fix test fix refs small bug fixes fix warnings clean up plotting, fix pernicious array overwriting bug wip wip C- version, finds 6 / 15 electrodes with most contacts fix bug, add ACPC alignment back to tutorial, outline Hough transform wip wip wip fix diff fix diff again wip remove auto-find for now fix tests wip review, add tests everything works great except compute_volume_registration for the MRI and CT, fixing that ASAP fix registration, remove link, fix test revert to older version Eric review increase cursor width fix tests fix subjects_dir fix tutorial fix tests fix tests Update mne/gui/_ieeg_locate_gui.py Update mne/gui/tests/test_ieeg_locate_gui.py FIX: pyvistaqt, not mayavi STY: Cleaner MAINT: pytest --- doc/changes/latest.inc | 1 + doc/conf.py | 1 + doc/mri.rst | 1 + mne/datasets/utils.py | 4 +- mne/gui/__init__.py | 41 ++ mne/gui/_ieeg_locate_gui.py | 864 +++++++++++++++++++++++++ mne/gui/tests/test_ieeg_locate_gui.py | 155 +++++ mne/surface.py | 88 ++- mne/tests/test_surface.py | 13 +- mne/utils/_testing.py | 2 +- mne/utils/docs.py | 8 +- mne/viz/misc.py | 2 +- tutorials/clinical/10_ieeg_localize.py | 165 +++-- 13 files changed, 1254 insertions(+), 91 deletions(-) create mode 100644 mne/gui/_ieeg_locate_gui.py create mode 100644 mne/gui/tests/test_ieeg_locate_gui.py diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index e0b2a8d0d21..4436ae2fc2a 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -159,6 +159,7 @@ Enhancements - Reading EDF files via :func:`mne.io.read_raw_edf` now can infer channel type from the signal label in the EDF header (:gh:`9694` by `Adam Li`_) +- Add :func:`mne.gui.locate_ieeg` to locate intracranial electrode contacts from a CT, an MRI (with Freesurfer ``recon-all``) and the channel names from an :class:`mne.Info` object (:gh:`9586` by `Alex Rockhill`_) Bugs ~~~~ diff --git a/doc/conf.py b/doc/conf.py index 91aba4bc6de..c5fe3ccb71b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -258,6 +258,7 @@ # unlinkable 'mayavi.mlab.pipeline.surface', 'CoregFrame', 'Kit2FiffFrame', 'FiducialsFrame', + 'IntracranialElectrodeLocator' } numpydoc_validate = True numpydoc_validation_checks = {'all'} | set(error_ignores) diff --git a/doc/mri.rst b/doc/mri.rst index 889246feb84..9689b4c705e 100644 --- a/doc/mri.rst +++ b/doc/mri.rst @@ -19,6 +19,7 @@ Step by step instructions for using :func:`gui.coregistration`: get_montage_volume_labels gui.coregistration gui.fiducials + gui.locate_ieeg create_default_subject head_to_mni head_to_mri diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 683313f722b..b4f099d47d0 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -254,7 +254,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, path = _get_path(path, key, name) # To update the testing or misc dataset, push commits, then make a new # release on GitHub. Then update the "releases" variable: - releases = dict(testing='0.123', misc='0.18') + releases = dict(testing='0.123', misc='0.19') # And also update the "md5_hashes['testing']" variable below. # To update any other dataset, update the data archive itself (upload # an updated version) and update the md5 hash. @@ -345,7 +345,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, bst_raw='fa2efaaec3f3d462b319bc24898f440c', bst_resting='70fc7bf9c3b97c4f2eab6260ee4a0430'), fake='3194e9f7b46039bb050a74f3e1ae9908', - misc='0aa25a9bb4f204b3d4769f0b84e9b526', + misc='b5ebe66dbe0f36cba9170a7ce909a66f', sample='12b75d1cb7df9dfb4ad73ed82f61094f', somato='32fd2f6c8c7eb0784a1de6435273c48b', spm='9f43f67150e3b694b523a21eb929ea75', diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index 40ea669b8b8..ae9e1ea308f 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -239,3 +239,44 @@ def kit2fiff(): from ._kit2fiff_gui import Kit2FiffFrame frame = Kit2FiffFrame() return _initialize_gui(frame) + + +@verbose +def locate_ieeg(info, trans, aligned_ct, subject=None, subjects_dir=None, + groups=None, verbose=None): + """Locate intracranial electrode contacts. + + Parameters + ---------- + %(info_not_none)s + %(trans_not_none)s + aligned_ct : str | pathlib.Path | nibabel.spatialimages.SpatialImage + The CT image that has been aligned to the Freesurfer T1. Path-like + inputs and nibabel image objects are supported. + %(subject)s + %(subjects_dir)s + groups : dict | None + A dictionary with channels as keys and their group index as values. + If None, the groups will be inferred by the channel names. Channel + names must have a format like ``LAMY 7`` where a string prefix + like ``LAMY`` precedes a numeric index like ``7``. If the channels + are formatted improperly, group plotting will work incorrectly. + Group assignments can be adjusted in the GUI. + %(verbose)s + + Returns + ------- + gui : instance of IntracranialElectrodeLocator + The graphical user interface (GUI) window. + """ + from ._ieeg_locate_gui import IntracranialElectrodeLocator + from PyQt5.QtWidgets import QApplication + # get application + app = QApplication.instance() + if app is None: + app = QApplication(["Intracranial Electrode Locator"]) + gui = IntracranialElectrodeLocator( + info, trans, aligned_ct, subject=subject, + subjects_dir=subjects_dir, groups=groups, verbose=verbose) + gui.show() + return gui diff --git a/mne/gui/_ieeg_locate_gui.py b/mne/gui/_ieeg_locate_gui.py new file mode 100644 index 00000000000..efb787cf9c7 --- /dev/null +++ b/mne/gui/_ieeg_locate_gui.py @@ -0,0 +1,864 @@ +# -*- coding: utf-8 -*- +"""Intracranial elecrode localization GUI for finding contact locations.""" + +# Authors: Alex Rockhill +# +# License: BSD (3-clause) + +import os.path as op +import numpy as np +from functools import partial + +from matplotlib.colors import LinearSegmentedColormap + +from PyQt5 import QtCore, QtGui, Qt +from PyQt5.QtCore import pyqtSlot +from PyQt5.QtWidgets import (QMainWindow, QGridLayout, + QVBoxLayout, QHBoxLayout, QLabel, + QMessageBox, QWidget, + QListView, QSlider, QPushButton, + QComboBox, QPlainTextEdit) +from matplotlib.backends.backend_qt5agg import FigureCanvas +from matplotlib.figure import Figure + +from .._freesurfer import _check_subject_dir, _import_nibabel +from ..viz.backends.renderer import _get_renderer +from ..surface import _read_mri_surface, _voxel_neighbors +from ..transforms import (apply_trans, _frame_to_str, _get_trans, + invert_transform) +from ..utils import logger, _check_fname, _validate_type, verbose, warn + +_IMG_LABELS = [['I', 'P'], ['I', 'L'], ['P', 'L']] +_CH_PLOT_SIZE = 1024 +_ZOOM_STEP_SIZE = 5 + +# 20 colors generated to be evenly spaced in a cube, worked better than +# matplotlib color cycle +_UNIQUE_COLORS = [(0.1, 0.42, 0.43), (0.9, 0.34, 0.62), (0.47, 0.51, 0.3), + (0.47, 0.55, 0.99), (0.79, 0.68, 0.06), (0.34, 0.74, 0.05), + (0.58, 0.87, 0.13), (0.86, 0.98, 0.4), (0.92, 0.91, 0.66), + (0.77, 0.38, 0.34), (0.9, 0.37, 0.1), (0.2, 0.62, 0.9), + (0.22, 0.65, 0.64), (0.14, 0.94, 0.8), (0.34, 0.31, 0.68), + (0.59, 0.28, 0.74), (0.46, 0.19, 0.94), (0.37, 0.93, 0.7), + (0.56, 0.86, 0.55), (0.67, 0.69, 0.44)] +_N_COLORS = len(_UNIQUE_COLORS) +_CMAP = LinearSegmentedColormap.from_list( + 'ch_colors', _UNIQUE_COLORS, N=_N_COLORS) + + +def _load_image(img, name, verbose=True): + """Load data from a 3D image file (e.g. CT, MR).""" + nib = _import_nibabel('use iEEG GUI') + if not isinstance(img, nib.spatialimages.SpatialImage): + if verbose: + logger.info(f'Loading {img}') + _check_fname(img, overwrite='read', must_exist=True, name=name) + img = nib.load(img) + # get data + orig_data = np.array(img.dataobj).astype(np.float32) + # reorient data to RAS + ornt = nib.orientations.axcodes2ornt( + nib.orientations.aff2axcodes(img.affine)).astype(int) + ras_ornt = nib.orientations.axcodes2ornt('RAS') + ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) + img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) + orig_mgh = nib.MGHImage(orig_data, img.affine) + aff_trans = nib.orientations.inv_ornt_aff(ornt_trans, img.shape) + vox_ras_t = np.dot(orig_mgh.header.get_vox2ras_tkr(), aff_trans) + return img_data, vox_ras_t + + +class ComboBox(QComboBox): + """Dropdown menu that emits a click when popped up.""" + + clicked = QtCore.pyqtSignal() + + def showPopup(self): + """Override show popup method to emit click.""" + self.clicked.emit() + super(ComboBox, self).showPopup() + + +def _make_slice_plot(width=4, height=4, dpi=300): + fig = Figure(figsize=(width, height)) + canvas = FigureCanvas(fig) + ax = fig.subplots() + fig.subplots_adjust(bottom=0, left=0, right=1, + top=1, wspace=0, hspace=0) + ax.set_facecolor('k') + # clean up excess plot text, invert + ax.invert_yaxis() + ax.set_xticks([]) + ax.set_yticks([]) + return canvas, fig + + +class IntracranialElectrodeLocator(QMainWindow): + """Locate electrode contacts using a coregistered MRI and CT.""" + + def __init__(self, info, trans, aligned_ct, subject=None, + subjects_dir=None, groups=None, verbose=None): + """GUI for locating intracranial electrodes. + + .. note:: Images will be displayed using orientation information + obtained from the image header. Images will be resampled to + dimensions [256, 256, 256] for display. + """ + # initialize QMainWindow class + super(IntracranialElectrodeLocator, self).__init__() + + if not info.ch_names: + raise ValueError('No channels found in `info` to locate') + + # store info for modification + self._info = info + self._verbose = verbose + + # load imaging data + self._subject_dir = _check_subject_dir(subject, subjects_dir) + self._load_image_data(aligned_ct) + + self._ch_alpha = 0.5 + self._radius = int(_CH_PLOT_SIZE // 100) # starting 1/200 of image + # initialize channel data + self._ch_index = 0 + # load data, apply trans + self._head_mri_t = _get_trans(trans, 'head', 'mri')[0] + self._mri_head_t = invert_transform(self._head_mri_t) + # load channels, convert from m to mm + self._chs = {name: apply_trans(self._head_mri_t, ch['loc'][:3]) * 1000 + for name, ch in zip(info.ch_names, info['chs'])} + self._ch_names = list(self._chs.keys()) + # set current position + if np.isnan(self._chs[self._ch_names[self._ch_index]]).any(): + self._ras = np.array([0., 0., 0.]) + else: + self._ras = self._chs[self._ch_names[self._ch_index]].copy() + self._current_slice = apply_trans( + self._ras_vox_t, self._ras).round().astype(int) + self._group_channels(groups) + + # GUI design + + # Main plots: make one plot for each view; sagittal, coronal, axial + plt_grid = QGridLayout() + plts = [_make_slice_plot(), _make_slice_plot(), _make_slice_plot()] + self._figs = [plts[0][1], plts[1][1], plts[2][1]] + plt_grid.addWidget(plts[0][0], 0, 0) + plt_grid.addWidget(plts[1][0], 0, 1) + plt_grid.addWidget(plts[2][0], 1, 0) + self._renderer = _get_renderer( + name='IEEG Locator', size=(400, 400), bgcolor='w') + # TODO: should eventually make sure the renderer here is actually + # some PyVista(Qt) variant, not mayavi, otherwise the following + # call will fail (hopefully it's rare that people who want to use this + # have also set their MNE_3D_BACKEND=mayavi and/or don't have a working + # pyvistaqt setup; also hopefully the refactoring to use the + # Qt/notebook abstraction will make this easier, too): + plt_grid.addWidget(self._renderer.plotter) + + # Channel selector + self._ch_list = QListView() + self._ch_list.setSelectionMode(Qt.QAbstractItemView.SingleSelection) + self._ch_list.setMinimumWidth(150) + self._set_ch_names() + + # Plots + self._plot_images() + + # Menus + button_hbox = self._get_button_bar() + slider_hbox = self._get_slider_bar() + bottom_hbox = self._get_bottom_bar() + + # Put everything together + plot_ch_hbox = QHBoxLayout() + plot_ch_hbox.addLayout(plt_grid) + plot_ch_hbox.addWidget(self._ch_list) + + main_vbox = QVBoxLayout() + main_vbox.addLayout(button_hbox) + main_vbox.addLayout(slider_hbox) + main_vbox.addLayout(plot_ch_hbox) + main_vbox.addLayout(bottom_hbox) + + central_widget = QWidget() + central_widget.setLayout(main_vbox) + self.setCentralWidget(central_widget) + + # ready for user + self._move_cursors_to_pos() + self._ch_list.setFocus() # always focus on list + + def _load_image_data(self, ct): + """Get MRI and CT data to display and transforms to/from vox/RAS.""" + self._mri_data, self._vox_ras_t = _load_image( + op.join(self._subject_dir, 'mri', 'brain.mgz'), + 'MRI Image', verbose=self._verbose) + self._ras_vox_t = np.linalg.inv(self._vox_ras_t) + + self._voxel_sizes = np.array(self._mri_data.shape) + self._img_ranges = [[0, self._voxel_sizes[1], 0, self._voxel_sizes[2]], + [0, self._voxel_sizes[0], 0, self._voxel_sizes[2]], + [0, self._voxel_sizes[0], 0, self._voxel_sizes[1]]] + + # ready ct + self._ct_data, vox_ras_t = _load_image(ct, 'CT', verbose=self._verbose) + if self._mri_data.shape != self._ct_data.shape or \ + not np.allclose(self._vox_ras_t, vox_ras_t, rtol=1e-6): + raise ValueError('CT is not aligned to MRI, got ' + f'CT shape={self._ct_data.shape}, ' + f'MRI shape={self._mri_data.shape}, ' + f'CT affine={vox_ras_t} and ' + f'MRI affine={self._vox_ras_t}') + + if op.exists(op.join(self._subject_dir, 'surf', 'lh.seghead')): + self._head = _read_mri_surface( + op.join(self._subject_dir, 'surf', 'lh.seghead')) + assert _frame_to_str[self._head['coord_frame']] == 'mri' + else: + warn('`seghead` not found, skipping head plot, see ' + ':ref:`mne.bem.make_scalp_surfaces` to add the head') + self._head = None + if op.exists(op.join(self._subject_dir, 'surf', 'lh.pial')): + self._lh = _read_mri_surface( + op.join(self._subject_dir, 'surf', 'lh.pial')) + assert _frame_to_str[self._lh['coord_frame']] == 'mri' + self._rh = _read_mri_surface( + op.join(self._subject_dir, 'surf', 'rh.pial')) + assert _frame_to_str[self._rh['coord_frame']] == 'mri' + else: + warn('`pial` surface not found, skipping adding to 3D ' + 'plot. This indicates the Freesurfer recon-all ' + 'has been modified and these files have been deleted.') + self._lh = self._rh = None + + def _make_ch_image(self, axis): + """Make a plot to display the channel locations.""" + # Make channel data higher resolution so it looks better. + ch_image = np.zeros((_CH_PLOT_SIZE, _CH_PLOT_SIZE)) * np.nan + vx, vy, vz = self._voxel_sizes + + def color_ch_radius(ch_image, xf, yf, group, radius): + # Take the fraction across each dimension of the RAS + # coordinates converted to xyz and put a circle in that + # position in this larger resolution image + ex, ey = np.round(np.array([xf, yf]) * _CH_PLOT_SIZE).astype(int) + for i in range(-radius, radius + 1): + for j in range(-radius, radius + 1): + if (i**2 + j**2)**0.5 < radius: + # negative y because y axis is inverted + ch_image[-(ey + i), ex + j] = group + return ch_image + + for name, ras in self._chs.items(): + # move from middle-centered (half coords positive, half negative) + # to bottom-left corner centered (all coords positive). + if np.isnan(ras).any(): + continue + xyz = apply_trans(self._ras_vox_t, ras) + # check if closest to that voxel + dist = np.linalg.norm(xyz - self._current_slice) + if dist < self._radius: + x, y, z = xyz + group = self._groups[name] + r = self._radius - np.round(abs(dist)).astype(int) + if axis == 0: + ch_image = color_ch_radius( + ch_image, y / vy, z / vz, group, r) + elif axis == 1: + ch_image = color_ch_radius( + ch_image, x / vx, z / vx, group, r) + elif axis == 2: + ch_image = color_ch_radius( + ch_image, x / vx, y / vy, group, r) + return ch_image + + @verbose + def _save_ch_coords(self, info=None, verbose=None): + """Save the location of the electrode contacts.""" + logger.info('Saving channel positions to `info`') + if info is None: + info = self._info + for name, ch in zip(info.ch_names, info['chs']): + ch['loc'][:3] = apply_trans( + self._mri_head_t, self._chs[name] / 1000) # mm->m + + def _plot_images(self): + """Use the MRI and CT to make plots.""" + # Plot sagittal (0), coronal (1) or axial (2) view + self._images = dict(ct=list(), chs=list(), + cursor=list(), cursor2=list()) + ct_min, ct_max = np.nanmin(self._ct_data), np.nanmax(self._ct_data) + text_kwargs = dict(fontsize=3, color='#66CCEE', family='monospace', + weight='bold', ha='center', va='center') + xyz = apply_trans(self._ras_vox_t, self._ras) + for axis in range(3): + ct_data = np.take(self._ct_data, self._current_slice[axis], + axis=axis).T + self._images['ct'].append(self._figs[axis].axes[0].imshow( + ct_data, cmap='gray', aspect='auto', + vmin=ct_min, vmax=ct_max)) + self._images['chs'].append( + self._figs[axis].axes[0].imshow( + self._make_ch_image(axis), aspect='auto', + extent=self._img_ranges[axis], + cmap=_CMAP, alpha=self._ch_alpha, vmin=0, vmax=_N_COLORS)) + self._images['cursor'].append( + self._figs[axis].axes[0].plot( + (xyz[axis], xyz[axis]), (0, self._voxel_sizes[axis]), + color=[0, 1, 0], linewidth=1, alpha=0.5)[0]) + self._images['cursor2'].append( + self._figs[axis].axes[0].plot( + (0, self._voxel_sizes[axis]), (xyz[axis], xyz[axis]), + color=[0, 1, 0], linewidth=1, alpha=0.5)[0]) + # label axes + self._figs[axis].text(0.5, 0.05, _IMG_LABELS[axis][0], + **text_kwargs) + self._figs[axis].text(0.05, 0.5, _IMG_LABELS[axis][1], + **text_kwargs) + self._figs[axis].axes[0].axis(self._img_ranges[axis]) + self._figs[axis].canvas.mpl_connect( + 'scroll_event', self._on_scroll) + self._figs[axis].canvas.mpl_connect( + 'button_release_event', partial(self._on_click, axis)) + # add head and brain in mm (convert from m) + if self._head is not None: + self._renderer.mesh( + *self._head['rr'].T * 1000, triangles=self._head['tris'], + color='gray', opacity=0.2, reset_camera=False, render=False) + if self._lh is not None and self._rh is not None: + self._renderer.mesh( + *self._lh['rr'].T * 1000, triangles=self._lh['tris'], + color='white', opacity=0.2, reset_camera=False, render=False) + self._renderer.mesh( + *self._rh['rr'].T * 1000, triangles=self._rh['tris'], + color='white', opacity=0.2, reset_camera=False, render=False) + self._3d_chs = dict() + self._plot_3d_ch_pos() + self._renderer.set_camera(azimuth=90, elevation=90, distance=300, + focalpoint=tuple(self._ras)) + # update plots + self._draw() + self._renderer._update() + + def _scale_radius(self): + """Scale the radius to mm.""" + shape = np.mean(self._ct_data.shape) # this is Freesurfer shape (256) + scale = np.diag(self._ras_vox_t)[:3].mean() + return scale * self._radius * (shape / _CH_PLOT_SIZE) + + def _update_camera(self, render=False): + """Update the camera position.""" + self._renderer.set_camera( + # needs fix, distance moves when focal point updates + distance=self._renderer.plotter.camera.distance * 0.9, + focalpoint=tuple(self._ras), + reset_camera=False) + + def _plot_3d_ch(self, name, render=False): + """Plot a single 3D channel.""" + if name in self._3d_chs: + self._renderer.plotter.remove_actor(self._3d_chs.pop(name)) + if not any(np.isnan(self._chs[name])): + radius = self._scale_radius() + self._3d_chs[name] = self._renderer.sphere( + tuple(self._chs[name]), scale=radius * 3, + color=_UNIQUE_COLORS[self._groups[name] % _N_COLORS], + opacity=self._ch_alpha)[0] + if render: + self._renderer._update() + + def _plot_3d_ch_pos(self, render=False): + for name in self._chs: + self._plot_3d_ch(name) + if render: + self._renderer._update() + + def _get_button_bar(self): + """Make a bar with buttons for user interactions.""" + hbox = QHBoxLayout() + + hbox.addStretch(10) + + hbox.addWidget(QLabel('Snap to Center')) + self._snap_button = QPushButton('Off') + self._snap_button.setMaximumWidth(25) # not too big + hbox.addWidget(self._snap_button) + self._snap_button.released.connect(self._toggle_snap) + self._toggle_snap() # turn on to start + + hbox.addStretch(1) + + mark_button = QPushButton('Mark') + hbox.addWidget(mark_button) + mark_button.released.connect(self._mark_ch) + + remove_button = QPushButton('Remove') + hbox.addWidget(remove_button) + remove_button.released.connect(self._remove_ch) + + self._group_selector = ComboBox() + group_model = self._group_selector.model() + + for i in range(_N_COLORS): + self._group_selector.addItem(' ') + color = QtGui.QColor() + color.setRgb(*(255 * np.array(_UNIQUE_COLORS[i % _N_COLORS]) + ).round().astype(int)) + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + group_model.setData(group_model.index(i, 0), + brush, QtCore.Qt.BackgroundRole) + self._group_selector.clicked.connect(self._select_group) + self._group_selector.currentIndexChanged.connect( + self._update_group) + hbox.addWidget(self._group_selector) + + # update background color for current selection + self._update_group() + + return hbox + + def _get_slider_bar(self): + """Make a bar with sliders on it.""" + + def make_label(name): + label = QLabel(name) + label.setAlignment(QtCore.Qt.AlignCenter) + return label + + def make_slider(smin, smax, sval, sfun=None): + slider = QSlider(QtCore.Qt.Horizontal) + slider.setMinimum(int(round(smin))) + slider.setMaximum(int(round(smax))) + slider.setValue(int(round(sval))) + slider.setTracking(False) # only update on release + if sfun is not None: + slider.valueChanged.connect(sfun) + slider.keyPressEvent = self._key_press_event + return slider + + slider_hbox = QHBoxLayout() + + ch_vbox = QVBoxLayout() + ch_vbox.addWidget(make_label('ch alpha')) + ch_vbox.addWidget(make_label('ch radius')) + slider_hbox.addLayout(ch_vbox) + + ch_slider_vbox = QVBoxLayout() + self._alpha_slider = make_slider(0, 100, self._ch_alpha * 100, + self._update_ch_alpha) + ch_plot_max = _CH_PLOT_SIZE // 50 # max 1 / 50 of plot size + ch_slider_vbox.addWidget(self._alpha_slider) + self._radius_slider = make_slider(0, ch_plot_max, self._radius, + self._update_radius) + ch_slider_vbox.addWidget(self._radius_slider) + slider_hbox.addLayout(ch_slider_vbox) + + ct_vbox = QVBoxLayout() + ct_vbox.addWidget(make_label('CT min')) + ct_vbox.addWidget(make_label('CT max')) + slider_hbox.addLayout(ct_vbox) + + ct_slider_vbox = QVBoxLayout() + ct_min = int(round(np.nanmin(self._ct_data))) + ct_max = int(round(np.nanmax(self._ct_data))) + self._ct_min_slider = make_slider( + ct_min, ct_max, ct_min, self._update_ct_scale) + ct_slider_vbox.addWidget(self._ct_min_slider) + self._ct_max_slider = make_slider( + ct_min, ct_max, ct_max, self._update_ct_scale) + ct_slider_vbox.addWidget(self._ct_max_slider) + slider_hbox.addLayout(ct_slider_vbox) + return slider_hbox + + def _get_bottom_bar(self): + """Make a bar at the bottom with information in it.""" + hbox = QHBoxLayout() + + hbox.addStretch(10) + + self._intensity_label = QLabel('') # update later + hbox.addWidget(self._intensity_label) + + RAS_label = QLabel('RAS =') + self._RAS_textbox = QPlainTextEdit('') # update later + self._RAS_textbox.setMaximumHeight(25) + self._RAS_textbox.setMaximumWidth(200) + self._RAS_textbox.focusOutEvent = self._update_RAS + self._RAS_textbox.textChanged.connect(self._check_update_RAS) + hbox.addWidget(RAS_label) + hbox.addWidget(self._RAS_textbox) + self._update_moved() # update text now + return hbox + + def _group_channels(self, groups): + """Automatically find a group based on the name of the channel.""" + if groups is not None: + for name in self._ch_names: + if name not in groups: + raise ValueError(f'{name} not found in ``groups``') + _validate_type(groups[name], (float, int), f'groups[{name}]') + self.groups = groups + else: + i = 0 + self._groups = dict() + base_names = dict() + for name in self._ch_names: + # strip all numbers from the name + base_name = ''.join([letter for letter in name if + not letter.isdigit() and letter != ' ']) + if base_name in base_names: + # look up group number by base name + self._groups[name] = base_names[base_name] + else: + self._groups[name] = i + base_names[base_name] = i + i += 1 + + def _set_ch_names(self): + """Add the channel names to the selector.""" + self._ch_list_model = QtGui.QStandardItemModel(self._ch_list) + for name in self._ch_names: + self._ch_list_model.appendRow(QtGui.QStandardItem(name)) + self._color_list_item(name=name) + self._ch_list.setModel(self._ch_list_model) + self._ch_list.clicked.connect(self._go_to_ch) + self._ch_list.setCurrentIndex( + self._ch_list_model.index(self._ch_index, 0)) + self._ch_list.keyPressEvent = self._key_press_event + + def _select_group(self): + """Change the group label to the selection.""" + group = self._group_selector.currentIndex() + self._groups[self._ch_names[self._ch_index]] = group + # color differently if found already + self._color_list_item(self._ch_names[self._ch_index]) + self._update_group() + + def _update_group(self): + """Set background for closed group menu.""" + group = self._group_selector.currentIndex() + rgb = (255 * np.array(_UNIQUE_COLORS[group % _N_COLORS]) + ).round().astype(int) + self._group_selector.setStyleSheet( + 'background-color: rgb({:d},{:d},{:d})'.format(*rgb)) + self._group_selector.update() + + def _on_scroll(self, event): + """Process mouse scroll wheel event to zoom.""" + self._zoom(event.step, draw=True) + + def _zoom(self, sign=1, draw=False): + """Zoom in on the image.""" + delta = _ZOOM_STEP_SIZE * sign + for axis, fig in enumerate(self._figs): + xmid = self._images['cursor'][axis].get_xdata()[0] + ymid = self._images['cursor2'][axis].get_ydata()[0] + xmin, xmax = fig.axes[0].get_xlim() + ymin, ymax = fig.axes[0].get_ylim() + xwidth = (xmax - xmin) / 2 - delta + ywidth = (ymax - ymin) / 2 - delta + if xwidth <= 0 or ywidth <= 0: + return + fig.axes[0].set_xlim(xmid - xwidth, xmid + xwidth) + fig.axes[0].set_ylim(ymid - ywidth, ymid + ywidth) + self._images['cursor'][axis].set_ydata([ymin, ymax]) + self._images['cursor2'][axis].set_xdata([xmin, xmax]) + if draw: + self._figs[axis].canvas.draw() + + def _update_ch_selection(self): + """Update which channel is selected.""" + name = self._ch_names[self._ch_index] + self._ch_list.setCurrentIndex( + self._ch_list_model.index(self._ch_index, 0)) + self._group_selector.setCurrentIndex(self._groups[name]) + self._update_group() + if not np.isnan(self._chs[name]).any(): + self._ras[:] = self._chs[name] + self._move_cursors_to_pos() + self._update_camera(render=True) + self._draw() + + def _go_to_ch(self, index): + """Change current channel to the item selected.""" + self._ch_index = index.row() + self._update_ch_selection() + + @pyqtSlot() + def _next_ch(self): + """Increment the current channel selection index.""" + self._ch_index = (self._ch_index + 1) % len(self._ch_names) + self._update_ch_selection() + + @pyqtSlot() + def _update_RAS(self, event): + """Interpret user input to the RAS textbox.""" + text = self._RAS_textbox.toPlainText().replace('\n', '') + ras = text.split(',') + if len(ras) != 3: + ras = text.split(' ') # spaces also okay as in freesurfer + ras = [var.lstrip().rstrip() for var in ras] + + if len(ras) != 3: + self._update_moved() # resets RAS label + return + all_float = all([all([dig.isdigit() or dig in ('-', '.') + for dig in var]) for var in ras]) + if not all_float: + self._update_moved() # resets RAS label + return + + ras = np.array([float(var) for var in ras]) + xyz = apply_trans(self._ras_vox_t, ras) + wrong_size = any([var < 0 or var > n for var, n in + zip(xyz, self._voxel_sizes)]) + if wrong_size: + self._update_moved() # resets RAS label + return + + # valid RAS position, update and move + self._ras = ras + self._move_cursors_to_pos() + + @pyqtSlot() + def _check_update_RAS(self): + """Check whether the RAS textbox is done being edited.""" + if '\n' in self._RAS_textbox.toPlainText(): + self._update_RAS(event=None) + self._ch_list.setFocus() # remove focus from text edit + + def _color_list_item(self, name=None): + """Color the item in the view list for easy id of marked channels.""" + name = self._ch_names[self._ch_index] if name is None else name + color = QtGui.QColor('white') + if not np.isnan(self._chs[name]).any(): + group = self._groups[name] + color.setRgb(*[int(c * 255) for c in + _UNIQUE_COLORS[int(group) % _N_COLORS]]) + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + self._ch_list_model.setData( + self._ch_list_model.index(self._ch_names.index(name), 0), + brush, QtCore.Qt.BackgroundRole) + # color text black + color = QtGui.QColor('black') + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + self._ch_list_model.setData( + self._ch_list_model.index(self._ch_names.index(name), 0), + brush, QtCore.Qt.ForegroundRole) + + @pyqtSlot() + def _toggle_snap(self): + """Toggle snapping the contact location to the center of mass.""" + if self._snap_button.text() == 'Off': + self._snap_button.setText('On') + self._snap_button.setStyleSheet("background-color: green") + else: # text == 'On', turn off + self._snap_button.setText('Off') + self._snap_button.setStyleSheet("background-color: red") + + @pyqtSlot() + def _mark_ch(self): + """Mark the current channel as being located at the crosshair.""" + name = self._ch_names[self._ch_index] + if self._snap_button.text() == 'Off': + self._chs[name][:] = self._ras + else: + coord = apply_trans(self._ras_vox_t, self._ras.copy()) + shape = np.mean(self._mri_data.shape) # Freesurfer shape (256) + voxels_max = int( + 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE)**3) + neighbors = _voxel_neighbors( + coord, self._ct_data, thresh=0.5, + voxels_max=voxels_max, use_relative=True) + self._chs[name][:] = apply_trans( # to surface RAS + self._vox_ras_t, np.array(list(neighbors)).mean(axis=0)) + self._color_list_item() + self._update_ch_images(draw=True) + self._plot_3d_ch(name, render=True) + self._save_ch_coords() + self._next_ch() + self._ch_list.setFocus() + + @pyqtSlot() + def _remove_ch(self): + """Remove the location data for the current channel.""" + name = self._ch_names[self._ch_index] + self._chs[name] *= np.nan + self._color_list_item() + self._save_ch_coords() + self._update_ch_images(draw=True) + self._plot_3d_ch(name, render=True) + self._next_ch() + self._ch_list.setFocus() + + def _draw(self, axis=None): + """Update the figures with a draw call.""" + for axis in (range(3) if axis is None else [axis]): + self._figs[axis].canvas.draw() + + def _update_ch_images(self, axis=None, draw=False): + """Update the channel image(s).""" + for axis in range(3) if axis is None else [axis]: + self._images['chs'][axis].set_data( + self._make_ch_image(axis)) + if draw: + self._draw(axis) + + def _update_ct_images(self, axis=None, draw=False): + """Update the CT image(s).""" + for axis in range(3) if axis is None else [axis]: + ct_data = np.take(self._ct_data, self._current_slice[axis], + axis=axis).T + # Threshold the CT so only bright objects (electrodes) are visible + ct_data[ct_data < self._ct_min_slider.value()] = np.nan + ct_data[ct_data > self._ct_max_slider.value()] = np.nan + self._images['ct'][axis].set_data(ct_data) + if draw: + self._draw(axis) + + def _update_mri_images(self, axis=None, draw=False): + """Update the CT image(s).""" + if 'mri' in self._images: + for axis in range(3) if axis is None else [axis]: + self._images['mri'][axis].set_data( + np.take(self._mri_data, self._current_slice[axis], + axis=axis).T) + if draw: + self._draw(axis) + + def _update_images(self, axis=None, draw=True): + """Update CT and channel images when general changes happen.""" + self._update_ct_images(axis=axis) + self._update_ch_images(axis=axis) + self._update_mri_images(axis=axis) + if draw: + self._draw(axis) + + def _update_ct_scale(self): + """Update CT min slider value.""" + new_min = self._ct_min_slider.value() + new_max = self._ct_max_slider.value() + # handle inversions + self._ct_min_slider.setValue(min([new_min, new_max])) + self._ct_max_slider.setValue(max([new_min, new_max])) + self._update_ct_images(draw=True) + + def _update_radius(self): + """Update channel plot radius.""" + self._radius = np.round(self._radius_slider.value()).astype(int) + self._update_ch_images(draw=True) + self._plot_3d_ch_pos(render=True) + self._ch_list.setFocus() # remove focus from 3d plotter + + def _update_ch_alpha(self): + """Update channel plot alpha.""" + self._ch_alpha = self._alpha_slider.value() / 100 + for axis in range(3): + self._images['chs'][axis].set_alpha(self._ch_alpha) + self._draw() + self._plot_3d_ch_pos(render=True) + self._ch_list.setFocus() # remove focus from 3d plotter + + def _get_click_pos(self, axis, x, y): + """Get which axis was clicked and where.""" + fx, fy = self._figs[axis].transFigure.inverted().transform((x, y)) + xmin, xmax = self._figs[axis].axes[0].get_xlim() + ymin, ymax = self._figs[axis].axes[0].get_ylim() + return (fx * (xmax - xmin) + xmin, fy * (ymax - ymin) + ymin) + + def _move_cursors_to_pos(self): + """Move the cursors to a position.""" + x, y, z = apply_trans(self._ras_vox_t, self._ras) + self._current_slice = np.array([x, y, z]).round().astype(int) + self._move_cursor_to(0, x=y, y=z) + self._move_cursor_to(1, x=x, y=z) + self._move_cursor_to(2, x=x, y=y) + self._zoom(0) # doesn't actually zoom just resets view to center + self._update_images(draw=True) + self._update_moved() + + def _move_cursor_to(self, axis, x, y): + """Move the cursors to a position for a given subplot.""" + self._images['cursor2'][axis].set_ydata([y, y]) + self._images['cursor'][axis].set_xdata([x, x]) + + def _key_press_event(self, event): + """Execute functions when the user presses a key.""" + if event.key() == 'escape': + self.close() + + if event.text() == 'h': + # Show help + QMessageBox.information( + self, 'Help', + "Help:\n'm': mark channel location\n" + "'r': remove channel location\n" + "'b': toggle viewing of brain in T1\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "page up/page down arrow: anterior/posterior") + + if event.text() == 'm': + self._mark_ch() + + if event.text() == 'r': + self._remove_ch() + + if event.text() == 'b': + if 'mri' in self._images: + for img in self._images['mri']: + img.remove() + self._images.pop('mri') + else: + self._images['mri'] = list() + for axis in range(3): + mri_data = np.take(self._mri_data, + self._current_slice[axis], axis=axis).T + self._images['mri'].append(self._figs[axis].axes[0].imshow( + mri_data, cmap='hot', aspect='auto', alpha=0.25)) + self._draw() + + if event.text() in ('=', '+', '-'): + self._zoom(sign=-2 * (event.text() == '-') + 1, draw=True) + + # Changing slices + if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down, + QtCore.Qt.Key_Left, QtCore.Qt.Key_Right, + QtCore.Qt.Key_PageUp, QtCore.Qt.Key_PageDown): + if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down): + self._ras[2] += 2 * (event.key() == QtCore.Qt.Key_Up) - 1 + elif event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_Right): + self._ras[0] += 2 * (event.key() == QtCore.Qt.Key_Right) - 1 + elif event.key() in (QtCore.Qt.Key_PageUp, + QtCore.Qt.Key_PageDown): + self._ras[1] += 2 * (event.key() == QtCore.Qt.Key_PageUp) - 1 + self._move_cursors_to_pos() + + def _on_click(self, axis, event): + """Move to view on MRI and CT on click.""" + # Transform coordinates to figure coordinates + pos = self._get_click_pos(axis, event.x, event.y) + logger.info(f'Clicked axis {axis} at pos {pos}') + + if axis is not None and pos is not None: + xyz = apply_trans(self._ras_vox_t, self._ras) + if axis == 0: + xyz[[1, 2]] = pos + elif axis == 1: + xyz[[0, 2]] = pos + elif axis == 2: + xyz[[0, 1]] = pos + self._ras = apply_trans(self._vox_ras_t, xyz) + self._move_cursors_to_pos() + + def _update_moved(self): + """Update when cursor position changes.""" + self._RAS_textbox.setPlainText('{:.2f}, {:.2f}, {:.2f}'.format( + *self._ras)) + self._intensity_label.setText('intensity = {:.2f}'.format( + self._ct_data[tuple(self._current_slice)])) diff --git a/mne/gui/tests/test_ieeg_locate_gui.py b/mne/gui/tests/test_ieeg_locate_gui.py new file mode 100644 index 00000000000..e4568ff52bd --- /dev/null +++ b/mne/gui/tests/test_ieeg_locate_gui.py @@ -0,0 +1,155 @@ +# Authors: Alex Rockhill +# +# License: BSD-3-clause + +import os.path as op +import numpy as np +from numpy.testing import assert_allclose + +import pytest + +import mne +from mne.datasets import testing +from mne.utils import requires_nibabel +from mne.viz.utils import _fake_click + +data_path = testing.data_path(download=False) +subject = 'sample' +subjects_dir = op.join(data_path, 'subjects') +sample_dir = op.join(data_path, 'MEG', subject) +raw_path = op.join(sample_dir, 'sample_audvis_trunc_raw.fif') +fname_trans = op.join(sample_dir, 'sample_audvis_trunc-trans.fif') + + +@requires_nibabel() +@pytest.fixture +def _fake_CT_coords(skull_size=5, contact_size=2): + """Make somewhat realistic CT data with contacts.""" + import nibabel as nib + brain = nib.load( + op.join(subjects_dir, subject, 'mri', 'brain.mgz')) + verts = mne.read_surface( + op.join(subjects_dir, subject, 'bem', 'outer_skull.surf'))[0] + verts = mne.transforms.apply_trans( + np.linalg.inv(brain.header.get_vox2ras_tkr()), verts) + x, y, z = np.array(brain.shape).astype(int) // 2 + coords = [(x, y - 14, z), (x - 10, y - 15, z), + (x - 20, y - 16, z + 1), (x - 30, y - 16, z + 1)] + center = np.array(brain.shape) / 2 + # make image + np.random.seed(99) + ct_data = np.random.random(brain.shape).astype(np.float32) * 100 + # make skull + for vert in verts: + x, y, z = np.round(vert).astype(int) + ct_data[slice(x - skull_size, x + skull_size + 1), + slice(y - skull_size, y + skull_size + 1), + slice(z - skull_size, z + skull_size + 1)] = 1000 + # add electrode with contacts + for (x, y, z) in coords: + # make sure not in skull + assert np.linalg.norm(center - np.array((x, y, z))) < 50 + ct_data[slice(x - contact_size, x + contact_size + 1), + slice(y - contact_size, y + contact_size + 1), + slice(z - contact_size, z + contact_size + 1)] = \ + 1000 - np.linalg.norm(np.array(np.meshgrid( + *[range(-contact_size, contact_size + 1)] * 3)), axis=0) + ct = nib.MGHImage(ct_data, brain.affine) + coords = mne.transforms.apply_trans( + ct.header.get_vox2ras_tkr(), np.array(coords)) + return ct, coords + + +@requires_nibabel() +@pytest.fixture +def _locate_ieeg(renderer_interactive_pyvistaqt): + # Use a fixture to create these classes so we can ensure that they + # are closed at the end of the test + guis = list() + + def fun(*args, **kwargs): + guis.append(mne.gui.locate_ieeg(*args, **kwargs)) + return guis[-1] + + yield fun + + for gui in guis: + try: + gui.close() + except Exception: + pass + + +def test_ieeg_elec_locate_gui_io(_locate_ieeg): + """Test the input/output of the intracranial location GUI.""" + import nibabel as nib + info = mne.create_info([], 1000) + aligned_ct = nib.MGHImage(np.zeros((256, 256, 256), dtype=np.float32), + np.eye(4)) + trans = mne.transforms.Transform('head', 'mri') + with pytest.raises(ValueError, + match='No channels found in `info` to locate'): + _locate_ieeg(info, aligned_ct, subject, subjects_dir) + info = mne.create_info(['test'], 1000, ['seeg']) + with pytest.raises(ValueError, match='CT is not aligned to MRI'): + _locate_ieeg(info, trans, aligned_ct, subject=subject, + subjects_dir=subjects_dir) + + +@testing.requires_testing_data +def test_ieeg_elec_locate_gui_display(_locate_ieeg, _fake_CT_coords): + """Test that the intracranial location GUI displays properly.""" + raw = mne.io.read_raw_fif(raw_path) + raw.pick_types(eeg=True) + ch_dict = {'EEG 001': 'LAMY 1', 'EEG 002': 'LAMY 2', + 'EEG 003': 'LSTN 1', 'EEG 004': 'LSTN 2'} + raw.pick_channels(list(ch_dict.keys())) + raw.rename_channels(ch_dict) + raw.set_montage(None) + aligned_ct, coords = _fake_CT_coords + trans = mne.read_trans(fname_trans) + with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + gui = _locate_ieeg(raw.info, trans, aligned_ct, + subject=subject, subjects_dir=subjects_dir) + + gui._ras[:] = coords[0] # start in the right position + gui._move_cursors_to_pos() + for coord in coords: + coord_vox = mne.transforms.apply_trans(gui._ras_vox_t, coord) + _fake_click(gui._figs[2], gui._figs[2].axes[0], + coord_vox[:-1], xform='data', kind='release') + assert_allclose(coord, gui._ras, atol=3) # clicks are a bit off + + # test snap to center + gui._ras[:] = coords[0] # move to first position + gui._move_cursors_to_pos() + gui._mark_ch() + assert_allclose(coords[0], gui._chs['LAMY 1'], atol=0.2) + gui._snap_button.click() + assert gui._snap_button.text() == 'Off' + # now make sure no snap happens + gui._ras[:] = coords[1] + 1 + gui._mark_ch() + assert_allclose(coords[1] + 1, gui._chs['LAMY 2'], atol=0.01) + # check that it turns back on + gui._snap_button.click() + assert gui._snap_button.text() == 'On' + + # test remove + gui._ch_index = 1 + gui._update_ch_selection() + gui._remove_ch() + assert np.isnan(gui._chs['LAMY 2']).all() + + # check that raw object saved + assert not np.isnan(raw.info['chs'][0]['loc'][:3]).any() # LAMY 1 + assert np.isnan(raw.info['chs'][1]['loc'][:3]).all() # LAMY 2 (removed) + + # move sliders + gui._alpha_slider.setValue(75) + assert gui._ch_alpha == 0.75 + gui._radius_slider.setValue(5) + assert gui._radius == 5 + ct_sum_before = np.nansum(gui._images['ct'][0].get_array().data) + gui._ct_min_slider.setValue(500) + assert np.nansum(gui._images['ct'][0].get_array().data) < ct_sum_before diff --git a/mne/surface.py b/mne/surface.py index e94264da05d..586ff52cbce 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -21,7 +21,7 @@ from .channels.channels import _get_meg_system from .fixes import (_serialize_volume_info, _get_read_geometry, jit, - prange, bincount, _get_img_fdata) + prange, bincount) from .io.constants import FIFF from .io.pick import pick_types from .parallel import parallel_func @@ -1736,22 +1736,23 @@ def _marching_cubes(image, level, smooth=0): return out -def _warn_missing_chs(montage, dig_image, after_warp): +def _warn_missing_chs(info, dig_image, after_warp, verbose=None): """Warn that channels are missing.""" # ensure that each electrode contact was marked in at least one voxel - missing = set(np.arange(1, len(montage.ch_names) + 1)).difference( - set(np.unique(np.asanyarray(dig_image.dataobj)))) - missing_ch = [montage.ch_names[idx - 1] for idx in missing] - if missing_ch: + missing = set(np.arange(1, len(info.ch_names) + 1)).difference( + set(np.unique(np.array(dig_image.dataobj)))) + missing_ch = [info.ch_names[idx - 1] for idx in missing] + if missing_ch and verbose != 'error': warn('Channels ' + ', '.join(missing_ch) + ' were not assigned ' 'voxels' + (' after applying SDR warp' if after_warp else '')) -@fill_doc +@verbose def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, subject_from, subject_to='fsaverage', - subjects_dir=None, thresh=0.95, - max_peak_dist=1, voxels_max=100, use_min=False): + subjects_dir_from=None, subjects_dir_to=None, + thresh=0.5, max_peak_dist=1, voxels_max=100, + use_min=False, verbose=None): """Warp a montage to a template with image volumes using SDR. Find areas of the input volume with intensity greater than @@ -1777,10 +1778,17 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, subject_to : str The name of the subject to use as a template to morph to (e.g. 'fsaverage'). - %(subjects_dir)s + subjects_dir_from : str | pathlib.Path | None + The path to the Freesurfer ``recon-all`` directory for the + ``subject_from`` subject. The ``SUBJECTS_DIR`` environment + variable will be used when ``None``. + subjects_dir_to : str | pathlib.Path | None + The path to the Freesurfer ``recon-all`` directory for the + ``subject_to`` subject. ``subject_dir_from`` will be used + when ``None``. thresh : float - The quantile of the image data to use to threshold the - channel size on the volume. + The threshold relative to the peak to determine the size + of the sensors on the volume. max_peak_dist : int The number of voxels away from the channel location to look in the ``image``. This will depend on the accuracy of @@ -1791,6 +1799,7 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, use_min : bool Whether to hypointensities in the volume as channel locations. Default False uses hyperintensities. + %(verbose)s Returns ------- @@ -1820,14 +1829,16 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, _validate_type(use_min, bool, 'use_min') # first, make sure we have the necessary freesurfer surfaces - _check_subject_dir(subject_from, subjects_dir) - _check_subject_dir(subject_to, subjects_dir) + _check_subject_dir(subject_from, subjects_dir_from) + if subjects_dir_to is None: # assume shared + subjects_dir_to = subjects_dir_from + _check_subject_dir(subject_to, subjects_dir_to) # load image and make sure it's in surface RAS if not isinstance(base_image, nib.spatialimages.SpatialImage): base_image = nib.load(base_image) fs_from_img = nib.load( - op.join(subjects_dir, subject_from, 'mri', 'brain.mgz')) + op.join(subjects_dir_from, subject_from, 'mri', 'brain.mgz')) if not np.allclose(base_image.affine, fs_from_img.affine, atol=1e-6): raise RuntimeError('The `base_image` is not aligned to Freesurfer ' 'surface RAS space. This space is required as ' @@ -1853,12 +1864,13 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, # take channel coordinates and use the image to transform them # into a volume where all the voxels over a threshold nearby # are labeled with an index - image_data = _get_img_fdata(base_image) + image_data = np.array(base_image.dataobj) if use_min: image_data *= -1 - thresh = np.quantile(image_data, thresh) image_from = np.zeros(base_image.shape, dtype=int) for i, ch_coord in enumerate(ch_coords): + if np.isnan(ch_coord).any(): + continue # this looks up to a voxel away, it may be marked imperfectly volume = _voxel_neighbors(ch_coord, image_data, thresh=thresh, max_peak_dist=max_peak_dist, @@ -1880,7 +1892,7 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, _warn_missing_chs(montage, image_from, after_warp=False) template_brain = nib.load( - op.join(subjects_dir, subject_to, 'mri', 'brain.mgz')) + op.join(subjects_dir_to, subject_to, 'mri', 'brain.mgz')) image_to = apply_volume_registration( image_from, template_brain, reg_affine, sdr_morph, @@ -1891,12 +1903,11 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, # recover the contact positions as the center of mass warped_data = np.asanyarray(image_to.dataobj) for val, ch_coord in enumerate(ch_coords, 1): - ch_coord[:] = np.asanyarray( - np.where(warped_data == val), float).mean(axis=1) + ch_coord[:] = np.mean(np.where(warped_data == val), axis=1) # convert back to surface RAS of the template fs_to_img = nib.load( - op.join(subjects_dir, subject_to, 'mri', 'brain.mgz')) + op.join(subjects_dir_to, subject_to, 'mri', 'brain.mgz')) ch_coords = apply_trans( fs_to_img.header.get_vox2ras_tkr(), ch_coords) / 1000 @@ -1961,13 +1972,16 @@ def get_montage_volume_labels(montage, subject, subjects_dir=None, ch_coords = apply_trans( np.linalg.inv(aseg.header.get_vox2ras_tkr()), ch_coords * 1000) labels = OrderedDict() - for ch_name, seed in zip(montage.ch_names, ch_coords): - voxels = _voxel_neighbors( - seed, aseg_data, dist=dist, vox2ras_tkr=vox2ras_tkr, - voxels_max=_VOXELS_MAX) - label_idxs = set([aseg_data[tuple(voxel)].astype(int) - for voxel in voxels]) - labels[ch_name] = [label_lut[idx] for idx in label_idxs] + for ch_name, ch_coord in zip(montage.ch_names, ch_coords): + if np.isnan(ch_coord).any(): + labels[ch_name] = list() + else: + voxels = _voxel_neighbors( + ch_coord, aseg_data, dist=dist, vox2ras_tkr=vox2ras_tkr, + voxels_max=_VOXELS_MAX) + label_idxs = set([aseg_data[tuple(voxel)].astype(int) + for voxel in voxels]) + labels[ch_name] = [label_lut[idx] for idx in label_idxs] all_labels = set([label for val in labels.values() for label in val]) colors = {label: tuple(fs_colors[label][:3] / 255) + (1.,) @@ -2000,8 +2014,9 @@ def _get_neighbors(loc, image, voxels, thresh, dist_params): return neighbors -def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, - dist=None, vox2ras_tkr=None, voxels_max=100): +def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=1, + use_relative=True, dist=None, vox2ras_tkr=None, + voxels_max=100): """Find voxels above a threshold contiguous with a seed location. Parameters @@ -2011,11 +2026,14 @@ def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, image : ndarray The image to search. thresh : float - The threshold to use as a cutoff for what qualifies as a - neighbor. + The threshold to use as a cutoff for what qualifies as a neighbor. + Will be relative to the peak if ``use_relative`` or absolute if not. max_peak_dist : int The maximum number of voxels to search for the peak near the seed location. + use_relative : bool + If ``True``, the threshold will be relative to the peak, if + ``False``, the threshold will be absolute. dist : float The distance in mm to include surrounding voxels. vox2ras_tkr : ndarray @@ -2037,8 +2055,8 @@ def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, only voxels within ``dist`` mm of the seed are included. """ seed = np.array(seed).round().astype(int) - assert ((dist is not None) + (max_peak_dist is not None)) == 1 - if max_peak_dist is not None: + assert ((dist is not None) + (thresh is not None)) == 1 + if thresh is not None: dist_params = None check_grid = image[tuple([ slice(idx - max_peak_dist, idx + max_peak_dist + 1) @@ -2046,6 +2064,8 @@ def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, peak = np.array(np.unravel_index( np.argmax(check_grid), check_grid.shape)) - max_peak_dist + seed voxels = neighbors = set([tuple(peak)]) + if use_relative: + thresh *= image[tuple(peak)] else: assert vox2ras_tkr is not None seed_fs_ras = apply_trans(vox2ras_tkr, seed + 0.5) # center of voxel diff --git a/mne/tests/test_surface.py b/mne/tests/test_surface.py index 851ed065518..094e82740a1 100644 --- a/mne/tests/test_surface.py +++ b/mne/tests/test_surface.py @@ -290,7 +290,7 @@ def test_voxel_neighbors(): image[4:7, 4:7, 4:7] = 3 image[5, 5, 5] = 4 volume = _voxel_neighbors( - (5.5, 5.1, 4.9), image, thresh=2, max_peak_dist=1) + (5.5, 5.1, 4.9), image, thresh=2, max_peak_dist=1, use_relative=False) true_volume = set([(5, 4, 5), (5, 5, 4), (5, 5, 5), (6, 5, 5), (5, 6, 5), (5, 5, 6), (4, 5, 5)]) assert volume.difference(true_volume) == set() @@ -341,7 +341,7 @@ def test_warp_montage_volume(): rpa=rpa['r'], coord_frame='mri') montage_warped, image_from, image_to = warp_montage_volume( montage, CT, reg_affine, sdr_morph, 'sample', - subjects_dir=subjects_dir, thresh=0.99) + subjects_dir_from=subjects_dir, thresh=0.99) # checked with nilearn plot from `tut-ieeg-localize` # check montage in surface RAS ground_truth_warped = np.array([[-0.009, -0.00133333, -0.033], @@ -375,13 +375,15 @@ def test_warp_montage_volume(): CT_unaligned = nib.Nifti1Image(CT_data, template_brain.affine) with pytest.raises(RuntimeError, match='not aligned to Freesurfer'): warp_montage_volume(montage, CT_unaligned, reg_affine, - sdr_morph, 'sample', subjects_dir=subjects_dir) + sdr_morph, 'sample', + subjects_dir_from=subjects_dir) bad_montage = montage.copy() for d in bad_montage.dig: d['coord_frame'] = 99 with pytest.raises(RuntimeError, match='Coordinate frame not supported'): warp_montage_volume(bad_montage, CT, reg_affine, - sdr_morph, 'sample', subjects_dir=subjects_dir) + sdr_morph, 'sample', + subjects_dir_from=subjects_dir) # check channel not warped ch_pos_doubled = ch_pos.copy() @@ -391,4 +393,5 @@ def test_warp_montage_volume(): rpa=rpa['r'], coord_frame='mri') with pytest.warns(RuntimeWarning, match='not assigned'): warp_montage_volume(doubled_montage, CT, reg_affine, - sdr_morph, 'sample', subjects_dir=subjects_dir) + sdr_morph, 'sample', + subjects_dir_from=subjects_dir) diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index f48515d8b61..1321538b35c 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -66,7 +66,7 @@ def requires_nibabel(): def requires_dipy(): """Check for dipy.""" import pytest - # for some strange reason on CIs we cane get: + # for some strange reason on CIs we can get: # # can get weird ImportError: dlopen: cannot load any more object # with static TLS diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 79e2f3d66a0..7e6630f1bef 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2577,10 +2577,10 @@ step as a key. Steps not in the dictionary will use the default value. The default (None) is equivalent to: - niter=dict(translation=(10000, 1000, 100), - rigid=(10000, 1000, 100), - affine=(10000, 1000, 100), - sdr=(10, 10, 5)) + niter=dict(translation=(100, 100, 10), + rigid=(100, 100, 10), + affine=(100, 100, 10), + sdr=(5, 5, 3)) """ docdict['pipeline'] = """ pipeline : str | tuple diff --git a/mne/viz/misc.py b/mne/viz/misc.py index cdbfb3c0cdd..dfe0e15d92a 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -513,7 +513,7 @@ def plot_bem(subject=None, subjects_dir=None, orientation='coronal', bem_path = op.join(subjects_dir, subject, 'bem') if not op.isdir(bem_path): - raise IOError('Subject bem directory "%s" does not exist' % bem_path) + raise IOError(f'Subject bem directory "{bem_path}" does not exist') surfaces = _get_bem_plotting_surfaces(bem_path) if brain_surfaces is not None: diff --git a/tutorials/clinical/10_ieeg_localize.py b/tutorials/clinical/10_ieeg_localize.py index 2eace180f59..7c2b5f311b9 100644 --- a/tutorials/clinical/10_ieeg_localize.py +++ b/tutorials/clinical/10_ieeg_localize.py @@ -6,15 +6,18 @@ Locating Intracranial Electrode Contacts ======================================== -Intracranial electrophysiology recording contacts are generally localized -based on a post-implantation computed tomography (CT) image and a -pre-implantation magnetic resonance (MR) image. The CT image has greater -intensity than the background at each of the electrode contacts and -for the skull. Using the skull, the CT can be aligned to MR-space. -Contact locations in MR-space are the goal because this is the image from which -brain structures can be determined using the -:ref:`tut-freesurfer-reconstruction`. Contact locations in MR-space can also -be translated to a template space such as ``fsaverage`` for group comparisons. +Analysis of intracranial electrophysiology recordings typically involves +finding the position of each contact relative to brain structures. In a +typical setup, the brain and the electrode locations will be in two places +and will have to be aligned; the brain is best visualized by a +pre-implantation magnetic resonance (MR) image whereas the electrode contact +locations are best visualized in a post-implantation computed tomography (CT) +image. The CT image has greater intensity than the background at each of the +electrode contacts and for the skull. Using the skull, the CT can be aligned +to MR-space. This accomplishes our goal of obtaining contact locations in +MR-space (which is where the brain structures are best determined using the +:ref:`tut-freesurfer-reconstruction`). Contact locations in MR-space can also +be warped to a template space such as ``fsaverage`` for group comparisons. """ # Authors: Alex Rockhill @@ -24,7 +27,6 @@ # %% -import os import os.path as op import numpy as np @@ -46,6 +48,58 @@ # use mne-python's fsaverage data fetch_fsaverage(subjects_dir=subjects_dir, verbose=True) # downloads if needed +############################################################################### +# Aligning the T1 to ACPC +# ======================= +# +# For intracranial electrophysiology recordings, the Brain Imaging Data +# Structure (BIDS) standard requires that coordinates be aligned to the +# anterior commissure and posterior commissure (ACPC-aligned). Therefore, it is +# recommended that you do this alignment before finding the positions of the +# channels in your recording. Doing this will make the "mri" (aka surface RAS) +# coordinate frame an ACPC coordinate frame. This can be done using +# Freesurfer's freeview: +# +# .. code-block:: bash +# +# $ freeview $MISC_PATH/seeg/sample_seeg_T1.mgz +# +# And then interact with the graphical user interface: +# +# First, it is recommended to change the cursor style to long, this can be done +# through the menu options like so: +# +# ``Freeview -> Preferences -> General -> Cursor style -> Long`` +# +# Then, the image needs to be aligned to ACPC to look like the image below. +# This can be done by pulling up the transform popup from the menu like so: +# +# ``Tools -> Transform Volume`` +# +# .. note:: +# Be sure to set the text entry box labeled RAS (not TkReg RAS) to +# ``0 0 0`` before beginning the transform. +# +# Then translate the image until the crosshairs meet on the AC and +# run through the PC as shown in the plot. The eyes should be in +# the ACPC plane and the image should be rotated until they are symmetrical, +# and the crosshairs should transect the midline of the brain. +# Be sure to use both the rotate and the translate menus and save the volume +# after you're finished using ``Save Volume As`` in the transform popup +# :footcite:`HamiltonEtAl2017`. + +T1 = nib.load(op.join(misc_path, 'seeg', 'sample_seeg', 'mri', 'T1.mgz')) +viewer = T1.orthoview() +viewer.set_position(0, 9.9, 5.8) +viewer.figs[0].axes[0].annotate( + 'PC', (107, 108), xytext=(10, 75), color='white', + horizontalalignment='center', + arrowprops=dict(facecolor='white', lw=0.5, width=2, headwidth=5)) +viewer.figs[0].axes[0].annotate( + 'AC', (137, 108), xytext=(246, 75), color='white', + horizontalalignment='center', + arrowprops=dict(facecolor='white', lw=0.5, width=2, headwidth=5)) + # %% # Freesurfer recon-all # ==================== @@ -105,7 +159,6 @@ def plot_overlay(image, compare, title, thresh=None): fig.tight_layout() -T1 = nib.load(op.join(misc_path, 'seeg', 'sample_seeg', 'mri', 'T1.mgz')) CT_orig = nib.load(op.join(misc_path, 'seeg', 'sample_seeg_CT.mgz')) # resample to T1's definition of world coordinates @@ -125,7 +178,7 @@ def plot_overlay(image, compare, title, thresh=None): # here:: # # reg_affine, _ = mne.transforms.compute_volume_registration( -# CT_orig, T1, pipeline='rigids', verbose=True) +# CT_orig, T1, pipeline='rigids') # # And instead we just hard-code the resulting 4x4 matrix: @@ -173,6 +226,21 @@ def plot_overlay(image, compare, title, thresh=None): fig.tight_layout() del CT_data, T1 +# %% +# Now we need to estimate the "head" coordinate transform. +# +# MNE stores digitization montages in a coordinate frame called "head" +# defined by fiducial points (origin is halfway between the LPA and RPA +# see :ref:`tut-source-alignment`). For sEEG, it is convenient to get an +# estimate of the location of the fiducial points for the subject +# using the Talairach transform (see :func:`mne.coreg.get_mni_fiducials`) +# to use to define the coordinate frame so that we don't have to manually +# identify their location. + +# estimate head->mri transform +subj_trans = mne.coreg.estimate_head_mri_t( + 'sample_seeg', op.join(misc_path, 'seeg')) + # %% # Marking the Location of Each Electrode Contact # ============================================== @@ -180,35 +248,46 @@ def plot_overlay(image, compare, title, thresh=None): # Now, the CT and the MR are in the same space, so when you are looking at a # point in CT space, it is the same point in MR space. So now everything is # ready to determine the location of each electrode contact in the -# individual subject's anatomical space (T1-space). To do this, can make -# list of ``TkReg RAS`` points from the lower panel in freeview or use the -# mne graphical user interface (coming soon). The electrode locations will then -# be in the ``surface RAS`` coordinate frame, which is helpful because that is -# the coordinate frame that all the surface and image files that freesurfer -# outputs are in, see :ref:`tut-freesurfer-mne`. +# individual subject's anatomical space (T1-space). To do this, we can use the +# MNE intracranial electrode location graphical user interface. # -# The electrode contact locations could be determined using freeview by -# clicking through and noting each contact position in the interface launched -# by the following command: +# .. note: The most useful coordinate frame for intracranial electrodes is +# generally the ``surface RAS`` coordinate frame because that is +# the coordinate frame that all the surface and image files that +# Freesurfer outputs are in, see :ref:`tut-freesurfer-mne`. These are +# useful for finding the brain structures nearby each contact and +# plotting the results. # -# .. code-block:: bash +# To operate the GUI: # -# $ freeview $MISC_PATH/seeg/sample_seeg_T1.mgz \ -# $MISC_PATH/seeg/sample_seeg_CT.mgz - -# %% -# Now, we'll make a montage with the channels that we've found in the -# previous step. +# - Click in each image to navigate to each electrode contact +# - Select the contact name in the right panel +# - Press the "Mark" button or the "m" key to associate that +# position with that contact +# - Repeat until each contact is marked, they will both appear as circles +# in the plots and be colored in the sidebar when marked # -# .. note:: MNE represents data in the "head" space internally +# .. note:: The channel locations are saved to the ``raw`` object every time +# a location is marked or removed so there is no "Save" button. +# +# .. note:: Using the scroll or +/- arrow keys you can zoom in and out, +# and the up/down, left/right and page up/page down keys allow +# you to move one slice in any direction. This information is +# available in the help menu, accessible by pressing the "h" key. +# +# .. note:: If "Snap to Center" is on, this will use the radius so be +# sure to set it properly. -# load electrophysiology data with channel locations +# load electrophysiology data to find channel locations for +# (the channels are already located in the example) raw = mne.io.read_raw(op.join(misc_path, 'seeg', 'sample_seeg_ieeg.fif')) -# create symbolic link to share ``subjects_dir`` -if not op.exists(op.join(subjects_dir, 'sample_seeg')): - os.symlink(op.join(misc_path, 'seeg', 'sample_seeg'), - op.join(subjects_dir, 'sample_seeg')) +gui = mne.gui.locate_ieeg(raw.info, subj_trans, CT_aligned, + subject='sample_seeg', + subjects_dir=op.join(misc_path, 'seeg')) +# The `raw` object is modified to contain the channel locations +# after closing the GUI and can now be saved +gui.close() # close when done # %% # Let's plot the electrode contact locations on the subject's brain. @@ -222,13 +301,10 @@ def plot_overlay(image, compare, title, thresh=None): # identify their location. The estimated head->mri ``trans`` was used # when the electrode contacts were localized so we need to use it again here. -# estimate head->mri transform -subj_trans = mne.coreg.estimate_head_mri_t('sample_seeg', subjects_dir) - # plot the alignment -brain_kwargs = dict(cortex='low_contrast', alpha=0.2, background='white', - subjects_dir=subjects_dir) -brain = mne.viz.Brain('sample_seeg', **brain_kwargs) +brain_kwargs = dict(cortex='low_contrast', alpha=0.2, background='white') +brain = mne.viz.Brain('sample_seeg', subjects_dir=op.join(misc_path, 'seeg'), + **brain_kwargs) brain.add_sensors(raw.info, trans=subj_trans) view_kwargs = dict(azimuth=60, elevation=100, distance=350, focalpoint=(0, 0, -15)) @@ -269,7 +345,6 @@ def plot_overlay(image, compare, title, thresh=None): # is useful for getting a quick view of the data, but finalized # pipelines should use ``zooms=None`` instead! -CT_thresh = 0.8 # 0.95 is better for zooms=None! reg_affine, sdr_morph = mne.transforms.compute_volume_registration( subject_brain, template_brain, zooms=5, verbose=True) subject_brain_sdr = mne.transforms.apply_volume_registration( @@ -297,9 +372,11 @@ def plot_overlay(image, compare, title, thresh=None): montage = raw.get_montage() montage.apply_trans(subj_trans) +# higher thresh such as 0.5 (default) works when `zooms=None` montage_warped, elec_image, warped_elec_image = mne.warp_montage_volume( - montage, CT_aligned, reg_affine, sdr_morph, - subject_from='sample_seeg', subjects_dir=subjects_dir, thresh=CT_thresh) + montage, CT_aligned, reg_affine, sdr_morph, thresh=0.1, + subject_from='sample_seeg', subjects_dir_from=op.join(misc_path, 'seeg'), + subject_to='fsaverage', subjects_dir_to=subjects_dir) fig, axes = plt.subplots(2, 1, figsize=(8, 8)) nilearn.plotting.plot_glass_brain(elec_image, axes=axes[0], cmap='Dark2') @@ -334,7 +411,7 @@ def plot_overlay(image, compare, title, thresh=None): raw.set_montage(montage_warped) # plot the resulting alignment -brain = mne.viz.Brain('fsaverage', **brain_kwargs) +brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, **brain_kwargs) brain.add_sensors(raw.info, trans=template_trans) brain.show_view(**view_kwargs) From 57094f1970724a94481d9712af0daffd7cf1afa0 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 15 Sep 2021 13:26:09 -0700 Subject: [PATCH 2/2] mainak + eric suggestions --- mne/gui/_ieeg_locate_gui.py | 65 ++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/mne/gui/_ieeg_locate_gui.py b/mne/gui/_ieeg_locate_gui.py index efb787cf9c7..a88333af304 100644 --- a/mne/gui/_ieeg_locate_gui.py +++ b/mne/gui/_ieeg_locate_gui.py @@ -379,7 +379,11 @@ def _get_button_bar(self): """Make a bar with buttons for user interactions.""" hbox = QHBoxLayout() - hbox.addStretch(10) + help_button = QPushButton('Help') + help_button.released.connect(self._show_help) + hbox.addWidget(help_button) + + hbox.addStretch(8) hbox.addWidget(QLabel('Snap to Center')) self._snap_button = QPushButton('Off') @@ -390,6 +394,12 @@ def _get_button_bar(self): hbox.addStretch(1) + self._toggle_brain_button = QPushButton('Show Brain') + self._toggle_brain_button.released.connect(self._toggle_show_brain) + hbox.addWidget(self._toggle_brain_button) + + hbox.addStretch(1) + mark_button = QPushButton('Mark') hbox.addWidget(mark_button) mark_button.released.connect(self._mark_ch) @@ -412,7 +422,7 @@ def _get_button_bar(self): brush, QtCore.Qt.BackgroundRole) self._group_selector.clicked.connect(self._select_group) self._group_selector.currentIndexChanged.connect( - self._update_group) + self._select_group) hbox.addWidget(self._group_selector) # update background color for current selection @@ -787,21 +797,41 @@ def _move_cursor_to(self, axis, x, y): self._images['cursor2'][axis].set_ydata([y, y]) self._images['cursor'][axis].set_xdata([x, x]) + def _show_help(self): + """Show the help menu.""" + QMessageBox.information( + self, 'Help', + "Help:\n'm': mark channel location\n" + "'r': remove channel location\n" + "'b': toggle viewing of brain in T1\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "page up/page down arrow: anterior/posterior") + + def _toggle_show_brain(self): + """Toggle whether the brain/MRI is being shown.""" + if 'mri' in self._images: + for img in self._images['mri']: + img.remove() + self._images.pop('mri') + self._toggle_brain_button.setText('Show Brain') + else: + self._images['mri'] = list() + for axis in range(3): + mri_data = np.take(self._mri_data, + self._current_slice[axis], axis=axis).T + self._images['mri'].append(self._figs[axis].axes[0].imshow( + mri_data, cmap='hot', aspect='auto', alpha=0.25)) + self._toggle_brain_button.setText('Hide Brain') + self._draw() + def _key_press_event(self, event): """Execute functions when the user presses a key.""" if event.key() == 'escape': self.close() if event.text() == 'h': - # Show help - QMessageBox.information( - self, 'Help', - "Help:\n'm': mark channel location\n" - "'r': remove channel location\n" - "'b': toggle viewing of brain in T1\n" - "'+'/'-': zoom\nleft/right arrow: left/right\n" - "up/down arrow: superior/inferior\n" - "page up/page down arrow: anterior/posterior") + self._show_help() if event.text() == 'm': self._mark_ch() @@ -810,18 +840,7 @@ def _key_press_event(self, event): self._remove_ch() if event.text() == 'b': - if 'mri' in self._images: - for img in self._images['mri']: - img.remove() - self._images.pop('mri') - else: - self._images['mri'] = list() - for axis in range(3): - mri_data = np.take(self._mri_data, - self._current_slice[axis], axis=axis).T - self._images['mri'].append(self._figs[axis].axes[0].imshow( - mri_data, cmap='hot', aspect='auto', alpha=0.25)) - self._draw() + self._toggle_show_brain() if event.text() in ('=', '+', '-'): self._zoom(sign=-2 * (event.text() == '-') + 1, draw=True)