diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index e0b2a8d0d21..4436ae2fc2a 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -159,6 +159,7 @@ Enhancements - Reading EDF files via :func:`mne.io.read_raw_edf` now can infer channel type from the signal label in the EDF header (:gh:`9694` by `Adam Li`_) +- Add :func:`mne.gui.locate_ieeg` to locate intracranial electrode contacts from a CT, an MRI (with Freesurfer ``recon-all``) and the channel names from an :class:`mne.Info` object (:gh:`9586` by `Alex Rockhill`_) Bugs ~~~~ diff --git a/doc/conf.py b/doc/conf.py index 91aba4bc6de..c5fe3ccb71b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -258,6 +258,7 @@ # unlinkable 'mayavi.mlab.pipeline.surface', 'CoregFrame', 'Kit2FiffFrame', 'FiducialsFrame', + 'IntracranialElectrodeLocator' } numpydoc_validate = True numpydoc_validation_checks = {'all'} | set(error_ignores) diff --git a/doc/mri.rst b/doc/mri.rst index 889246feb84..9689b4c705e 100644 --- a/doc/mri.rst +++ b/doc/mri.rst @@ -19,6 +19,7 @@ Step by step instructions for using :func:`gui.coregistration`: get_montage_volume_labels gui.coregistration gui.fiducials + gui.locate_ieeg create_default_subject head_to_mni head_to_mri diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 683313f722b..b4f099d47d0 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -254,7 +254,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, path = _get_path(path, key, name) # To update the testing or misc dataset, push commits, then make a new # release on GitHub. Then update the "releases" variable: - releases = dict(testing='0.123', misc='0.18') + releases = dict(testing='0.123', misc='0.19') # And also update the "md5_hashes['testing']" variable below. # To update any other dataset, update the data archive itself (upload # an updated version) and update the md5 hash. @@ -345,7 +345,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, bst_raw='fa2efaaec3f3d462b319bc24898f440c', bst_resting='70fc7bf9c3b97c4f2eab6260ee4a0430'), fake='3194e9f7b46039bb050a74f3e1ae9908', - misc='0aa25a9bb4f204b3d4769f0b84e9b526', + misc='b5ebe66dbe0f36cba9170a7ce909a66f', sample='12b75d1cb7df9dfb4ad73ed82f61094f', somato='32fd2f6c8c7eb0784a1de6435273c48b', spm='9f43f67150e3b694b523a21eb929ea75', diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index 40ea669b8b8..ae9e1ea308f 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -239,3 +239,44 @@ def kit2fiff(): from ._kit2fiff_gui import Kit2FiffFrame frame = Kit2FiffFrame() return _initialize_gui(frame) + + +@verbose +def locate_ieeg(info, trans, aligned_ct, subject=None, subjects_dir=None, + groups=None, verbose=None): + """Locate intracranial electrode contacts. + + Parameters + ---------- + %(info_not_none)s + %(trans_not_none)s + aligned_ct : str | pathlib.Path | nibabel.spatialimages.SpatialImage + The CT image that has been aligned to the Freesurfer T1. Path-like + inputs and nibabel image objects are supported. + %(subject)s + %(subjects_dir)s + groups : dict | None + A dictionary with channels as keys and their group index as values. + If None, the groups will be inferred by the channel names. Channel + names must have a format like ``LAMY 7`` where a string prefix + like ``LAMY`` precedes a numeric index like ``7``. If the channels + are formatted improperly, group plotting will work incorrectly. + Group assignments can be adjusted in the GUI. + %(verbose)s + + Returns + ------- + gui : instance of IntracranialElectrodeLocator + The graphical user interface (GUI) window. + """ + from ._ieeg_locate_gui import IntracranialElectrodeLocator + from PyQt5.QtWidgets import QApplication + # get application + app = QApplication.instance() + if app is None: + app = QApplication(["Intracranial Electrode Locator"]) + gui = IntracranialElectrodeLocator( + info, trans, aligned_ct, subject=subject, + subjects_dir=subjects_dir, groups=groups, verbose=verbose) + gui.show() + return gui diff --git a/mne/gui/_ieeg_locate_gui.py b/mne/gui/_ieeg_locate_gui.py new file mode 100644 index 00000000000..a88333af304 --- /dev/null +++ b/mne/gui/_ieeg_locate_gui.py @@ -0,0 +1,883 @@ +# -*- coding: utf-8 -*- +"""Intracranial elecrode localization GUI for finding contact locations.""" + +# Authors: Alex Rockhill +# +# License: BSD (3-clause) + +import os.path as op +import numpy as np +from functools import partial + +from matplotlib.colors import LinearSegmentedColormap + +from PyQt5 import QtCore, QtGui, Qt +from PyQt5.QtCore import pyqtSlot +from PyQt5.QtWidgets import (QMainWindow, QGridLayout, + QVBoxLayout, QHBoxLayout, QLabel, + QMessageBox, QWidget, + QListView, QSlider, QPushButton, + QComboBox, QPlainTextEdit) +from matplotlib.backends.backend_qt5agg import FigureCanvas +from matplotlib.figure import Figure + +from .._freesurfer import _check_subject_dir, _import_nibabel +from ..viz.backends.renderer import _get_renderer +from ..surface import _read_mri_surface, _voxel_neighbors +from ..transforms import (apply_trans, _frame_to_str, _get_trans, + invert_transform) +from ..utils import logger, _check_fname, _validate_type, verbose, warn + +_IMG_LABELS = [['I', 'P'], ['I', 'L'], ['P', 'L']] +_CH_PLOT_SIZE = 1024 +_ZOOM_STEP_SIZE = 5 + +# 20 colors generated to be evenly spaced in a cube, worked better than +# matplotlib color cycle +_UNIQUE_COLORS = [(0.1, 0.42, 0.43), (0.9, 0.34, 0.62), (0.47, 0.51, 0.3), + (0.47, 0.55, 0.99), (0.79, 0.68, 0.06), (0.34, 0.74, 0.05), + (0.58, 0.87, 0.13), (0.86, 0.98, 0.4), (0.92, 0.91, 0.66), + (0.77, 0.38, 0.34), (0.9, 0.37, 0.1), (0.2, 0.62, 0.9), + (0.22, 0.65, 0.64), (0.14, 0.94, 0.8), (0.34, 0.31, 0.68), + (0.59, 0.28, 0.74), (0.46, 0.19, 0.94), (0.37, 0.93, 0.7), + (0.56, 0.86, 0.55), (0.67, 0.69, 0.44)] +_N_COLORS = len(_UNIQUE_COLORS) +_CMAP = LinearSegmentedColormap.from_list( + 'ch_colors', _UNIQUE_COLORS, N=_N_COLORS) + + +def _load_image(img, name, verbose=True): + """Load data from a 3D image file (e.g. CT, MR).""" + nib = _import_nibabel('use iEEG GUI') + if not isinstance(img, nib.spatialimages.SpatialImage): + if verbose: + logger.info(f'Loading {img}') + _check_fname(img, overwrite='read', must_exist=True, name=name) + img = nib.load(img) + # get data + orig_data = np.array(img.dataobj).astype(np.float32) + # reorient data to RAS + ornt = nib.orientations.axcodes2ornt( + nib.orientations.aff2axcodes(img.affine)).astype(int) + ras_ornt = nib.orientations.axcodes2ornt('RAS') + ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) + img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) + orig_mgh = nib.MGHImage(orig_data, img.affine) + aff_trans = nib.orientations.inv_ornt_aff(ornt_trans, img.shape) + vox_ras_t = np.dot(orig_mgh.header.get_vox2ras_tkr(), aff_trans) + return img_data, vox_ras_t + + +class ComboBox(QComboBox): + """Dropdown menu that emits a click when popped up.""" + + clicked = QtCore.pyqtSignal() + + def showPopup(self): + """Override show popup method to emit click.""" + self.clicked.emit() + super(ComboBox, self).showPopup() + + +def _make_slice_plot(width=4, height=4, dpi=300): + fig = Figure(figsize=(width, height)) + canvas = FigureCanvas(fig) + ax = fig.subplots() + fig.subplots_adjust(bottom=0, left=0, right=1, + top=1, wspace=0, hspace=0) + ax.set_facecolor('k') + # clean up excess plot text, invert + ax.invert_yaxis() + ax.set_xticks([]) + ax.set_yticks([]) + return canvas, fig + + +class IntracranialElectrodeLocator(QMainWindow): + """Locate electrode contacts using a coregistered MRI and CT.""" + + def __init__(self, info, trans, aligned_ct, subject=None, + subjects_dir=None, groups=None, verbose=None): + """GUI for locating intracranial electrodes. + + .. note:: Images will be displayed using orientation information + obtained from the image header. Images will be resampled to + dimensions [256, 256, 256] for display. + """ + # initialize QMainWindow class + super(IntracranialElectrodeLocator, self).__init__() + + if not info.ch_names: + raise ValueError('No channels found in `info` to locate') + + # store info for modification + self._info = info + self._verbose = verbose + + # load imaging data + self._subject_dir = _check_subject_dir(subject, subjects_dir) + self._load_image_data(aligned_ct) + + self._ch_alpha = 0.5 + self._radius = int(_CH_PLOT_SIZE // 100) # starting 1/200 of image + # initialize channel data + self._ch_index = 0 + # load data, apply trans + self._head_mri_t = _get_trans(trans, 'head', 'mri')[0] + self._mri_head_t = invert_transform(self._head_mri_t) + # load channels, convert from m to mm + self._chs = {name: apply_trans(self._head_mri_t, ch['loc'][:3]) * 1000 + for name, ch in zip(info.ch_names, info['chs'])} + self._ch_names = list(self._chs.keys()) + # set current position + if np.isnan(self._chs[self._ch_names[self._ch_index]]).any(): + self._ras = np.array([0., 0., 0.]) + else: + self._ras = self._chs[self._ch_names[self._ch_index]].copy() + self._current_slice = apply_trans( + self._ras_vox_t, self._ras).round().astype(int) + self._group_channels(groups) + + # GUI design + + # Main plots: make one plot for each view; sagittal, coronal, axial + plt_grid = QGridLayout() + plts = [_make_slice_plot(), _make_slice_plot(), _make_slice_plot()] + self._figs = [plts[0][1], plts[1][1], plts[2][1]] + plt_grid.addWidget(plts[0][0], 0, 0) + plt_grid.addWidget(plts[1][0], 0, 1) + plt_grid.addWidget(plts[2][0], 1, 0) + self._renderer = _get_renderer( + name='IEEG Locator', size=(400, 400), bgcolor='w') + # TODO: should eventually make sure the renderer here is actually + # some PyVista(Qt) variant, not mayavi, otherwise the following + # call will fail (hopefully it's rare that people who want to use this + # have also set their MNE_3D_BACKEND=mayavi and/or don't have a working + # pyvistaqt setup; also hopefully the refactoring to use the + # Qt/notebook abstraction will make this easier, too): + plt_grid.addWidget(self._renderer.plotter) + + # Channel selector + self._ch_list = QListView() + self._ch_list.setSelectionMode(Qt.QAbstractItemView.SingleSelection) + self._ch_list.setMinimumWidth(150) + self._set_ch_names() + + # Plots + self._plot_images() + + # Menus + button_hbox = self._get_button_bar() + slider_hbox = self._get_slider_bar() + bottom_hbox = self._get_bottom_bar() + + # Put everything together + plot_ch_hbox = QHBoxLayout() + plot_ch_hbox.addLayout(plt_grid) + plot_ch_hbox.addWidget(self._ch_list) + + main_vbox = QVBoxLayout() + main_vbox.addLayout(button_hbox) + main_vbox.addLayout(slider_hbox) + main_vbox.addLayout(plot_ch_hbox) + main_vbox.addLayout(bottom_hbox) + + central_widget = QWidget() + central_widget.setLayout(main_vbox) + self.setCentralWidget(central_widget) + + # ready for user + self._move_cursors_to_pos() + self._ch_list.setFocus() # always focus on list + + def _load_image_data(self, ct): + """Get MRI and CT data to display and transforms to/from vox/RAS.""" + self._mri_data, self._vox_ras_t = _load_image( + op.join(self._subject_dir, 'mri', 'brain.mgz'), + 'MRI Image', verbose=self._verbose) + self._ras_vox_t = np.linalg.inv(self._vox_ras_t) + + self._voxel_sizes = np.array(self._mri_data.shape) + self._img_ranges = [[0, self._voxel_sizes[1], 0, self._voxel_sizes[2]], + [0, self._voxel_sizes[0], 0, self._voxel_sizes[2]], + [0, self._voxel_sizes[0], 0, self._voxel_sizes[1]]] + + # ready ct + self._ct_data, vox_ras_t = _load_image(ct, 'CT', verbose=self._verbose) + if self._mri_data.shape != self._ct_data.shape or \ + not np.allclose(self._vox_ras_t, vox_ras_t, rtol=1e-6): + raise ValueError('CT is not aligned to MRI, got ' + f'CT shape={self._ct_data.shape}, ' + f'MRI shape={self._mri_data.shape}, ' + f'CT affine={vox_ras_t} and ' + f'MRI affine={self._vox_ras_t}') + + if op.exists(op.join(self._subject_dir, 'surf', 'lh.seghead')): + self._head = _read_mri_surface( + op.join(self._subject_dir, 'surf', 'lh.seghead')) + assert _frame_to_str[self._head['coord_frame']] == 'mri' + else: + warn('`seghead` not found, skipping head plot, see ' + ':ref:`mne.bem.make_scalp_surfaces` to add the head') + self._head = None + if op.exists(op.join(self._subject_dir, 'surf', 'lh.pial')): + self._lh = _read_mri_surface( + op.join(self._subject_dir, 'surf', 'lh.pial')) + assert _frame_to_str[self._lh['coord_frame']] == 'mri' + self._rh = _read_mri_surface( + op.join(self._subject_dir, 'surf', 'rh.pial')) + assert _frame_to_str[self._rh['coord_frame']] == 'mri' + else: + warn('`pial` surface not found, skipping adding to 3D ' + 'plot. This indicates the Freesurfer recon-all ' + 'has been modified and these files have been deleted.') + self._lh = self._rh = None + + def _make_ch_image(self, axis): + """Make a plot to display the channel locations.""" + # Make channel data higher resolution so it looks better. + ch_image = np.zeros((_CH_PLOT_SIZE, _CH_PLOT_SIZE)) * np.nan + vx, vy, vz = self._voxel_sizes + + def color_ch_radius(ch_image, xf, yf, group, radius): + # Take the fraction across each dimension of the RAS + # coordinates converted to xyz and put a circle in that + # position in this larger resolution image + ex, ey = np.round(np.array([xf, yf]) * _CH_PLOT_SIZE).astype(int) + for i in range(-radius, radius + 1): + for j in range(-radius, radius + 1): + if (i**2 + j**2)**0.5 < radius: + # negative y because y axis is inverted + ch_image[-(ey + i), ex + j] = group + return ch_image + + for name, ras in self._chs.items(): + # move from middle-centered (half coords positive, half negative) + # to bottom-left corner centered (all coords positive). + if np.isnan(ras).any(): + continue + xyz = apply_trans(self._ras_vox_t, ras) + # check if closest to that voxel + dist = np.linalg.norm(xyz - self._current_slice) + if dist < self._radius: + x, y, z = xyz + group = self._groups[name] + r = self._radius - np.round(abs(dist)).astype(int) + if axis == 0: + ch_image = color_ch_radius( + ch_image, y / vy, z / vz, group, r) + elif axis == 1: + ch_image = color_ch_radius( + ch_image, x / vx, z / vx, group, r) + elif axis == 2: + ch_image = color_ch_radius( + ch_image, x / vx, y / vy, group, r) + return ch_image + + @verbose + def _save_ch_coords(self, info=None, verbose=None): + """Save the location of the electrode contacts.""" + logger.info('Saving channel positions to `info`') + if info is None: + info = self._info + for name, ch in zip(info.ch_names, info['chs']): + ch['loc'][:3] = apply_trans( + self._mri_head_t, self._chs[name] / 1000) # mm->m + + def _plot_images(self): + """Use the MRI and CT to make plots.""" + # Plot sagittal (0), coronal (1) or axial (2) view + self._images = dict(ct=list(), chs=list(), + cursor=list(), cursor2=list()) + ct_min, ct_max = np.nanmin(self._ct_data), np.nanmax(self._ct_data) + text_kwargs = dict(fontsize=3, color='#66CCEE', family='monospace', + weight='bold', ha='center', va='center') + xyz = apply_trans(self._ras_vox_t, self._ras) + for axis in range(3): + ct_data = np.take(self._ct_data, self._current_slice[axis], + axis=axis).T + self._images['ct'].append(self._figs[axis].axes[0].imshow( + ct_data, cmap='gray', aspect='auto', + vmin=ct_min, vmax=ct_max)) + self._images['chs'].append( + self._figs[axis].axes[0].imshow( + self._make_ch_image(axis), aspect='auto', + extent=self._img_ranges[axis], + cmap=_CMAP, alpha=self._ch_alpha, vmin=0, vmax=_N_COLORS)) + self._images['cursor'].append( + self._figs[axis].axes[0].plot( + (xyz[axis], xyz[axis]), (0, self._voxel_sizes[axis]), + color=[0, 1, 0], linewidth=1, alpha=0.5)[0]) + self._images['cursor2'].append( + self._figs[axis].axes[0].plot( + (0, self._voxel_sizes[axis]), (xyz[axis], xyz[axis]), + color=[0, 1, 0], linewidth=1, alpha=0.5)[0]) + # label axes + self._figs[axis].text(0.5, 0.05, _IMG_LABELS[axis][0], + **text_kwargs) + self._figs[axis].text(0.05, 0.5, _IMG_LABELS[axis][1], + **text_kwargs) + self._figs[axis].axes[0].axis(self._img_ranges[axis]) + self._figs[axis].canvas.mpl_connect( + 'scroll_event', self._on_scroll) + self._figs[axis].canvas.mpl_connect( + 'button_release_event', partial(self._on_click, axis)) + # add head and brain in mm (convert from m) + if self._head is not None: + self._renderer.mesh( + *self._head['rr'].T * 1000, triangles=self._head['tris'], + color='gray', opacity=0.2, reset_camera=False, render=False) + if self._lh is not None and self._rh is not None: + self._renderer.mesh( + *self._lh['rr'].T * 1000, triangles=self._lh['tris'], + color='white', opacity=0.2, reset_camera=False, render=False) + self._renderer.mesh( + *self._rh['rr'].T * 1000, triangles=self._rh['tris'], + color='white', opacity=0.2, reset_camera=False, render=False) + self._3d_chs = dict() + self._plot_3d_ch_pos() + self._renderer.set_camera(azimuth=90, elevation=90, distance=300, + focalpoint=tuple(self._ras)) + # update plots + self._draw() + self._renderer._update() + + def _scale_radius(self): + """Scale the radius to mm.""" + shape = np.mean(self._ct_data.shape) # this is Freesurfer shape (256) + scale = np.diag(self._ras_vox_t)[:3].mean() + return scale * self._radius * (shape / _CH_PLOT_SIZE) + + def _update_camera(self, render=False): + """Update the camera position.""" + self._renderer.set_camera( + # needs fix, distance moves when focal point updates + distance=self._renderer.plotter.camera.distance * 0.9, + focalpoint=tuple(self._ras), + reset_camera=False) + + def _plot_3d_ch(self, name, render=False): + """Plot a single 3D channel.""" + if name in self._3d_chs: + self._renderer.plotter.remove_actor(self._3d_chs.pop(name)) + if not any(np.isnan(self._chs[name])): + radius = self._scale_radius() + self._3d_chs[name] = self._renderer.sphere( + tuple(self._chs[name]), scale=radius * 3, + color=_UNIQUE_COLORS[self._groups[name] % _N_COLORS], + opacity=self._ch_alpha)[0] + if render: + self._renderer._update() + + def _plot_3d_ch_pos(self, render=False): + for name in self._chs: + self._plot_3d_ch(name) + if render: + self._renderer._update() + + def _get_button_bar(self): + """Make a bar with buttons for user interactions.""" + hbox = QHBoxLayout() + + help_button = QPushButton('Help') + help_button.released.connect(self._show_help) + hbox.addWidget(help_button) + + hbox.addStretch(8) + + hbox.addWidget(QLabel('Snap to Center')) + self._snap_button = QPushButton('Off') + self._snap_button.setMaximumWidth(25) # not too big + hbox.addWidget(self._snap_button) + self._snap_button.released.connect(self._toggle_snap) + self._toggle_snap() # turn on to start + + hbox.addStretch(1) + + self._toggle_brain_button = QPushButton('Show Brain') + self._toggle_brain_button.released.connect(self._toggle_show_brain) + hbox.addWidget(self._toggle_brain_button) + + hbox.addStretch(1) + + mark_button = QPushButton('Mark') + hbox.addWidget(mark_button) + mark_button.released.connect(self._mark_ch) + + remove_button = QPushButton('Remove') + hbox.addWidget(remove_button) + remove_button.released.connect(self._remove_ch) + + self._group_selector = ComboBox() + group_model = self._group_selector.model() + + for i in range(_N_COLORS): + self._group_selector.addItem(' ') + color = QtGui.QColor() + color.setRgb(*(255 * np.array(_UNIQUE_COLORS[i % _N_COLORS]) + ).round().astype(int)) + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + group_model.setData(group_model.index(i, 0), + brush, QtCore.Qt.BackgroundRole) + self._group_selector.clicked.connect(self._select_group) + self._group_selector.currentIndexChanged.connect( + self._select_group) + hbox.addWidget(self._group_selector) + + # update background color for current selection + self._update_group() + + return hbox + + def _get_slider_bar(self): + """Make a bar with sliders on it.""" + + def make_label(name): + label = QLabel(name) + label.setAlignment(QtCore.Qt.AlignCenter) + return label + + def make_slider(smin, smax, sval, sfun=None): + slider = QSlider(QtCore.Qt.Horizontal) + slider.setMinimum(int(round(smin))) + slider.setMaximum(int(round(smax))) + slider.setValue(int(round(sval))) + slider.setTracking(False) # only update on release + if sfun is not None: + slider.valueChanged.connect(sfun) + slider.keyPressEvent = self._key_press_event + return slider + + slider_hbox = QHBoxLayout() + + ch_vbox = QVBoxLayout() + ch_vbox.addWidget(make_label('ch alpha')) + ch_vbox.addWidget(make_label('ch radius')) + slider_hbox.addLayout(ch_vbox) + + ch_slider_vbox = QVBoxLayout() + self._alpha_slider = make_slider(0, 100, self._ch_alpha * 100, + self._update_ch_alpha) + ch_plot_max = _CH_PLOT_SIZE // 50 # max 1 / 50 of plot size + ch_slider_vbox.addWidget(self._alpha_slider) + self._radius_slider = make_slider(0, ch_plot_max, self._radius, + self._update_radius) + ch_slider_vbox.addWidget(self._radius_slider) + slider_hbox.addLayout(ch_slider_vbox) + + ct_vbox = QVBoxLayout() + ct_vbox.addWidget(make_label('CT min')) + ct_vbox.addWidget(make_label('CT max')) + slider_hbox.addLayout(ct_vbox) + + ct_slider_vbox = QVBoxLayout() + ct_min = int(round(np.nanmin(self._ct_data))) + ct_max = int(round(np.nanmax(self._ct_data))) + self._ct_min_slider = make_slider( + ct_min, ct_max, ct_min, self._update_ct_scale) + ct_slider_vbox.addWidget(self._ct_min_slider) + self._ct_max_slider = make_slider( + ct_min, ct_max, ct_max, self._update_ct_scale) + ct_slider_vbox.addWidget(self._ct_max_slider) + slider_hbox.addLayout(ct_slider_vbox) + return slider_hbox + + def _get_bottom_bar(self): + """Make a bar at the bottom with information in it.""" + hbox = QHBoxLayout() + + hbox.addStretch(10) + + self._intensity_label = QLabel('') # update later + hbox.addWidget(self._intensity_label) + + RAS_label = QLabel('RAS =') + self._RAS_textbox = QPlainTextEdit('') # update later + self._RAS_textbox.setMaximumHeight(25) + self._RAS_textbox.setMaximumWidth(200) + self._RAS_textbox.focusOutEvent = self._update_RAS + self._RAS_textbox.textChanged.connect(self._check_update_RAS) + hbox.addWidget(RAS_label) + hbox.addWidget(self._RAS_textbox) + self._update_moved() # update text now + return hbox + + def _group_channels(self, groups): + """Automatically find a group based on the name of the channel.""" + if groups is not None: + for name in self._ch_names: + if name not in groups: + raise ValueError(f'{name} not found in ``groups``') + _validate_type(groups[name], (float, int), f'groups[{name}]') + self.groups = groups + else: + i = 0 + self._groups = dict() + base_names = dict() + for name in self._ch_names: + # strip all numbers from the name + base_name = ''.join([letter for letter in name if + not letter.isdigit() and letter != ' ']) + if base_name in base_names: + # look up group number by base name + self._groups[name] = base_names[base_name] + else: + self._groups[name] = i + base_names[base_name] = i + i += 1 + + def _set_ch_names(self): + """Add the channel names to the selector.""" + self._ch_list_model = QtGui.QStandardItemModel(self._ch_list) + for name in self._ch_names: + self._ch_list_model.appendRow(QtGui.QStandardItem(name)) + self._color_list_item(name=name) + self._ch_list.setModel(self._ch_list_model) + self._ch_list.clicked.connect(self._go_to_ch) + self._ch_list.setCurrentIndex( + self._ch_list_model.index(self._ch_index, 0)) + self._ch_list.keyPressEvent = self._key_press_event + + def _select_group(self): + """Change the group label to the selection.""" + group = self._group_selector.currentIndex() + self._groups[self._ch_names[self._ch_index]] = group + # color differently if found already + self._color_list_item(self._ch_names[self._ch_index]) + self._update_group() + + def _update_group(self): + """Set background for closed group menu.""" + group = self._group_selector.currentIndex() + rgb = (255 * np.array(_UNIQUE_COLORS[group % _N_COLORS]) + ).round().astype(int) + self._group_selector.setStyleSheet( + 'background-color: rgb({:d},{:d},{:d})'.format(*rgb)) + self._group_selector.update() + + def _on_scroll(self, event): + """Process mouse scroll wheel event to zoom.""" + self._zoom(event.step, draw=True) + + def _zoom(self, sign=1, draw=False): + """Zoom in on the image.""" + delta = _ZOOM_STEP_SIZE * sign + for axis, fig in enumerate(self._figs): + xmid = self._images['cursor'][axis].get_xdata()[0] + ymid = self._images['cursor2'][axis].get_ydata()[0] + xmin, xmax = fig.axes[0].get_xlim() + ymin, ymax = fig.axes[0].get_ylim() + xwidth = (xmax - xmin) / 2 - delta + ywidth = (ymax - ymin) / 2 - delta + if xwidth <= 0 or ywidth <= 0: + return + fig.axes[0].set_xlim(xmid - xwidth, xmid + xwidth) + fig.axes[0].set_ylim(ymid - ywidth, ymid + ywidth) + self._images['cursor'][axis].set_ydata([ymin, ymax]) + self._images['cursor2'][axis].set_xdata([xmin, xmax]) + if draw: + self._figs[axis].canvas.draw() + + def _update_ch_selection(self): + """Update which channel is selected.""" + name = self._ch_names[self._ch_index] + self._ch_list.setCurrentIndex( + self._ch_list_model.index(self._ch_index, 0)) + self._group_selector.setCurrentIndex(self._groups[name]) + self._update_group() + if not np.isnan(self._chs[name]).any(): + self._ras[:] = self._chs[name] + self._move_cursors_to_pos() + self._update_camera(render=True) + self._draw() + + def _go_to_ch(self, index): + """Change current channel to the item selected.""" + self._ch_index = index.row() + self._update_ch_selection() + + @pyqtSlot() + def _next_ch(self): + """Increment the current channel selection index.""" + self._ch_index = (self._ch_index + 1) % len(self._ch_names) + self._update_ch_selection() + + @pyqtSlot() + def _update_RAS(self, event): + """Interpret user input to the RAS textbox.""" + text = self._RAS_textbox.toPlainText().replace('\n', '') + ras = text.split(',') + if len(ras) != 3: + ras = text.split(' ') # spaces also okay as in freesurfer + ras = [var.lstrip().rstrip() for var in ras] + + if len(ras) != 3: + self._update_moved() # resets RAS label + return + all_float = all([all([dig.isdigit() or dig in ('-', '.') + for dig in var]) for var in ras]) + if not all_float: + self._update_moved() # resets RAS label + return + + ras = np.array([float(var) for var in ras]) + xyz = apply_trans(self._ras_vox_t, ras) + wrong_size = any([var < 0 or var > n for var, n in + zip(xyz, self._voxel_sizes)]) + if wrong_size: + self._update_moved() # resets RAS label + return + + # valid RAS position, update and move + self._ras = ras + self._move_cursors_to_pos() + + @pyqtSlot() + def _check_update_RAS(self): + """Check whether the RAS textbox is done being edited.""" + if '\n' in self._RAS_textbox.toPlainText(): + self._update_RAS(event=None) + self._ch_list.setFocus() # remove focus from text edit + + def _color_list_item(self, name=None): + """Color the item in the view list for easy id of marked channels.""" + name = self._ch_names[self._ch_index] if name is None else name + color = QtGui.QColor('white') + if not np.isnan(self._chs[name]).any(): + group = self._groups[name] + color.setRgb(*[int(c * 255) for c in + _UNIQUE_COLORS[int(group) % _N_COLORS]]) + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + self._ch_list_model.setData( + self._ch_list_model.index(self._ch_names.index(name), 0), + brush, QtCore.Qt.BackgroundRole) + # color text black + color = QtGui.QColor('black') + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + self._ch_list_model.setData( + self._ch_list_model.index(self._ch_names.index(name), 0), + brush, QtCore.Qt.ForegroundRole) + + @pyqtSlot() + def _toggle_snap(self): + """Toggle snapping the contact location to the center of mass.""" + if self._snap_button.text() == 'Off': + self._snap_button.setText('On') + self._snap_button.setStyleSheet("background-color: green") + else: # text == 'On', turn off + self._snap_button.setText('Off') + self._snap_button.setStyleSheet("background-color: red") + + @pyqtSlot() + def _mark_ch(self): + """Mark the current channel as being located at the crosshair.""" + name = self._ch_names[self._ch_index] + if self._snap_button.text() == 'Off': + self._chs[name][:] = self._ras + else: + coord = apply_trans(self._ras_vox_t, self._ras.copy()) + shape = np.mean(self._mri_data.shape) # Freesurfer shape (256) + voxels_max = int( + 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE)**3) + neighbors = _voxel_neighbors( + coord, self._ct_data, thresh=0.5, + voxels_max=voxels_max, use_relative=True) + self._chs[name][:] = apply_trans( # to surface RAS + self._vox_ras_t, np.array(list(neighbors)).mean(axis=0)) + self._color_list_item() + self._update_ch_images(draw=True) + self._plot_3d_ch(name, render=True) + self._save_ch_coords() + self._next_ch() + self._ch_list.setFocus() + + @pyqtSlot() + def _remove_ch(self): + """Remove the location data for the current channel.""" + name = self._ch_names[self._ch_index] + self._chs[name] *= np.nan + self._color_list_item() + self._save_ch_coords() + self._update_ch_images(draw=True) + self._plot_3d_ch(name, render=True) + self._next_ch() + self._ch_list.setFocus() + + def _draw(self, axis=None): + """Update the figures with a draw call.""" + for axis in (range(3) if axis is None else [axis]): + self._figs[axis].canvas.draw() + + def _update_ch_images(self, axis=None, draw=False): + """Update the channel image(s).""" + for axis in range(3) if axis is None else [axis]: + self._images['chs'][axis].set_data( + self._make_ch_image(axis)) + if draw: + self._draw(axis) + + def _update_ct_images(self, axis=None, draw=False): + """Update the CT image(s).""" + for axis in range(3) if axis is None else [axis]: + ct_data = np.take(self._ct_data, self._current_slice[axis], + axis=axis).T + # Threshold the CT so only bright objects (electrodes) are visible + ct_data[ct_data < self._ct_min_slider.value()] = np.nan + ct_data[ct_data > self._ct_max_slider.value()] = np.nan + self._images['ct'][axis].set_data(ct_data) + if draw: + self._draw(axis) + + def _update_mri_images(self, axis=None, draw=False): + """Update the CT image(s).""" + if 'mri' in self._images: + for axis in range(3) if axis is None else [axis]: + self._images['mri'][axis].set_data( + np.take(self._mri_data, self._current_slice[axis], + axis=axis).T) + if draw: + self._draw(axis) + + def _update_images(self, axis=None, draw=True): + """Update CT and channel images when general changes happen.""" + self._update_ct_images(axis=axis) + self._update_ch_images(axis=axis) + self._update_mri_images(axis=axis) + if draw: + self._draw(axis) + + def _update_ct_scale(self): + """Update CT min slider value.""" + new_min = self._ct_min_slider.value() + new_max = self._ct_max_slider.value() + # handle inversions + self._ct_min_slider.setValue(min([new_min, new_max])) + self._ct_max_slider.setValue(max([new_min, new_max])) + self._update_ct_images(draw=True) + + def _update_radius(self): + """Update channel plot radius.""" + self._radius = np.round(self._radius_slider.value()).astype(int) + self._update_ch_images(draw=True) + self._plot_3d_ch_pos(render=True) + self._ch_list.setFocus() # remove focus from 3d plotter + + def _update_ch_alpha(self): + """Update channel plot alpha.""" + self._ch_alpha = self._alpha_slider.value() / 100 + for axis in range(3): + self._images['chs'][axis].set_alpha(self._ch_alpha) + self._draw() + self._plot_3d_ch_pos(render=True) + self._ch_list.setFocus() # remove focus from 3d plotter + + def _get_click_pos(self, axis, x, y): + """Get which axis was clicked and where.""" + fx, fy = self._figs[axis].transFigure.inverted().transform((x, y)) + xmin, xmax = self._figs[axis].axes[0].get_xlim() + ymin, ymax = self._figs[axis].axes[0].get_ylim() + return (fx * (xmax - xmin) + xmin, fy * (ymax - ymin) + ymin) + + def _move_cursors_to_pos(self): + """Move the cursors to a position.""" + x, y, z = apply_trans(self._ras_vox_t, self._ras) + self._current_slice = np.array([x, y, z]).round().astype(int) + self._move_cursor_to(0, x=y, y=z) + self._move_cursor_to(1, x=x, y=z) + self._move_cursor_to(2, x=x, y=y) + self._zoom(0) # doesn't actually zoom just resets view to center + self._update_images(draw=True) + self._update_moved() + + def _move_cursor_to(self, axis, x, y): + """Move the cursors to a position for a given subplot.""" + self._images['cursor2'][axis].set_ydata([y, y]) + self._images['cursor'][axis].set_xdata([x, x]) + + def _show_help(self): + """Show the help menu.""" + QMessageBox.information( + self, 'Help', + "Help:\n'm': mark channel location\n" + "'r': remove channel location\n" + "'b': toggle viewing of brain in T1\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "page up/page down arrow: anterior/posterior") + + def _toggle_show_brain(self): + """Toggle whether the brain/MRI is being shown.""" + if 'mri' in self._images: + for img in self._images['mri']: + img.remove() + self._images.pop('mri') + self._toggle_brain_button.setText('Show Brain') + else: + self._images['mri'] = list() + for axis in range(3): + mri_data = np.take(self._mri_data, + self._current_slice[axis], axis=axis).T + self._images['mri'].append(self._figs[axis].axes[0].imshow( + mri_data, cmap='hot', aspect='auto', alpha=0.25)) + self._toggle_brain_button.setText('Hide Brain') + self._draw() + + def _key_press_event(self, event): + """Execute functions when the user presses a key.""" + if event.key() == 'escape': + self.close() + + if event.text() == 'h': + self._show_help() + + if event.text() == 'm': + self._mark_ch() + + if event.text() == 'r': + self._remove_ch() + + if event.text() == 'b': + self._toggle_show_brain() + + if event.text() in ('=', '+', '-'): + self._zoom(sign=-2 * (event.text() == '-') + 1, draw=True) + + # Changing slices + if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down, + QtCore.Qt.Key_Left, QtCore.Qt.Key_Right, + QtCore.Qt.Key_PageUp, QtCore.Qt.Key_PageDown): + if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down): + self._ras[2] += 2 * (event.key() == QtCore.Qt.Key_Up) - 1 + elif event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_Right): + self._ras[0] += 2 * (event.key() == QtCore.Qt.Key_Right) - 1 + elif event.key() in (QtCore.Qt.Key_PageUp, + QtCore.Qt.Key_PageDown): + self._ras[1] += 2 * (event.key() == QtCore.Qt.Key_PageUp) - 1 + self._move_cursors_to_pos() + + def _on_click(self, axis, event): + """Move to view on MRI and CT on click.""" + # Transform coordinates to figure coordinates + pos = self._get_click_pos(axis, event.x, event.y) + logger.info(f'Clicked axis {axis} at pos {pos}') + + if axis is not None and pos is not None: + xyz = apply_trans(self._ras_vox_t, self._ras) + if axis == 0: + xyz[[1, 2]] = pos + elif axis == 1: + xyz[[0, 2]] = pos + elif axis == 2: + xyz[[0, 1]] = pos + self._ras = apply_trans(self._vox_ras_t, xyz) + self._move_cursors_to_pos() + + def _update_moved(self): + """Update when cursor position changes.""" + self._RAS_textbox.setPlainText('{:.2f}, {:.2f}, {:.2f}'.format( + *self._ras)) + self._intensity_label.setText('intensity = {:.2f}'.format( + self._ct_data[tuple(self._current_slice)])) diff --git a/mne/gui/tests/test_ieeg_locate_gui.py b/mne/gui/tests/test_ieeg_locate_gui.py new file mode 100644 index 00000000000..e4568ff52bd --- /dev/null +++ b/mne/gui/tests/test_ieeg_locate_gui.py @@ -0,0 +1,155 @@ +# Authors: Alex Rockhill +# +# License: BSD-3-clause + +import os.path as op +import numpy as np +from numpy.testing import assert_allclose + +import pytest + +import mne +from mne.datasets import testing +from mne.utils import requires_nibabel +from mne.viz.utils import _fake_click + +data_path = testing.data_path(download=False) +subject = 'sample' +subjects_dir = op.join(data_path, 'subjects') +sample_dir = op.join(data_path, 'MEG', subject) +raw_path = op.join(sample_dir, 'sample_audvis_trunc_raw.fif') +fname_trans = op.join(sample_dir, 'sample_audvis_trunc-trans.fif') + + +@requires_nibabel() +@pytest.fixture +def _fake_CT_coords(skull_size=5, contact_size=2): + """Make somewhat realistic CT data with contacts.""" + import nibabel as nib + brain = nib.load( + op.join(subjects_dir, subject, 'mri', 'brain.mgz')) + verts = mne.read_surface( + op.join(subjects_dir, subject, 'bem', 'outer_skull.surf'))[0] + verts = mne.transforms.apply_trans( + np.linalg.inv(brain.header.get_vox2ras_tkr()), verts) + x, y, z = np.array(brain.shape).astype(int) // 2 + coords = [(x, y - 14, z), (x - 10, y - 15, z), + (x - 20, y - 16, z + 1), (x - 30, y - 16, z + 1)] + center = np.array(brain.shape) / 2 + # make image + np.random.seed(99) + ct_data = np.random.random(brain.shape).astype(np.float32) * 100 + # make skull + for vert in verts: + x, y, z = np.round(vert).astype(int) + ct_data[slice(x - skull_size, x + skull_size + 1), + slice(y - skull_size, y + skull_size + 1), + slice(z - skull_size, z + skull_size + 1)] = 1000 + # add electrode with contacts + for (x, y, z) in coords: + # make sure not in skull + assert np.linalg.norm(center - np.array((x, y, z))) < 50 + ct_data[slice(x - contact_size, x + contact_size + 1), + slice(y - contact_size, y + contact_size + 1), + slice(z - contact_size, z + contact_size + 1)] = \ + 1000 - np.linalg.norm(np.array(np.meshgrid( + *[range(-contact_size, contact_size + 1)] * 3)), axis=0) + ct = nib.MGHImage(ct_data, brain.affine) + coords = mne.transforms.apply_trans( + ct.header.get_vox2ras_tkr(), np.array(coords)) + return ct, coords + + +@requires_nibabel() +@pytest.fixture +def _locate_ieeg(renderer_interactive_pyvistaqt): + # Use a fixture to create these classes so we can ensure that they + # are closed at the end of the test + guis = list() + + def fun(*args, **kwargs): + guis.append(mne.gui.locate_ieeg(*args, **kwargs)) + return guis[-1] + + yield fun + + for gui in guis: + try: + gui.close() + except Exception: + pass + + +def test_ieeg_elec_locate_gui_io(_locate_ieeg): + """Test the input/output of the intracranial location GUI.""" + import nibabel as nib + info = mne.create_info([], 1000) + aligned_ct = nib.MGHImage(np.zeros((256, 256, 256), dtype=np.float32), + np.eye(4)) + trans = mne.transforms.Transform('head', 'mri') + with pytest.raises(ValueError, + match='No channels found in `info` to locate'): + _locate_ieeg(info, aligned_ct, subject, subjects_dir) + info = mne.create_info(['test'], 1000, ['seeg']) + with pytest.raises(ValueError, match='CT is not aligned to MRI'): + _locate_ieeg(info, trans, aligned_ct, subject=subject, + subjects_dir=subjects_dir) + + +@testing.requires_testing_data +def test_ieeg_elec_locate_gui_display(_locate_ieeg, _fake_CT_coords): + """Test that the intracranial location GUI displays properly.""" + raw = mne.io.read_raw_fif(raw_path) + raw.pick_types(eeg=True) + ch_dict = {'EEG 001': 'LAMY 1', 'EEG 002': 'LAMY 2', + 'EEG 003': 'LSTN 1', 'EEG 004': 'LSTN 2'} + raw.pick_channels(list(ch_dict.keys())) + raw.rename_channels(ch_dict) + raw.set_montage(None) + aligned_ct, coords = _fake_CT_coords + trans = mne.read_trans(fname_trans) + with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + gui = _locate_ieeg(raw.info, trans, aligned_ct, + subject=subject, subjects_dir=subjects_dir) + + gui._ras[:] = coords[0] # start in the right position + gui._move_cursors_to_pos() + for coord in coords: + coord_vox = mne.transforms.apply_trans(gui._ras_vox_t, coord) + _fake_click(gui._figs[2], gui._figs[2].axes[0], + coord_vox[:-1], xform='data', kind='release') + assert_allclose(coord, gui._ras, atol=3) # clicks are a bit off + + # test snap to center + gui._ras[:] = coords[0] # move to first position + gui._move_cursors_to_pos() + gui._mark_ch() + assert_allclose(coords[0], gui._chs['LAMY 1'], atol=0.2) + gui._snap_button.click() + assert gui._snap_button.text() == 'Off' + # now make sure no snap happens + gui._ras[:] = coords[1] + 1 + gui._mark_ch() + assert_allclose(coords[1] + 1, gui._chs['LAMY 2'], atol=0.01) + # check that it turns back on + gui._snap_button.click() + assert gui._snap_button.text() == 'On' + + # test remove + gui._ch_index = 1 + gui._update_ch_selection() + gui._remove_ch() + assert np.isnan(gui._chs['LAMY 2']).all() + + # check that raw object saved + assert not np.isnan(raw.info['chs'][0]['loc'][:3]).any() # LAMY 1 + assert np.isnan(raw.info['chs'][1]['loc'][:3]).all() # LAMY 2 (removed) + + # move sliders + gui._alpha_slider.setValue(75) + assert gui._ch_alpha == 0.75 + gui._radius_slider.setValue(5) + assert gui._radius == 5 + ct_sum_before = np.nansum(gui._images['ct'][0].get_array().data) + gui._ct_min_slider.setValue(500) + assert np.nansum(gui._images['ct'][0].get_array().data) < ct_sum_before diff --git a/mne/surface.py b/mne/surface.py index e94264da05d..586ff52cbce 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -21,7 +21,7 @@ from .channels.channels import _get_meg_system from .fixes import (_serialize_volume_info, _get_read_geometry, jit, - prange, bincount, _get_img_fdata) + prange, bincount) from .io.constants import FIFF from .io.pick import pick_types from .parallel import parallel_func @@ -1736,22 +1736,23 @@ def _marching_cubes(image, level, smooth=0): return out -def _warn_missing_chs(montage, dig_image, after_warp): +def _warn_missing_chs(info, dig_image, after_warp, verbose=None): """Warn that channels are missing.""" # ensure that each electrode contact was marked in at least one voxel - missing = set(np.arange(1, len(montage.ch_names) + 1)).difference( - set(np.unique(np.asanyarray(dig_image.dataobj)))) - missing_ch = [montage.ch_names[idx - 1] for idx in missing] - if missing_ch: + missing = set(np.arange(1, len(info.ch_names) + 1)).difference( + set(np.unique(np.array(dig_image.dataobj)))) + missing_ch = [info.ch_names[idx - 1] for idx in missing] + if missing_ch and verbose != 'error': warn('Channels ' + ', '.join(missing_ch) + ' were not assigned ' 'voxels' + (' after applying SDR warp' if after_warp else '')) -@fill_doc +@verbose def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, subject_from, subject_to='fsaverage', - subjects_dir=None, thresh=0.95, - max_peak_dist=1, voxels_max=100, use_min=False): + subjects_dir_from=None, subjects_dir_to=None, + thresh=0.5, max_peak_dist=1, voxels_max=100, + use_min=False, verbose=None): """Warp a montage to a template with image volumes using SDR. Find areas of the input volume with intensity greater than @@ -1777,10 +1778,17 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, subject_to : str The name of the subject to use as a template to morph to (e.g. 'fsaverage'). - %(subjects_dir)s + subjects_dir_from : str | pathlib.Path | None + The path to the Freesurfer ``recon-all`` directory for the + ``subject_from`` subject. The ``SUBJECTS_DIR`` environment + variable will be used when ``None``. + subjects_dir_to : str | pathlib.Path | None + The path to the Freesurfer ``recon-all`` directory for the + ``subject_to`` subject. ``subject_dir_from`` will be used + when ``None``. thresh : float - The quantile of the image data to use to threshold the - channel size on the volume. + The threshold relative to the peak to determine the size + of the sensors on the volume. max_peak_dist : int The number of voxels away from the channel location to look in the ``image``. This will depend on the accuracy of @@ -1791,6 +1799,7 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, use_min : bool Whether to hypointensities in the volume as channel locations. Default False uses hyperintensities. + %(verbose)s Returns ------- @@ -1820,14 +1829,16 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, _validate_type(use_min, bool, 'use_min') # first, make sure we have the necessary freesurfer surfaces - _check_subject_dir(subject_from, subjects_dir) - _check_subject_dir(subject_to, subjects_dir) + _check_subject_dir(subject_from, subjects_dir_from) + if subjects_dir_to is None: # assume shared + subjects_dir_to = subjects_dir_from + _check_subject_dir(subject_to, subjects_dir_to) # load image and make sure it's in surface RAS if not isinstance(base_image, nib.spatialimages.SpatialImage): base_image = nib.load(base_image) fs_from_img = nib.load( - op.join(subjects_dir, subject_from, 'mri', 'brain.mgz')) + op.join(subjects_dir_from, subject_from, 'mri', 'brain.mgz')) if not np.allclose(base_image.affine, fs_from_img.affine, atol=1e-6): raise RuntimeError('The `base_image` is not aligned to Freesurfer ' 'surface RAS space. This space is required as ' @@ -1853,12 +1864,13 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, # take channel coordinates and use the image to transform them # into a volume where all the voxels over a threshold nearby # are labeled with an index - image_data = _get_img_fdata(base_image) + image_data = np.array(base_image.dataobj) if use_min: image_data *= -1 - thresh = np.quantile(image_data, thresh) image_from = np.zeros(base_image.shape, dtype=int) for i, ch_coord in enumerate(ch_coords): + if np.isnan(ch_coord).any(): + continue # this looks up to a voxel away, it may be marked imperfectly volume = _voxel_neighbors(ch_coord, image_data, thresh=thresh, max_peak_dist=max_peak_dist, @@ -1880,7 +1892,7 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, _warn_missing_chs(montage, image_from, after_warp=False) template_brain = nib.load( - op.join(subjects_dir, subject_to, 'mri', 'brain.mgz')) + op.join(subjects_dir_to, subject_to, 'mri', 'brain.mgz')) image_to = apply_volume_registration( image_from, template_brain, reg_affine, sdr_morph, @@ -1891,12 +1903,11 @@ def warp_montage_volume(montage, base_image, reg_affine, sdr_morph, # recover the contact positions as the center of mass warped_data = np.asanyarray(image_to.dataobj) for val, ch_coord in enumerate(ch_coords, 1): - ch_coord[:] = np.asanyarray( - np.where(warped_data == val), float).mean(axis=1) + ch_coord[:] = np.mean(np.where(warped_data == val), axis=1) # convert back to surface RAS of the template fs_to_img = nib.load( - op.join(subjects_dir, subject_to, 'mri', 'brain.mgz')) + op.join(subjects_dir_to, subject_to, 'mri', 'brain.mgz')) ch_coords = apply_trans( fs_to_img.header.get_vox2ras_tkr(), ch_coords) / 1000 @@ -1961,13 +1972,16 @@ def get_montage_volume_labels(montage, subject, subjects_dir=None, ch_coords = apply_trans( np.linalg.inv(aseg.header.get_vox2ras_tkr()), ch_coords * 1000) labels = OrderedDict() - for ch_name, seed in zip(montage.ch_names, ch_coords): - voxels = _voxel_neighbors( - seed, aseg_data, dist=dist, vox2ras_tkr=vox2ras_tkr, - voxels_max=_VOXELS_MAX) - label_idxs = set([aseg_data[tuple(voxel)].astype(int) - for voxel in voxels]) - labels[ch_name] = [label_lut[idx] for idx in label_idxs] + for ch_name, ch_coord in zip(montage.ch_names, ch_coords): + if np.isnan(ch_coord).any(): + labels[ch_name] = list() + else: + voxels = _voxel_neighbors( + ch_coord, aseg_data, dist=dist, vox2ras_tkr=vox2ras_tkr, + voxels_max=_VOXELS_MAX) + label_idxs = set([aseg_data[tuple(voxel)].astype(int) + for voxel in voxels]) + labels[ch_name] = [label_lut[idx] for idx in label_idxs] all_labels = set([label for val in labels.values() for label in val]) colors = {label: tuple(fs_colors[label][:3] / 255) + (1.,) @@ -2000,8 +2014,9 @@ def _get_neighbors(loc, image, voxels, thresh, dist_params): return neighbors -def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, - dist=None, vox2ras_tkr=None, voxels_max=100): +def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=1, + use_relative=True, dist=None, vox2ras_tkr=None, + voxels_max=100): """Find voxels above a threshold contiguous with a seed location. Parameters @@ -2011,11 +2026,14 @@ def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, image : ndarray The image to search. thresh : float - The threshold to use as a cutoff for what qualifies as a - neighbor. + The threshold to use as a cutoff for what qualifies as a neighbor. + Will be relative to the peak if ``use_relative`` or absolute if not. max_peak_dist : int The maximum number of voxels to search for the peak near the seed location. + use_relative : bool + If ``True``, the threshold will be relative to the peak, if + ``False``, the threshold will be absolute. dist : float The distance in mm to include surrounding voxels. vox2ras_tkr : ndarray @@ -2037,8 +2055,8 @@ def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, only voxels within ``dist`` mm of the seed are included. """ seed = np.array(seed).round().astype(int) - assert ((dist is not None) + (max_peak_dist is not None)) == 1 - if max_peak_dist is not None: + assert ((dist is not None) + (thresh is not None)) == 1 + if thresh is not None: dist_params = None check_grid = image[tuple([ slice(idx - max_peak_dist, idx + max_peak_dist + 1) @@ -2046,6 +2064,8 @@ def _voxel_neighbors(seed, image, thresh=None, max_peak_dist=None, peak = np.array(np.unravel_index( np.argmax(check_grid), check_grid.shape)) - max_peak_dist + seed voxels = neighbors = set([tuple(peak)]) + if use_relative: + thresh *= image[tuple(peak)] else: assert vox2ras_tkr is not None seed_fs_ras = apply_trans(vox2ras_tkr, seed + 0.5) # center of voxel diff --git a/mne/tests/test_surface.py b/mne/tests/test_surface.py index 851ed065518..094e82740a1 100644 --- a/mne/tests/test_surface.py +++ b/mne/tests/test_surface.py @@ -290,7 +290,7 @@ def test_voxel_neighbors(): image[4:7, 4:7, 4:7] = 3 image[5, 5, 5] = 4 volume = _voxel_neighbors( - (5.5, 5.1, 4.9), image, thresh=2, max_peak_dist=1) + (5.5, 5.1, 4.9), image, thresh=2, max_peak_dist=1, use_relative=False) true_volume = set([(5, 4, 5), (5, 5, 4), (5, 5, 5), (6, 5, 5), (5, 6, 5), (5, 5, 6), (4, 5, 5)]) assert volume.difference(true_volume) == set() @@ -341,7 +341,7 @@ def test_warp_montage_volume(): rpa=rpa['r'], coord_frame='mri') montage_warped, image_from, image_to = warp_montage_volume( montage, CT, reg_affine, sdr_morph, 'sample', - subjects_dir=subjects_dir, thresh=0.99) + subjects_dir_from=subjects_dir, thresh=0.99) # checked with nilearn plot from `tut-ieeg-localize` # check montage in surface RAS ground_truth_warped = np.array([[-0.009, -0.00133333, -0.033], @@ -375,13 +375,15 @@ def test_warp_montage_volume(): CT_unaligned = nib.Nifti1Image(CT_data, template_brain.affine) with pytest.raises(RuntimeError, match='not aligned to Freesurfer'): warp_montage_volume(montage, CT_unaligned, reg_affine, - sdr_morph, 'sample', subjects_dir=subjects_dir) + sdr_morph, 'sample', + subjects_dir_from=subjects_dir) bad_montage = montage.copy() for d in bad_montage.dig: d['coord_frame'] = 99 with pytest.raises(RuntimeError, match='Coordinate frame not supported'): warp_montage_volume(bad_montage, CT, reg_affine, - sdr_morph, 'sample', subjects_dir=subjects_dir) + sdr_morph, 'sample', + subjects_dir_from=subjects_dir) # check channel not warped ch_pos_doubled = ch_pos.copy() @@ -391,4 +393,5 @@ def test_warp_montage_volume(): rpa=rpa['r'], coord_frame='mri') with pytest.warns(RuntimeWarning, match='not assigned'): warp_montage_volume(doubled_montage, CT, reg_affine, - sdr_morph, 'sample', subjects_dir=subjects_dir) + sdr_morph, 'sample', + subjects_dir_from=subjects_dir) diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index f48515d8b61..1321538b35c 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -66,7 +66,7 @@ def requires_nibabel(): def requires_dipy(): """Check for dipy.""" import pytest - # for some strange reason on CIs we cane get: + # for some strange reason on CIs we can get: # # can get weird ImportError: dlopen: cannot load any more object # with static TLS diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 79e2f3d66a0..7e6630f1bef 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2577,10 +2577,10 @@ step as a key. Steps not in the dictionary will use the default value. The default (None) is equivalent to: - niter=dict(translation=(10000, 1000, 100), - rigid=(10000, 1000, 100), - affine=(10000, 1000, 100), - sdr=(10, 10, 5)) + niter=dict(translation=(100, 100, 10), + rigid=(100, 100, 10), + affine=(100, 100, 10), + sdr=(5, 5, 3)) """ docdict['pipeline'] = """ pipeline : str | tuple diff --git a/mne/viz/misc.py b/mne/viz/misc.py index cdbfb3c0cdd..dfe0e15d92a 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -513,7 +513,7 @@ def plot_bem(subject=None, subjects_dir=None, orientation='coronal', bem_path = op.join(subjects_dir, subject, 'bem') if not op.isdir(bem_path): - raise IOError('Subject bem directory "%s" does not exist' % bem_path) + raise IOError(f'Subject bem directory "{bem_path}" does not exist') surfaces = _get_bem_plotting_surfaces(bem_path) if brain_surfaces is not None: diff --git a/tutorials/clinical/10_ieeg_localize.py b/tutorials/clinical/10_ieeg_localize.py index 2eace180f59..7c2b5f311b9 100644 --- a/tutorials/clinical/10_ieeg_localize.py +++ b/tutorials/clinical/10_ieeg_localize.py @@ -6,15 +6,18 @@ Locating Intracranial Electrode Contacts ======================================== -Intracranial electrophysiology recording contacts are generally localized -based on a post-implantation computed tomography (CT) image and a -pre-implantation magnetic resonance (MR) image. The CT image has greater -intensity than the background at each of the electrode contacts and -for the skull. Using the skull, the CT can be aligned to MR-space. -Contact locations in MR-space are the goal because this is the image from which -brain structures can be determined using the -:ref:`tut-freesurfer-reconstruction`. Contact locations in MR-space can also -be translated to a template space such as ``fsaverage`` for group comparisons. +Analysis of intracranial electrophysiology recordings typically involves +finding the position of each contact relative to brain structures. In a +typical setup, the brain and the electrode locations will be in two places +and will have to be aligned; the brain is best visualized by a +pre-implantation magnetic resonance (MR) image whereas the electrode contact +locations are best visualized in a post-implantation computed tomography (CT) +image. The CT image has greater intensity than the background at each of the +electrode contacts and for the skull. Using the skull, the CT can be aligned +to MR-space. This accomplishes our goal of obtaining contact locations in +MR-space (which is where the brain structures are best determined using the +:ref:`tut-freesurfer-reconstruction`). Contact locations in MR-space can also +be warped to a template space such as ``fsaverage`` for group comparisons. """ # Authors: Alex Rockhill @@ -24,7 +27,6 @@ # %% -import os import os.path as op import numpy as np @@ -46,6 +48,58 @@ # use mne-python's fsaverage data fetch_fsaverage(subjects_dir=subjects_dir, verbose=True) # downloads if needed +############################################################################### +# Aligning the T1 to ACPC +# ======================= +# +# For intracranial electrophysiology recordings, the Brain Imaging Data +# Structure (BIDS) standard requires that coordinates be aligned to the +# anterior commissure and posterior commissure (ACPC-aligned). Therefore, it is +# recommended that you do this alignment before finding the positions of the +# channels in your recording. Doing this will make the "mri" (aka surface RAS) +# coordinate frame an ACPC coordinate frame. This can be done using +# Freesurfer's freeview: +# +# .. code-block:: bash +# +# $ freeview $MISC_PATH/seeg/sample_seeg_T1.mgz +# +# And then interact with the graphical user interface: +# +# First, it is recommended to change the cursor style to long, this can be done +# through the menu options like so: +# +# ``Freeview -> Preferences -> General -> Cursor style -> Long`` +# +# Then, the image needs to be aligned to ACPC to look like the image below. +# This can be done by pulling up the transform popup from the menu like so: +# +# ``Tools -> Transform Volume`` +# +# .. note:: +# Be sure to set the text entry box labeled RAS (not TkReg RAS) to +# ``0 0 0`` before beginning the transform. +# +# Then translate the image until the crosshairs meet on the AC and +# run through the PC as shown in the plot. The eyes should be in +# the ACPC plane and the image should be rotated until they are symmetrical, +# and the crosshairs should transect the midline of the brain. +# Be sure to use both the rotate and the translate menus and save the volume +# after you're finished using ``Save Volume As`` in the transform popup +# :footcite:`HamiltonEtAl2017`. + +T1 = nib.load(op.join(misc_path, 'seeg', 'sample_seeg', 'mri', 'T1.mgz')) +viewer = T1.orthoview() +viewer.set_position(0, 9.9, 5.8) +viewer.figs[0].axes[0].annotate( + 'PC', (107, 108), xytext=(10, 75), color='white', + horizontalalignment='center', + arrowprops=dict(facecolor='white', lw=0.5, width=2, headwidth=5)) +viewer.figs[0].axes[0].annotate( + 'AC', (137, 108), xytext=(246, 75), color='white', + horizontalalignment='center', + arrowprops=dict(facecolor='white', lw=0.5, width=2, headwidth=5)) + # %% # Freesurfer recon-all # ==================== @@ -105,7 +159,6 @@ def plot_overlay(image, compare, title, thresh=None): fig.tight_layout() -T1 = nib.load(op.join(misc_path, 'seeg', 'sample_seeg', 'mri', 'T1.mgz')) CT_orig = nib.load(op.join(misc_path, 'seeg', 'sample_seeg_CT.mgz')) # resample to T1's definition of world coordinates @@ -125,7 +178,7 @@ def plot_overlay(image, compare, title, thresh=None): # here:: # # reg_affine, _ = mne.transforms.compute_volume_registration( -# CT_orig, T1, pipeline='rigids', verbose=True) +# CT_orig, T1, pipeline='rigids') # # And instead we just hard-code the resulting 4x4 matrix: @@ -173,6 +226,21 @@ def plot_overlay(image, compare, title, thresh=None): fig.tight_layout() del CT_data, T1 +# %% +# Now we need to estimate the "head" coordinate transform. +# +# MNE stores digitization montages in a coordinate frame called "head" +# defined by fiducial points (origin is halfway between the LPA and RPA +# see :ref:`tut-source-alignment`). For sEEG, it is convenient to get an +# estimate of the location of the fiducial points for the subject +# using the Talairach transform (see :func:`mne.coreg.get_mni_fiducials`) +# to use to define the coordinate frame so that we don't have to manually +# identify their location. + +# estimate head->mri transform +subj_trans = mne.coreg.estimate_head_mri_t( + 'sample_seeg', op.join(misc_path, 'seeg')) + # %% # Marking the Location of Each Electrode Contact # ============================================== @@ -180,35 +248,46 @@ def plot_overlay(image, compare, title, thresh=None): # Now, the CT and the MR are in the same space, so when you are looking at a # point in CT space, it is the same point in MR space. So now everything is # ready to determine the location of each electrode contact in the -# individual subject's anatomical space (T1-space). To do this, can make -# list of ``TkReg RAS`` points from the lower panel in freeview or use the -# mne graphical user interface (coming soon). The electrode locations will then -# be in the ``surface RAS`` coordinate frame, which is helpful because that is -# the coordinate frame that all the surface and image files that freesurfer -# outputs are in, see :ref:`tut-freesurfer-mne`. +# individual subject's anatomical space (T1-space). To do this, we can use the +# MNE intracranial electrode location graphical user interface. # -# The electrode contact locations could be determined using freeview by -# clicking through and noting each contact position in the interface launched -# by the following command: +# .. note: The most useful coordinate frame for intracranial electrodes is +# generally the ``surface RAS`` coordinate frame because that is +# the coordinate frame that all the surface and image files that +# Freesurfer outputs are in, see :ref:`tut-freesurfer-mne`. These are +# useful for finding the brain structures nearby each contact and +# plotting the results. # -# .. code-block:: bash +# To operate the GUI: # -# $ freeview $MISC_PATH/seeg/sample_seeg_T1.mgz \ -# $MISC_PATH/seeg/sample_seeg_CT.mgz - -# %% -# Now, we'll make a montage with the channels that we've found in the -# previous step. +# - Click in each image to navigate to each electrode contact +# - Select the contact name in the right panel +# - Press the "Mark" button or the "m" key to associate that +# position with that contact +# - Repeat until each contact is marked, they will both appear as circles +# in the plots and be colored in the sidebar when marked # -# .. note:: MNE represents data in the "head" space internally +# .. note:: The channel locations are saved to the ``raw`` object every time +# a location is marked or removed so there is no "Save" button. +# +# .. note:: Using the scroll or +/- arrow keys you can zoom in and out, +# and the up/down, left/right and page up/page down keys allow +# you to move one slice in any direction. This information is +# available in the help menu, accessible by pressing the "h" key. +# +# .. note:: If "Snap to Center" is on, this will use the radius so be +# sure to set it properly. -# load electrophysiology data with channel locations +# load electrophysiology data to find channel locations for +# (the channels are already located in the example) raw = mne.io.read_raw(op.join(misc_path, 'seeg', 'sample_seeg_ieeg.fif')) -# create symbolic link to share ``subjects_dir`` -if not op.exists(op.join(subjects_dir, 'sample_seeg')): - os.symlink(op.join(misc_path, 'seeg', 'sample_seeg'), - op.join(subjects_dir, 'sample_seeg')) +gui = mne.gui.locate_ieeg(raw.info, subj_trans, CT_aligned, + subject='sample_seeg', + subjects_dir=op.join(misc_path, 'seeg')) +# The `raw` object is modified to contain the channel locations +# after closing the GUI and can now be saved +gui.close() # close when done # %% # Let's plot the electrode contact locations on the subject's brain. @@ -222,13 +301,10 @@ def plot_overlay(image, compare, title, thresh=None): # identify their location. The estimated head->mri ``trans`` was used # when the electrode contacts were localized so we need to use it again here. -# estimate head->mri transform -subj_trans = mne.coreg.estimate_head_mri_t('sample_seeg', subjects_dir) - # plot the alignment -brain_kwargs = dict(cortex='low_contrast', alpha=0.2, background='white', - subjects_dir=subjects_dir) -brain = mne.viz.Brain('sample_seeg', **brain_kwargs) +brain_kwargs = dict(cortex='low_contrast', alpha=0.2, background='white') +brain = mne.viz.Brain('sample_seeg', subjects_dir=op.join(misc_path, 'seeg'), + **brain_kwargs) brain.add_sensors(raw.info, trans=subj_trans) view_kwargs = dict(azimuth=60, elevation=100, distance=350, focalpoint=(0, 0, -15)) @@ -269,7 +345,6 @@ def plot_overlay(image, compare, title, thresh=None): # is useful for getting a quick view of the data, but finalized # pipelines should use ``zooms=None`` instead! -CT_thresh = 0.8 # 0.95 is better for zooms=None! reg_affine, sdr_morph = mne.transforms.compute_volume_registration( subject_brain, template_brain, zooms=5, verbose=True) subject_brain_sdr = mne.transforms.apply_volume_registration( @@ -297,9 +372,11 @@ def plot_overlay(image, compare, title, thresh=None): montage = raw.get_montage() montage.apply_trans(subj_trans) +# higher thresh such as 0.5 (default) works when `zooms=None` montage_warped, elec_image, warped_elec_image = mne.warp_montage_volume( - montage, CT_aligned, reg_affine, sdr_morph, - subject_from='sample_seeg', subjects_dir=subjects_dir, thresh=CT_thresh) + montage, CT_aligned, reg_affine, sdr_morph, thresh=0.1, + subject_from='sample_seeg', subjects_dir_from=op.join(misc_path, 'seeg'), + subject_to='fsaverage', subjects_dir_to=subjects_dir) fig, axes = plt.subplots(2, 1, figsize=(8, 8)) nilearn.plotting.plot_glass_brain(elec_image, axes=axes[0], cmap='Dark2') @@ -334,7 +411,7 @@ def plot_overlay(image, compare, title, thresh=None): raw.set_montage(montage_warped) # plot the resulting alignment -brain = mne.viz.Brain('fsaverage', **brain_kwargs) +brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, **brain_kwargs) brain.add_sensors(raw.info, trans=template_trans) brain.show_view(**view_kwargs)